diff options
| author | Tim Foley <tfoley@nvidia.com> | 2017-06-09 11:34:21 -0700 |
|---|---|---|
| committer | Tim Foley <tfoley@nvidia.com> | 2017-06-09 13:44:59 -0700 |
| commit | fcf83dbf9effab3bd98bad2b83b2468b7eb05cfd (patch) | |
| tree | 41047c94883b86ec085a81597391ce3ef557cd43 /source/slang | |
| parent | 52e8d4b9a27ab0060f874c3a63ab531847be35c0 (diff) | |
Initial import of code.
Diffstat (limited to 'source/slang')
40 files changed, 28143 insertions, 0 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp new file mode 100644 index 000000000..7d50c5978 --- /dev/null +++ b/source/slang/check.cpp @@ -0,0 +1,4973 @@ +#include "syntax-visitors.h" + +#include "lookup.h" +#include "compiler.h" + +#include <assert.h> + +namespace Slang +{ + namespace Compiler + { + bool IsNumeric(BaseType t) + { + return t == BaseType::Int || t == BaseType::Float || t == BaseType::UInt; + } + + String TranslateHLSLTypeNames(String name) + { + if (name == "float2" || name == "half2") + return "vec2"; + else if (name == "float3" || name == "half3") + return "vec3"; + else if (name == "float4" || name == "half4") + return "vec4"; + else if (name == "half") + return "float"; + else if (name == "int2") + return "ivec2"; + else if (name == "int3") + return "ivec3"; + else if (name == "int4") + return "ivec4"; + else if (name == "uint2") + return "uvec2"; + else if (name == "uint3") + return "uvec3"; + else if (name == "uint4") + return "uvec4"; + else if (name == "float3x3" || name == "half3x3") + return "mat3"; + else if (name == "float4x4" || name == "half4x4") + return "mat4"; + else + return name; + } + + class SemanticsVisitor : public SyntaxVisitor + { + ProgramSyntaxNode * program = nullptr; + FunctionSyntaxNode * function = nullptr; + CompileOptions const* options = nullptr; + + // lexical outer statements + List<StatementSyntaxNode*> outerStmts; + public: + SemanticsVisitor( + DiagnosticSink * pErr, + CompileOptions const& options) + : SyntaxVisitor(pErr) + , options(&options) + { + } + + CompileOptions const& getOptions() { return *options; } + + public: + // Translate Types + RefPtr<ExpressionType> typeResult; + RefPtr<ExpressionSyntaxNode> TranslateTypeNodeImpl(const RefPtr<ExpressionSyntaxNode> & node) + { + if (!node) return nullptr; + auto expr = node->Accept(this).As<ExpressionSyntaxNode>(); + expr = ExpectATypeRepr(expr); + return expr; + } + RefPtr<ExpressionType> ExtractTypeFromTypeRepr(const RefPtr<ExpressionSyntaxNode>& typeRepr) + { + if (!typeRepr) return nullptr; + if (auto typeType = typeRepr->Type->As<TypeType>()) + { + return typeType->type; + } + return ExpressionType::Error; + } + RefPtr<ExpressionType> TranslateTypeNode(const RefPtr<ExpressionSyntaxNode> & node) + { + if (!node) return nullptr; + auto typeRepr = TranslateTypeNodeImpl(node); + return ExtractTypeFromTypeRepr(typeRepr); + } + TypeExp TranslateTypeNode(TypeExp const& typeExp) + { + // HACK(tfoley): It seems that in some cases we end up re-checking + // syntax that we've already checked. We need to root-cause that + // issue, but for now a quick fix in this case is to early + // exist if we've already got a type associated here: + if (typeExp.type) + { + return typeExp; + } + + + auto typeRepr = TranslateTypeNodeImpl(typeExp.exp); + + TypeExp result; + result.exp = typeRepr; + result.type = ExtractTypeFromTypeRepr(typeRepr); + return result; + } + + RefPtr<ExpressionSyntaxNode> ConstructDeclRefExpr( + DeclRef declRef, + RefPtr<ExpressionSyntaxNode> baseExpr, + RefPtr<ExpressionSyntaxNode> originalExpr) + { + if (baseExpr) + { + auto expr = new MemberExpressionSyntaxNode(); + expr->Position = originalExpr->Position; + expr->BaseExpression = baseExpr; + expr->MemberName = declRef.GetName(); + expr->Type = GetTypeForDeclRef(declRef); + expr->declRef = declRef; + return expr; + } + else + { + auto expr = new VarExpressionSyntaxNode(); + expr->Position = originalExpr->Position; + expr->Variable = declRef.GetName(); + expr->Type = GetTypeForDeclRef(declRef); + expr->declRef = declRef; + return expr; + } + } + + RefPtr<ExpressionSyntaxNode> ConstructDerefExpr( + RefPtr<ExpressionSyntaxNode> base, + RefPtr<ExpressionSyntaxNode> originalExpr) + { + auto ptrLikeType = base->Type->As<PointerLikeType>(); + assert(ptrLikeType); + + auto derefExpr = new DerefExpr(); + derefExpr->Position = originalExpr->Position; + derefExpr->base = base; + derefExpr->Type = ptrLikeType->elementType; + + // TODO(tfoley): handle l-value status here + + return derefExpr; + } + + RefPtr<ExpressionSyntaxNode> ConstructLookupResultExpr( + LookupResultItem const& item, + RefPtr<ExpressionSyntaxNode> baseExpr, + RefPtr<ExpressionSyntaxNode> originalExpr) + { + // If we collected any breadcrumbs, then these represent + // additional segments of the lookup path that we need + // to expand here. + auto bb = baseExpr; + for (auto breadcrumb = item.breadcrumbs; breadcrumb; breadcrumb = breadcrumb->next) + { + switch (breadcrumb->kind) + { + case LookupResultItem::Breadcrumb::Kind::Member: + bb = ConstructDeclRefExpr(breadcrumb->declRef, bb, originalExpr); + break; + case LookupResultItem::Breadcrumb::Kind::Deref: + bb = ConstructDerefExpr(bb, originalExpr); + break; + default: + SLANG_UNREACHABLE("all cases handle"); + } + } + + return ConstructDeclRefExpr(item.declRef, bb, originalExpr); + } + + RefPtr<ExpressionSyntaxNode> createLookupResultExpr( + LookupResult const& lookupResult, + RefPtr<ExpressionSyntaxNode> baseExpr, + RefPtr<ExpressionSyntaxNode> originalExpr) + { + if (lookupResult.isOverloaded()) + { + auto overloadedExpr = new OverloadedExpr(); + overloadedExpr->Position = originalExpr->Position; + overloadedExpr->Type = ExpressionType::Overloaded; + overloadedExpr->base = baseExpr; + overloadedExpr->lookupResult2 = lookupResult; + return overloadedExpr; + } + else + { + return ConstructLookupResultExpr(lookupResult.item, baseExpr, originalExpr); + } + } + + RefPtr<ExpressionSyntaxNode> ResolveOverloadedExpr(RefPtr<OverloadedExpr> overloadedExpr, LookupMask mask) + { + auto lookupResult = overloadedExpr->lookupResult2; + assert(lookupResult.isValid() && lookupResult.isOverloaded()); + + // Take the lookup result we had, and refine it based on what is expected in context. + lookupResult = refineLookup(lookupResult, mask); + + if (!lookupResult.isValid()) + { + // If we didn't find any symbols after filtering, then just + // use the original and report errors that way + return overloadedExpr; + } + + if (lookupResult.isOverloaded()) + { + // We had an ambiguity anyway, so report it. + getSink()->diagnose(overloadedExpr, Diagnostics::ambiguousReference, lookupResult.items[0].declRef.GetName()); + + for(auto item : lookupResult.items) + { + String declString = getDeclSignatureString(item); + getSink()->diagnose(item.declRef, Diagnostics::overloadCandidate, declString); + } + + // TODO(tfoley): should we construct a new ErrorExpr here? + overloadedExpr->Type = ExpressionType::Error; + return overloadedExpr; + } + + // otherwise, we had a single decl and it was valid, hooray! + return ConstructLookupResultExpr(lookupResult.item, overloadedExpr->base, overloadedExpr); + } + + RefPtr<ExpressionSyntaxNode> ExpectATypeRepr(RefPtr<ExpressionSyntaxNode> expr) + { + if (auto overloadedExpr = expr.As<OverloadedExpr>()) + { + expr = ResolveOverloadedExpr(overloadedExpr, LookupMask::Type); + } + + if (auto typeType = expr->Type.type->As<TypeType>()) + { + return expr; + } + else if (expr->Type.type->Equals(ExpressionType::Error)) + { + return expr; + } + + getSink()->diagnose(expr, Diagnostics::unimplemented, "expected a type"); + // TODO: construct some kind of `ErrorExpr`? + return expr; + } + + RefPtr<ExpressionType> ExpectAType(RefPtr<ExpressionSyntaxNode> expr) + { + auto typeRepr = ExpectATypeRepr(expr); + if (auto typeType = typeRepr->Type->As<TypeType>()) + { + return typeType->type; + } + return ExpressionType::Error; + } + + RefPtr<ExpressionType> ExtractGenericArgType(RefPtr<ExpressionSyntaxNode> exp) + { + return ExpectAType(exp); + } + + RefPtr<IntVal> ExtractGenericArgInteger(RefPtr<ExpressionSyntaxNode> exp) + { + return CheckIntegerConstantExpression(exp.Ptr()); + } + + RefPtr<Val> ExtractGenericArgVal(RefPtr<ExpressionSyntaxNode> exp) + { + if (auto overloadedExpr = exp.As<OverloadedExpr>()) + { + // assume that if it is overloaded, we want a type + exp = ResolveOverloadedExpr(overloadedExpr, LookupMask::Type); + } + + if (auto typeType = exp->Type->As<TypeType>()) + { + return typeType->type; + } + else if (exp->Type->Equals(ExpressionType::Error)) + { + return exp->Type.type; + } + else + { + return ExtractGenericArgInteger(exp); + } + } + + // Construct a type reprsenting the instantiation of + // the given generic declaration for the given arguments. + // The arguments should already be checked against + // the declaration. + RefPtr<ExpressionType> InstantiateGenericType( + GenericDeclRef genericDeclRef, + List<RefPtr<ExpressionSyntaxNode>> const& args) + { + RefPtr<Substitutions> subst = new Substitutions(); + subst->genericDecl = genericDeclRef.GetDecl(); + subst->outer = genericDeclRef.substitutions; + + for (auto argExpr : args) + { + subst->args.Add(ExtractGenericArgVal(argExpr)); + } + + DeclRef innerDeclRef; + innerDeclRef.decl = genericDeclRef.GetInner(); + innerDeclRef.substitutions = subst; + + return DeclRefType::Create(innerDeclRef); + } + + // Make sure a declaration has been checked, so we can refer to it. + // Note that this may lead to us recursively invoking checking, + // so this may not be the best way to handle things. + void EnsureDecl(RefPtr<Decl> decl, DeclCheckState state = DeclCheckState::CheckedHeader) + { + if (decl->IsChecked(state)) return; + if (decl->checkState == DeclCheckState::CheckingHeader) + { + // We tried to reference the same declaration while checking it! + throw "circularity"; + } + + if (DeclCheckState::CheckingHeader > decl->checkState) + { + decl->SetCheckState(DeclCheckState::CheckingHeader); + } + + // TODO: not all of the `Visit` cases are ready to + // handle this being called on-the-fly + decl->Accept(this); + + decl->SetCheckState(DeclCheckState::Checked); + } + + void EnusreAllDeclsRec(RefPtr<Decl> decl) + { + EnsureDecl(decl, DeclCheckState::Checked); + if (auto containerDecl = decl.As<ContainerDecl>()) + { + for (auto m : containerDecl->Members) + { + EnusreAllDeclsRec(m); + } + } + } + + // A "proper" type is one that can be used as the type of an expression. + // Put simply, it can be a concrete type like `int`, or a generic + // type that is applied to arguments, like `Texture2D<float4>`. + // The type `void` is also a proper type, since we can have expressions + // that return a `void` result (e.g., many function calls). + // + // A "non-proper" type is any type that can't actually have values. + // A simple example of this in C++ is `std::vector` - you can't have + // a value of this type. + // + // Part of what this function does is give errors if somebody tries + // to use a non-proper type as the type of a variable (or anything + // else that needs a proper type). + // + // The other thing it handles is the fact that HLSL lets you use + // the name of a non-proper type, and then have the compiler fill + // in the default values for its type arguments (e.g., a variable + // given type `Texture2D` will actually have type `Texture2D<float4>`). + bool CoerceToProperTypeImpl(TypeExp const& typeExp, RefPtr<ExpressionType>* outProperType) + { + ExpressionType* type = typeExp.type.Ptr(); + if (auto genericDeclRefType = type->As<GenericDeclRefType>()) + { + // We are using a reference to a generic declaration as a concrete + // type. This means we should substitute in any default parameter values + // if they are available. + // + // TODO(tfoley): A more expressive type system would substitute in + // "fresh" variables and then solve for their values... + // + + auto genericDeclRef = genericDeclRefType->GetDeclRef(); + EnsureDecl(genericDeclRef.decl); + List<RefPtr<ExpressionSyntaxNode>> args; + for (RefPtr<Decl> member : genericDeclRef.GetDecl()->Members) + { + if (auto typeParam = member.As<GenericTypeParamDecl>()) + { + if (!typeParam->initType.exp) + { + if (outProperType) + { + getSink()->diagnose(typeExp.exp.Ptr(), Diagnostics::unimplemented, "can't fill in default for generic type parameter"); + *outProperType = ExpressionType::Error; + } + return false; + } + + // TODO: this is one place where syntax should get cloned! + if(outProperType) + args.Add(typeParam->initType.exp); + } + else if (auto valParam = member.As<GenericValueParamDecl>()) + { + if (!valParam->Expr) + { + if (outProperType) + { + getSink()->diagnose(typeExp.exp.Ptr(), Diagnostics::unimplemented, "can't fill in default for generic type parameter"); + *outProperType = ExpressionType::Error; + } + return false; + } + + // TODO: this is one place where syntax should get cloned! + if(outProperType) + args.Add(valParam->Expr); + } + else + { + // ignore non-parameter members + } + } + + if (outProperType) + { + *outProperType = InstantiateGenericType(genericDeclRef, args); + } + return true; + } + else + { + // default case: we expect this to already be a proper type + if (outProperType) + { + *outProperType = type; + } + return true; + } + } + + + + TypeExp CoerceToProperType(TypeExp const& typeExp) + { + TypeExp result = typeExp; + CoerceToProperTypeImpl(typeExp, &result.type); + return result; + } + + bool CanCoerceToProperType(TypeExp const& typeExp) + { + return CoerceToProperTypeImpl(typeExp, nullptr); + } + + // Check a type, and coerce it to be proper + TypeExp CheckProperType(TypeExp typeExp) + { + return CoerceToProperType(TranslateTypeNode(typeExp)); + } + + // For our purposes, a "usable" type is one that can be + // used to declare a function parameter, variable, etc. + // These turn out to be all the proper types except + // `void`. + // + // TODO(tfoley): consider just allowing `void` as a + // simple example of a "unit" type, and get rid of + // this check. + TypeExp CoerceToUsableType(TypeExp const& typeExp) + { + TypeExp result = CoerceToProperType(typeExp); + ExpressionType* type = result.type.Ptr(); + if (auto basicType = type->As<BasicExpressionType>()) + { + // TODO: `void` shouldn't be a basic type, to make this easier to avoid + if (basicType->BaseType == BaseType::Void) + { + // TODO(tfoley): pick the right diagnostic message + getSink()->diagnose(result.exp.Ptr(), Diagnostics::invalidTypeVoid); + result.type = ExpressionType::Error; + return result; + } + } + return result; + } + + // Check a type, and coerce it to be usable + TypeExp CheckUsableType(TypeExp typeExp) + { + return CoerceToUsableType(TranslateTypeNode(typeExp)); + } + + RefPtr<ExpressionSyntaxNode> CheckTerm(RefPtr<ExpressionSyntaxNode> term) + { + if (!term) return nullptr; + return term->Accept(this).As<ExpressionSyntaxNode>(); + } + + RefPtr<ExpressionSyntaxNode> CreateErrorExpr(ExpressionSyntaxNode* expr) + { + expr->Type = ExpressionType::Error; + return expr; + } + + bool IsErrorExpr(RefPtr<ExpressionSyntaxNode> expr) + { + // TODO: we may want other cases here... + + if (expr->Type->Equals(ExpressionType::Error)) + return true; + + return false; + } + + // Capture the "base" expression in case this is a member reference + RefPtr<ExpressionSyntaxNode> GetBaseExpr(RefPtr<ExpressionSyntaxNode> expr) + { + if (auto memberExpr = expr.As<MemberExpressionSyntaxNode>()) + { + return memberExpr->BaseExpression; + } + else if(auto overloadedExpr = expr.As<OverloadedExpr>()) + { + return overloadedExpr->base; + } + return nullptr; + } + + public: + + typedef unsigned int ConversionCost; + enum : ConversionCost + { + // No conversion at all + kConversionCost_None = 0, + + // Conversion that is lossless and keeps the "kind" of the value the same + kConversionCost_RankPromotion = 100, + + // Conversions that are lossless, but change "kind" + kConversionCost_UnsignedToSignedPromotion = 200, + + // Conversion from signed->unsigned integer of same or greater size + kConversionCost_SignedToUnsignedConversion = 300, + + // Cost of converting an integer to a floating-point type + kConversionCost_IntegerToFloatConversion = 400, + + // Catch-all for conversions that should be discouraged + // (i.e., that really shouldn't be made implicitly) + // + // TODO: make these conversions not be allowed implicitly in "Slang mode" + kConversionCost_GeneralConversion = 900, + + // Additional conversion cost to add when promoting from a scalar to + // a vector (this will be added to the cost, if any, of converting + // the element type of the vector) + kConversionCost_ScalarToVector = 1, + }; + + enum BaseTypeConversionKind : uint8_t + { + kBaseTypeConversionKind_Signed, + kBaseTypeConversionKind_Unsigned, + kBaseTypeConversionKind_Float, + kBaseTypeConversionKind_Error, + }; + + enum BaseTypeConversionRank : uint8_t + { + kBaseTypeConversionRank_Bool, + kBaseTypeConversionRank_Int8, + kBaseTypeConversionRank_Int16, + kBaseTypeConversionRank_Int32, + kBaseTypeConversionRank_IntPtr, + kBaseTypeConversionRank_Int64, + kBaseTypeConversionRank_Error, + }; + + struct BaseTypeConversionInfo + { + BaseTypeConversionKind kind; + BaseTypeConversionRank rank; + }; + static BaseTypeConversionInfo GetBaseTypeConversionInfo(BaseType baseType) + { + switch (baseType) + { + #define CASE(TAG, KIND, RANK) \ + case BaseType::TAG: { BaseTypeConversionInfo info = {kBaseTypeConversionKind_##KIND, kBaseTypeConversionRank_##RANK}; return info; } break + + CASE(Bool, Unsigned, Bool); + CASE(Int, Signed, Int32); + CASE(UInt, Unsigned, Int32); + CASE(UInt64, Unsigned, Int64); + CASE(Float, Float, Int32); + CASE(Void, Error, Error); + + #undef CASE + + default: + break; + } + SLANG_UNREACHABLE("all cases handled"); + } + + bool ValuesAreEqual( + RefPtr<IntVal> left, + RefPtr<IntVal> right) + { + if(left == right) return true; + + if(auto leftConst = left.As<ConstantIntVal>()) + { + if(auto rightConst = right.As<ConstantIntVal>()) + { + return leftConst->value == rightConst->value; + } + } + + if(auto leftVar = left.As<GenericParamIntVal>()) + { + if(auto rightVar = right.As<GenericParamIntVal>()) + { + return leftVar->declRef.Equals(rightVar->declRef); + } + } + + return false; + } + + // Central engine for implementing implicit coercion logic + bool TryCoerceImpl( + RefPtr<ExpressionType> toType, // the target type for conversion + RefPtr<ExpressionSyntaxNode>* outToExpr, // (optional) a place to stuff the target expression + RefPtr<ExpressionType> fromType, // the source type for the conversion + RefPtr<ExpressionSyntaxNode> fromExpr, // the source expression + ConversionCost* outCost) // (optional) a place to stuff the conversion cost + { + // Easy case: the types are equal + if (toType->Equals(fromType)) + { + if (outToExpr) + *outToExpr = fromExpr; + if (outCost) + *outCost = kConversionCost_None; + return true; + } + + // If either type is an error, then let things pass. + if (toType->As<ErrorType>() || fromType->As<ErrorType>()) + { + if (outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + if (outCost) + *outCost = kConversionCost_None; + return true; + } + + // Coercion from an initializer list is allowed for many types + if( auto fromInitializerListExpr = fromExpr.As<InitializerListExpr>()) + { + auto argCount = fromInitializerListExpr->args.Count(); + + // In the case where we need to build a reuslt expression, + // we will collect the new arguments here + List<RefPtr<ExpressionSyntaxNode>> coercedArgs; + + if(auto toDeclRefType = toType->As<DeclRefType>()) + { + auto toTypeDeclRef = toDeclRefType->declRef; + if(auto toStructDeclRef = toTypeDeclRef.As<StructDeclRef>()) + { + // Trying to initialize a `struct` type given an initializer list. + // We will go through the fields in order and try to match them + // up with initializer arguments. + + + int argIndex = 0; + for(auto& fieldDeclRef : toStructDeclRef.GetMembersOfType<FieldDeclRef>()) + { + if(argIndex >= argCount) + { + // We've consumed all the arguments, so we should stop + break; + } + + auto arg = fromInitializerListExpr->args[argIndex++]; + + // + RefPtr<ExpressionSyntaxNode> coercedArg; + ConversionCost argCost; + + bool argResult = TryCoerceImpl( + fieldDeclRef.GetType(), + outToExpr ? &coercedArg : nullptr, + arg->Type, + arg, + outCost ? &argCost : nullptr); + + // No point in trying further if any argument fails + if(!argResult) + return false; + + // TODO(tfoley): what to do with cost? + // This only matters if/when we allow an initializer list as an argument to + // an overloaded call. + + if( outToExpr ) + { + coercedArgs.Add(coercedArg); + } + } + } + } + else if(auto toArrayType = toType->As<ArrayExpressionType>()) + { + // TODO(tfoley): If we can compute the size of the array statically, + // then we want to check that there aren't too many initializers present + + auto toElementType = toArrayType->BaseType; + + for(auto& arg : fromInitializerListExpr->args) + { + RefPtr<ExpressionSyntaxNode> coercedArg; + ConversionCost argCost; + + bool argResult = TryCoerceImpl( + toElementType, + outToExpr ? &coercedArg : nullptr, + arg->Type, + arg, + outCost ? &argCost : nullptr); + + // No point in trying further if any argument fails + if(!argResult) + return false; + + if( outToExpr ) + { + coercedArgs.Add(coercedArg); + } + } + } + else + { + // By default, we don't allow a type to be initialized using + // an initializer list. + return false; + } + + // For now, coercion from an initializer list has no cost + if(outCost) + { + *outCost = kConversionCost_None; + } + + // We were able to coerce all the arguments given, and so + // we need to construct a suitable expression to remember the result + if(outToExpr) + { + auto toInitializerListExpr = new InitializerListExpr(); + toInitializerListExpr->Position = fromInitializerListExpr->Position; + toInitializerListExpr->Type = toType; + toInitializerListExpr->args = coercedArgs; + + + *outToExpr = toInitializerListExpr; + } + + return true; + } + + // + + if (auto toBasicType = toType->AsBasicType()) + { + if (auto fromBasicType = fromType->AsBasicType()) + { + // Conversions between base types are always allowed, + // and the only question is what the cost will be. + + auto toInfo = GetBaseTypeConversionInfo(toBasicType->BaseType); + auto fromInfo = GetBaseTypeConversionInfo(fromBasicType->BaseType); + + // We expect identical types to have been dealt with already. + assert(toInfo.kind != fromInfo.kind || toInfo.rank != fromInfo.rank); + + if (outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + + + if (outCost) + { + // Conversions within the same kind are easist to handle + if (toInfo.kind == fromInfo.kind) + { + // If we are converting to a "larger" type, then + // we are doing a lossless promotion, and otherwise + // we are doing a demotion. + if( toInfo.rank > fromInfo.rank) + *outCost = kConversionCost_RankPromotion; + else + *outCost = kConversionCost_GeneralConversion; + } + // If we are converting from an unsigned integer type to + // a signed integer type that is guaranteed to be larger, + // then that is also a lossless promotion. + else if(toInfo.kind == kBaseTypeConversionKind_Signed + && fromInfo.kind == kBaseTypeConversionKind_Unsigned + && toInfo.rank > fromInfo.rank) + { + // TODO: probably need to weed out cases involving + // "pointer-sized" integers if these are treated + // as distinct from 32- and 64-bit types. + // E.g., there is no guarantee that conversion + // from 32-bit unsigned to pointer-sized signed + // is lossless, because pointers could be 32-bit, + // and the same applies for conversion from + // `uintptr` to `uint64`. + *outCost = kConversionCost_UnsignedToSignedPromotion; + } + // Conversion from signed to unsigned is always lossy, + // but it is preferred over conversions from unsigned + // to signed, for same-size types. + else if(toInfo.kind == kBaseTypeConversionKind_Unsigned + && fromInfo.kind == kBaseTypeConversionKind_Signed + && toInfo.rank >= fromInfo.rank) + { + *outCost = kConversionCost_SignedToUnsignedConversion; + } + // Conversion from an integer to a floating-point type + // is never considered a promotion (even when the value + // would fit in the available bits). + // If the destination type is at least 32 bits we consider + // this a reasonably good conversion, though. + else if (toInfo.kind == kBaseTypeConversionKind_Float + && toInfo.rank >= kBaseTypeConversionRank_Int32) + { + *outCost = kConversionCost_IntegerToFloatConversion; + } + // All other cases are considered as "general" conversions, + // where we don't consider any one conversion better than + // any others. + else + { + *outCost = kConversionCost_GeneralConversion; + } + } + + return true; + } + } + + if (auto toVectorType = toType->AsVectorType()) + { + if (auto fromVectorType = fromType->AsVectorType()) + { + // Conversion between vector types. + + // If element counts don't match, then bail: + if (!ValuesAreEqual(toVectorType->elementCount, fromVectorType->elementCount)) + return false; + + // Otherwise, if we can convert the element types, we are golden + ConversionCost elementCost; + if (CanCoerce(toVectorType->elementType, fromVectorType->elementType, &elementCost)) + { + if (outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + if (outCost) + *outCost = elementCost; + return true; + } + } + else if (auto fromScalarType = fromType->AsBasicType()) + { + // Conversion from scalar to vector. + // Should allow as long as we can coerce the scalar to our element type. + ConversionCost elementCost; + if (CanCoerce(toVectorType->elementType, fromScalarType, &elementCost)) + { + if (outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + if (outCost) + *outCost = elementCost + kConversionCost_ScalarToVector; + return true; + } + } + } + + // TODO: more cases! + + return false; + } + + // Check whether a type coercion is possible + bool CanCoerce( + RefPtr<ExpressionType> toType, // the target type for conversion + RefPtr<ExpressionType> fromType, // the source type for the conversion + ConversionCost* outCost = 0) // (optional) a place to stuff the conversion cost + { + return TryCoerceImpl( + toType, + nullptr, + fromType, + nullptr, + outCost); + } + + RefPtr<ExpressionSyntaxNode> CreateImplicitCastExpr( + RefPtr<ExpressionType> toType, + RefPtr<ExpressionSyntaxNode> fromExpr) + { + auto castExpr = new TypeCastExpressionSyntaxNode(); + castExpr->Position = fromExpr->Position; + castExpr->TargetType.type = toType; + castExpr->Type = toType; + castExpr->Expression = fromExpr; + return castExpr; + } + + + // Perform type coercion, and emit errors if it isn't possible + RefPtr<ExpressionSyntaxNode> Coerce( + RefPtr<ExpressionType> toType, + RefPtr<ExpressionSyntaxNode> fromExpr) + { + // If semantic checking is being suppressed, then we might see + // expressions without a type, and we need to ignore them. + if( !fromExpr->Type.type ) + { + if(getOptions().flags & SLANG_COMPILE_FLAG_NO_CHECKING ) + return fromExpr; + } + + RefPtr<ExpressionSyntaxNode> expr; + if (!TryCoerceImpl( + toType, + &expr, + fromExpr->Type.Ptr(), + fromExpr.Ptr(), + nullptr)) + { + if(!(getOptions().flags & SLANG_COMPILE_FLAG_NO_CHECKING)) + { + getSink()->diagnose(fromExpr->Position, Diagnostics::typeMismatch, toType, fromExpr->Type); + } + + // Note(tfoley): We don't call `CreateErrorExpr` here, because that would + // clobber the type on `fromExpr`, and an invariant here is that coercion + // really shouldn't *change* the expression that is passed in, but should + // introduce new AST nodes to coerce its value to a different type... + return CreateImplicitCastExpr(ExpressionType::Error, fromExpr); + } + return expr; + } + + void CheckVarDeclCommon(RefPtr<VarDeclBase> varDecl) + { + // Check the type, if one was given + TypeExp type = CheckUsableType(varDecl->Type); + + // TODO: Additional validation rules on types should go here, + // but we need to deal with the fact that some cases might be + // allowed in one context (e.g., an unsized array parameter) + // but not in othters (e.g., an unsized array field in a struct). + + // Check the initializers, if one was given + RefPtr<ExpressionSyntaxNode> initExpr = CheckTerm(varDecl->Expr); + + // If a type was given, ... + if (type.Ptr()) + { + // then coerce any initializer to the type + if (initExpr) + { + initExpr = Coerce(type, initExpr); + } + } + else + { + // TODO: infer a type from the initializers + if (!initExpr) + { + getSink()->diagnose(varDecl, Diagnostics::unimplemented, "variable declaration with no type must have initializer"); + } + else + { + getSink()->diagnose(varDecl, Diagnostics::unimplemented, "type inference for variable declaration"); + } + } + + varDecl->Type = type; + varDecl->Expr = initExpr; + } + + void CheckGenericConstraintDecl(GenericTypeConstraintDecl* decl) + { + // TODO: are there any other validations we can do at this point? + // + // There probably needs to be a kind of "occurs check" to make + // sure that the constraint actually applies to at least one + // of the parameters of the generic. + + decl->sub = TranslateTypeNode(decl->sub); + decl->sup = TranslateTypeNode(decl->sup); + } + + virtual RefPtr<GenericDecl> VisitGenericDecl(GenericDecl* genericDecl) override + { + // check the parameters + for (auto m : genericDecl->Members) + { + if (auto typeParam = m.As<GenericTypeParamDecl>()) + { + typeParam->initType = CheckProperType(typeParam->initType); + } + else if (auto valParam = m.As<GenericValueParamDecl>()) + { + // TODO: some real checking here... + CheckVarDeclCommon(valParam); + } + else if(auto constraint = m.As<GenericTypeConstraintDecl>()) + { + CheckGenericConstraintDecl(constraint.Ptr()); + } + } + + // check the nested declaration + // TODO: this needs to be done in an appropriate environment... + genericDecl->inner->Accept(this); + return genericDecl; + } + + virtual void VisitTraitConformanceDecl(TraitConformanceDecl* conformanceDecl) override + { + // check the type being conformed to + auto base = conformanceDecl->base; + base = TranslateTypeNode(base); + conformanceDecl->base = base; + + if(auto declRefType = base.type->As<DeclRefType>()) + { + if(auto traitDeclRef = declRefType->declRef.As<TraitDeclRef>()) + { + conformanceDecl->traitDeclRef = traitDeclRef; + return; + } + } + + // We expected a trait here + getSink()->diagnose( conformanceDecl, Diagnostics::expectedATraitGot, base.type); + } + + RefPtr<ConstantIntVal> checkConstantIntVal( + RefPtr<ExpressionSyntaxNode> expr) + { + // First type-check the expression as normal + expr = CheckExpr(expr); + + auto intVal = CheckIntegerConstantExpression(expr.Ptr()); + if(!intVal) + return nullptr; + + auto constIntVal = intVal.As<ConstantIntVal>(); + if(!constIntVal) + { + getSink()->diagnose(expr->Position, Diagnostics::expectedIntegerConstantNotLiteral); + return nullptr; + } + return constIntVal; + } + + RefPtr<Modifier> checkModifier( + RefPtr<Modifier> m, + Decl* decl) + { + if(auto hlslUncheckedAttribute = m.As<HLSLUncheckedAttribute>()) + { + // We have an HLSL `[name(arg,...)]` attribute, and we'd like + // to check that it is provides all the expected arguments + // + // For now we will do this in a completely ad hoc fashion, + // but it would be nice to have some generic routine to + // do the needed type checking/coercion. + if(hlslUncheckedAttribute->nameToken.Content == "numthreads") + { + if(hlslUncheckedAttribute->args.Count() != 3) + return m; + + auto xVal = checkConstantIntVal(hlslUncheckedAttribute->args[0]); + auto yVal = checkConstantIntVal(hlslUncheckedAttribute->args[1]); + auto zVal = checkConstantIntVal(hlslUncheckedAttribute->args[2]); + + if(!xVal) return m; + if(!yVal) return m; + if(!zVal) return m; + + auto hlslNumThreadsAttribute = new HLSLNumThreadsAttribute(); + + hlslNumThreadsAttribute->Position = hlslUncheckedAttribute->Position; + hlslNumThreadsAttribute->nameToken = hlslUncheckedAttribute->nameToken; + hlslNumThreadsAttribute->args = hlslUncheckedAttribute->args; + hlslNumThreadsAttribute->x = xVal->value; + hlslNumThreadsAttribute->y = yVal->value; + hlslNumThreadsAttribute->z = zVal->value; + + return hlslNumThreadsAttribute; + } + } + + // Default behavior is to leave things as they are, + // and assume that modifiers are mostly already checked. + // + // TODO: This would be a good place to validate that + // a modifier is actually valid for the thing it is + // being applied to, and potentially to check that + // it isn't in conflict with any other modifiers + // on the same declaration. + + return m; + } + + + void checkModifiers(Decl* decl) + { + // TODO(tfoley): need to make sure this only + // performs semantic checks on a `SharedModifier` once... + + // The process of checking a modifier may produce a new modifier in its place, + // so we will build up a new linked list of modifiers that will replace + // the old list. + RefPtr<Modifier> resultModifiers; + RefPtr<Modifier>* resultModifierLink = &resultModifiers; + + RefPtr<Modifier> modifier = decl->modifiers.first; + while(modifier) + { + // Because we are rewriting the list in place, we need to extract + // the next modifier here (not at the end of the loop). + auto next = modifier->next; + + // We also go ahead and clobber the `next` field on the modifier + // itself, so that the default behavior of `checkModifier()` can + // be to return a single unlinked modifier. + modifier->next = nullptr; + + auto checkedModifier = checkModifier(modifier, decl); + if(checkedModifier) + { + // If checking gave us a modifier to add, then we + // had better add it. + + // Just in case `checkModifier` ever returns multiple + // modifiers, lets advance to the end of the list we + // are building. + while(*resultModifierLink) + resultModifierLink = &(*resultModifierLink)->next; + + // attach the new modifier at the end of the list, + // and now set the "link" to point to its `next` field + *resultModifierLink = checkedModifier; + resultModifierLink = &checkedModifier->next; + } + + // Move along to the next modifier + modifier = next; + } + + // Whether we actually re-wrote anything or note, lets + // install the new list of modifiers on the declaration + decl->modifiers.first = resultModifiers; + } + + virtual RefPtr<ProgramSyntaxNode> VisitProgram(ProgramSyntaxNode * programNode) override + { + // Try to register all the builtin decls + for (auto decl : programNode->Members) + { + auto inner = decl; + if (auto genericDecl = decl.As<GenericDecl>()) + { + inner = genericDecl->inner; + } + + if (auto builtinMod = inner->FindModifier<BuiltinTypeModifier>()) + { + RegisterBuiltinDecl(decl, builtinMod); + } + if (auto magicMod = inner->FindModifier<MagicTypeModifier>()) + { + RegisterMagicDecl(decl, magicMod); + } + } + + // + + HashSet<String> funcNames; + this->program = programNode; + this->function = nullptr; + + for (auto & s : program->GetTypeDefs()) + VisitTypeDefDecl(s.Ptr()); + for (auto & s : program->GetStructs()) + { + VisitStruct(s.Ptr()); + } + for (auto & s : program->GetClasses()) + { + VisitClass(s.Ptr()); + } + // HACK(tfoley): Visiting all generic declarations here, + // because otherwise they won't get visited. + for (auto & g : program->GetMembersOfType<GenericDecl>()) + { + VisitGenericDecl(g.Ptr()); + } + + for (auto & func : program->GetFunctions()) + { + if (!func->IsChecked(DeclCheckState::Checked)) + { + VisitFunctionDeclaration(func.Ptr()); + } + } + for (auto & func : program->GetFunctions()) + { + EnsureDecl(func); + } + + if (sink->GetErrorCount() != 0) + return programNode; + + // Force everything to be fully checked, just in case + // Note that we don't just call this on the program, + // because we'd end up recursing into this very code path... + for (auto d : programNode->Members) + { + EnusreAllDeclsRec(d); + } + + // Do any semantic checking required on modifiers? + for (auto d : programNode->Members) + { + checkModifiers(d.Ptr()); + } + + return programNode; + } + + virtual RefPtr<ClassSyntaxNode> VisitClass(ClassSyntaxNode * classNode) override + { + if (classNode->IsChecked(DeclCheckState::Checked)) + return classNode; + classNode->SetCheckState(DeclCheckState::Checked); + + for (auto field : classNode->GetFields()) + { + field->Type = CheckUsableType(field->Type); + field->SetCheckState(DeclCheckState::Checked); + } + return classNode; + } + + virtual RefPtr<StructSyntaxNode> VisitStruct(StructSyntaxNode * structNode) override + { + if (structNode->IsChecked(DeclCheckState::Checked)) + return structNode; + structNode->SetCheckState(DeclCheckState::Checked); + + for (auto field : structNode->GetFields()) + { + field->Type = CheckUsableType(field->Type); + field->SetCheckState(DeclCheckState::Checked); + } + return structNode; + } + + virtual RefPtr<TypeDefDecl> VisitTypeDefDecl(TypeDefDecl* decl) override + { + if (decl->IsChecked(DeclCheckState::Checked)) return decl; + + decl->SetCheckState(DeclCheckState::CheckingHeader); + decl->Type = CheckProperType(decl->Type); + decl->SetCheckState(DeclCheckState::Checked); + return decl; + } + + virtual RefPtr<FunctionSyntaxNode> VisitFunction(FunctionSyntaxNode *functionNode) override + { + if (functionNode->IsChecked(DeclCheckState::Checked)) + return functionNode; + + VisitFunctionDeclaration(functionNode); + functionNode->SetCheckState(DeclCheckState::Checked); + + if (!functionNode->IsExtern()) + { + this->function = functionNode; + if (functionNode->Body) + { + functionNode->Body->Accept(this); + } + this->function = nullptr; + } + return functionNode; + } + + // Check if two functions have the same signature for the purposes + // of overload resolution. + bool DoFunctionSignaturesMatch( + FunctionSyntaxNode* fst, + FunctionSyntaxNode* snd) + { + // TODO(tfoley): This function won't do anything sensible for generics, + // so we need to figure out a plan for that... + + // TODO(tfoley): This copies the parameter array, which is bad for performance. + auto fstParams = fst->GetParameters().ToArray(); + auto sndParams = snd->GetParameters().ToArray(); + + // If the functions have different numbers of parameters, then + // their signatures trivially don't match. + auto fstParamCount = fstParams.Count(); + auto sndParamCount = sndParams.Count(); + if (fstParamCount != sndParamCount) + return false; + + for (int ii = 0; ii < fstParamCount; ++ii) + { + auto fstParam = fstParams[ii]; + auto sndParam = sndParams[ii]; + + // If a given parameter type doesn't match, then signatures don't match + if (!fstParam->Type.Equals(sndParam->Type)) + return false; + + // If one parameter is `out` and the other isn't, then they don't match + // + // Note(tfoley): we don't consider `out` and `inout` as distinct here, + // because there is no way for overload resolution to pick between them. + if (fstParam->HasModifier<OutModifier>() != sndParam->HasModifier<OutModifier>()) + return false; + } + + // Note(tfoley): return type doesn't enter into it, because we can't take + // calling context into account during overload resolution. + + return true; + } + + void ValidateFunctionRedeclaration(FunctionSyntaxNode* funcDecl) + { + auto parentDecl = funcDecl->ParentDecl; + assert(parentDecl); + if (!parentDecl) return; + + // Look at previously-declared functions with the same name, + // in the same container + buildMemberDictionary(parentDecl); + + for (auto prevDecl = funcDecl->nextInContainerWithSameName; prevDecl; prevDecl = prevDecl->nextInContainerWithSameName) + { + // Look through generics to the declaration underneath + auto prevGenericDecl = dynamic_cast<GenericDecl*>(prevDecl); + if (prevGenericDecl) + prevDecl = prevGenericDecl->inner.Ptr(); + + // We only care about previously-declared functions + // Note(tfoley): although we should really error out if the + // name is already in use for something else, like a variable... + auto prevFuncDecl = dynamic_cast<FunctionSyntaxNode*>(prevDecl); + if (!prevFuncDecl) + continue; + + // If the parameter signatures don't match, then don't worry + if (!DoFunctionSignaturesMatch(funcDecl, prevFuncDecl)) + continue; + + // If we get this far, then we've got two declarations in the same + // scope, with the same name and signature. + // + // They might just be redeclarations, which we would want to allow. + + // First, check if the return types match. + // TODO(tfolye): this code won't work for generics + if (!funcDecl->ReturnType.Equals(prevFuncDecl->ReturnType)) + { + // Bad dedeclaration + getSink()->diagnose(funcDecl, Diagnostics::unimplemented, "redeclaration has a different return type"); + + // Don't bother emitting other errors at this point + break; + } + + // TODO(tfoley): track the fact that there is redeclaration going on, + // so that we can detect it and react accordingly during overload resolution + // (e.g., by only considering one declaration as the canonical one...) + + // If both have a body, then there is trouble + if (funcDecl->Body && prevFuncDecl->Body) + { + // Redefinition + getSink()->diagnose(funcDecl, Diagnostics::unimplemented, "function redefinition"); + + // Don't bother emitting other errors + break; + } + + // TODO(tfoley): If both specific default argument expressions + // for the same value, then that is an error too... + } + } + + void VisitFunctionDeclaration(FunctionSyntaxNode *functionNode) + { + if (functionNode->IsChecked(DeclCheckState::CheckedHeader)) return; + functionNode->SetCheckState(DeclCheckState::CheckingHeader); + + this->function = functionNode; + auto returnType = CheckProperType(functionNode->ReturnType); + functionNode->ReturnType = returnType; + HashSet<String> paraNames; + for (auto & para : functionNode->GetParameters()) + { + if (paraNames.Contains(para->Name.Content)) + getSink()->diagnose(para, Diagnostics::parameterAlreadyDefined, para->Name); + else + paraNames.Add(para->Name.Content); + para->Type = CheckUsableType(para->Type); + if (para->Type.Equals(ExpressionType::GetVoid())) + getSink()->diagnose(para, Diagnostics::parameterCannotBeVoid); + } + this->function = NULL; + functionNode->SetCheckState(DeclCheckState::CheckedHeader); + + // One last bit of validation: check if we are redeclaring an existing function + ValidateFunctionRedeclaration(functionNode); + } + + virtual RefPtr<StatementSyntaxNode> VisitBlockStatement(BlockStatementSyntaxNode *stmt) override + { + for (auto & node : stmt->Statements) + { + node->Accept(this); + } + return stmt; + } + + template<typename T> + T* FindOuterStmt() + { + int outerStmtCount = outerStmts.Count(); + for (int ii = outerStmtCount - 1; ii >= 0; --ii) + { + auto outerStmt = outerStmts[ii]; + auto found = dynamic_cast<T*>(outerStmt); + if (found) + return found; + } + return nullptr; + } + + virtual RefPtr<StatementSyntaxNode> VisitBreakStatement(BreakStatementSyntaxNode *stmt) override + { + auto outer = FindOuterStmt<BreakableStmt>(); + if (!outer) + { + getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop); + } + stmt->parentStmt = outer; + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitContinueStatement(ContinueStatementSyntaxNode *stmt) override + { + auto outer = FindOuterStmt<LoopStmt>(); + if (!outer) + { + getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop); + } + stmt->parentStmt = outer; + return stmt; + } + + void PushOuterStmt(StatementSyntaxNode* stmt) + { + outerStmts.Add(stmt); + } + + void PopOuterStmt(StatementSyntaxNode* /*stmt*/) + { + outerStmts.RemoveAt(outerStmts.Count() - 1); + } + + virtual RefPtr<StatementSyntaxNode> VisitDoWhileStatement(DoWhileStatementSyntaxNode *stmt) override + { + PushOuterStmt(stmt); + if (stmt->Predicate != NULL) + stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); + if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) && + !stmt->Predicate->Type->Equals(ExpressionType::GetInt()) && + !stmt->Predicate->Type->Equals(ExpressionType::GetBool())) + { + getSink()->diagnose(stmt, Diagnostics::whilePredicateTypeError); + } + stmt->Statement->Accept(this); + + PopOuterStmt(stmt); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitForStatement(ForStatementSyntaxNode *stmt) override + { + PushOuterStmt(stmt); + if (stmt->InitialStatement) + { + stmt->InitialStatement = stmt->InitialStatement->Accept(this).As<StatementSyntaxNode>(); + } + if (stmt->PredicateExpression) + { + stmt->PredicateExpression = stmt->PredicateExpression->Accept(this).As<ExpressionSyntaxNode>(); + if (!stmt->PredicateExpression->Type->Equals(ExpressionType::GetBool()) && + !stmt->PredicateExpression->Type->Equals(ExpressionType::GetInt()) && + !stmt->PredicateExpression->Type->Equals(ExpressionType::GetUInt())) + { + getSink()->diagnose(stmt->PredicateExpression.Ptr(), Diagnostics::forPredicateTypeError); + } + } + if (stmt->SideEffectExpression) + { + stmt->SideEffectExpression = stmt->SideEffectExpression->Accept(this).As<ExpressionSyntaxNode>(); + } + stmt->Statement->Accept(this); + + PopOuterStmt(stmt); + return stmt; + } + virtual RefPtr<SwitchStmt> VisitSwitchStmt(SwitchStmt* stmt) override + { + PushOuterStmt(stmt); + // TODO(tfoley): need to coerce condition to an integral type... + stmt->condition = CheckExpr(stmt->condition); + stmt->body->Accept(this); + PopOuterStmt(stmt); + return stmt; + } + virtual RefPtr<CaseStmt> VisitCaseStmt(CaseStmt* stmt) override + { + auto expr = CheckExpr(stmt->expr); + auto switchStmt = FindOuterStmt<SwitchStmt>(); + + if (!switchStmt) + { + getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); + } + else + { + // TODO: need to do some basic matching to ensure the type + // for the `case` is consistent with the type for the `switch`... + } + + stmt->expr = expr; + stmt->parentStmt = switchStmt; + + return stmt; + } + virtual RefPtr<DefaultStmt> VisitDefaultStmt(DefaultStmt* stmt) override + { + auto switchStmt = FindOuterStmt<SwitchStmt>(); + if (!switchStmt) + { + getSink()->diagnose(stmt, Diagnostics::defaultOutsideSwitch); + } + stmt->parentStmt = switchStmt; + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitIfStatement(IfStatementSyntaxNode *stmt) override + { + auto condition = stmt->Predicate; + condition = CheckTerm(condition); + condition = Coerce(ExpressionType::GetBool(), condition); + + stmt->Predicate = condition; + +#if 0 + if (stmt->Predicate != NULL) + stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); + if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) + && (!stmt->Predicate->Type->Equals(ExpressionType::GetInt()) && + !stmt->Predicate->Type->Equals(ExpressionType::GetBool()))) + getSink()->diagnose(stmt, Diagnostics::ifPredicateTypeError); +#endif + + if (stmt->PositiveStatement != NULL) + stmt->PositiveStatement->Accept(this); + + if (stmt->NegativeStatement != NULL) + stmt->NegativeStatement->Accept(this); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitReturnStatement(ReturnStatementSyntaxNode *stmt) override + { + if (!stmt->Expression) + { + if (function && !function->ReturnType.Equals(ExpressionType::GetVoid())) + getSink()->diagnose(stmt, Diagnostics::returnNeedsExpression); + } + else + { + stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); + if (!stmt->Expression->Type->Equals(ExpressionType::Error.Ptr())) + { + if (function) + { + stmt->Expression = Coerce(function->ReturnType, stmt->Expression); + } + else + { + // TODO(tfoley): this case currently gets triggered for member functions, + // which aren't being checked consistently (because of the whole symbol + // table idea getting in the way). + +// getSink()->diagnose(stmt, Diagnostics::unimplemented, "case for return stmt"); + } + } + } + return stmt; + } + + int GetMinBound(RefPtr<IntVal> val) + { + if (auto constantVal = val.As<ConstantIntVal>()) + return constantVal->value; + + // TODO(tfoley): Need to track intervals so that this isn't just a lie... + return 1; + } + + void maybeInferArraySizeForVariable(Variable* varDecl) + { + // Not an array? + auto arrayType = varDecl->Type->AsArrayType(); + if (!arrayType) return; + + // Explicit element count given? + auto elementCount = arrayType->ArrayLength; + if (elementCount) return; + + // No initializer? + auto initExpr = varDecl->Expr; + if(!initExpr) return; + + // Is the initializer an initializer list? + if(auto initializerListExpr = initExpr.As<InitializerListExpr>()) + { + auto argCount = initializerListExpr->args.Count(); + elementCount = new ConstantIntVal(argCount); + } + // Is the type of the initializer an array type? + else if(auto arrayInitType = initExpr->Type->As<ArrayExpressionType>()) + { + elementCount = arrayInitType->ArrayLength; + } + else + { + // Nothing to do: we couldn't infer a size + return; + } + + // Create a new array type based on the size we found, + // and install it into our type. + auto newArrayType = new ArrayExpressionType(); + newArrayType->BaseType = arrayType->BaseType; + newArrayType->ArrayLength = elementCount; + + // Okay we are good to go! + varDecl->Type.type = newArrayType; + } + + void ValidateArraySizeForVariable(Variable* varDecl) + { + auto arrayType = varDecl->Type->AsArrayType(); + if (!arrayType) return; + + auto elementCount = arrayType->ArrayLength; + if (!elementCount) + { + // Note(tfoley): For now we allow arrays of unspecified size + // everywhere, because some source languages (e.g., GLSL) + // allow them in specific cases. +#if 0 + getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); +#endif + return; + } + + // TODO(tfoley): How to handle the case where bound isn't known? + if (GetMinBound(elementCount) <= 0) + { + getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); + return; + } + } + + virtual RefPtr<Variable> VisitDeclrVariable(Variable* varDecl) + { + TypeExp typeExp = CheckUsableType(varDecl->Type); +#if 0 + if (typeExp.type->GetBindableResourceType() != BindableResourceType::NonBindable) + { + // We don't want to allow bindable resource types as local variables (at least for now). + auto parentDecl = varDecl->ParentDecl; + if (auto parentScopeDecl = dynamic_cast<ScopeDecl*>(parentDecl)) + { + getSink()->diagnose(varDecl->Type, Diagnostics::invalidTypeForLocalVariable); + } + } +#endif + varDecl->Type = typeExp; + if (varDecl->Type.Equals(ExpressionType::GetVoid())) + getSink()->diagnose(varDecl, Diagnostics::invalidTypeVoid); + + if(auto initExpr = varDecl->Expr) + { + initExpr = CheckTerm(initExpr); + varDecl->Expr = initExpr; + } + + // If this is an array variable, then we first want to give + // it a chance to infer an array size from its initializer + // + // TODO(tfoley): May need to extend this to handle the + // multi-dimensional case... + maybeInferArraySizeForVariable(varDecl); + // + // Next we want to make sure that the declared (or inferred) + // size for the array meets whatever language-specific + // constraints we want to enforce (e.g., disallow empty + // arrays in specific cases) + ValidateArraySizeForVariable(varDecl); + + + if(auto initExpr = varDecl->Expr) + { + // TODO(tfoley): should coercion of initializer lists be special-cased + // here, or handled as a general case for coercion? + + initExpr = Coerce(varDecl->Type, initExpr); + varDecl->Expr = initExpr; + } + + varDecl->SetCheckState(DeclCheckState::Checked); + + return varDecl; + } + + virtual RefPtr<StatementSyntaxNode> VisitWhileStatement(WhileStatementSyntaxNode *stmt) override + { + PushOuterStmt(stmt); + stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); + if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) && + !stmt->Predicate->Type->Equals(ExpressionType::GetInt()) && + !stmt->Predicate->Type->Equals(ExpressionType::GetBool())) + getSink()->diagnose(stmt, Diagnostics::whilePredicateTypeError2); + + stmt->Statement->Accept(this); + PopOuterStmt(stmt); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitExpressionStatement(ExpressionStatementSyntaxNode *stmt) override + { + stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); + return stmt; + } + virtual RefPtr<ExpressionSyntaxNode> VisitOperatorExpression(OperatorExpressionSyntaxNode *expr) override + { +#if 0 + + for (int i = 0; i < expr->Arguments.Count(); i++) + expr->Arguments[i] = expr->Arguments[i]->Accept(this).As<ExpressionSyntaxNode>(); + auto & leftType = expr->Arguments[0]->Type; + QualType rightType; + if (expr->Arguments.Count() == 2) + rightType = expr->Arguments[1]->Type; + RefPtr<ExpressionType> matchedType; + auto checkAssign = [&]() + { + if (!leftType.IsLeftValue && + !leftType->Equals(ExpressionType::Error.Ptr())) + getSink()->diagnose(expr->Arguments[0].Ptr(), Diagnostics::assignNonLValue); + if (expr->Operator == Operator::AndAssign || + expr->Operator == Operator::OrAssign || + expr->Operator == Operator::XorAssign || + expr->Operator == Operator::LshAssign || + expr->Operator == Operator::RshAssign) + { +#if 0 + if (!(leftType->IsIntegral() && rightType->IsIntegral())) + { + // TODO(tfoley): This diagnostic shouldn't be handled here +// getSink()->diagnose(expr, Diagnostics::bitOperationNonIntegral); + } +#endif + } + + // TODO(tfoley): Need to actual insert coercion here... + if(CanCoerce(leftType, expr->Type)) + expr->Type = leftType; + else + expr->Type = ExpressionType::Error; + }; +#if 0 + if (expr->Operator == Operator::Assign) + { + expr->Type = rightType; + checkAssign(); + } + else +#endif + { + expr->FunctionExpr = CheckExpr(expr->FunctionExpr); + CheckInvokeExprWithCheckedOperands(expr); + if (expr->Operator > Operator::Assign) + checkAssign(); + } + return expr; +#endif + + // Treat operator application just like a function call + return VisitInvokeExpression(expr); + } + virtual RefPtr<ExpressionSyntaxNode> VisitConstantExpression(ConstantExpressionSyntaxNode *expr) override + { + switch (expr->ConstType) + { + case ConstantExpressionSyntaxNode::ConstantType::Int: + expr->Type = ExpressionType::GetInt(); + break; + case ConstantExpressionSyntaxNode::ConstantType::Bool: + expr->Type = ExpressionType::GetBool(); + break; + case ConstantExpressionSyntaxNode::ConstantType::Float: + expr->Type = ExpressionType::GetFloat(); + break; + default: + expr->Type = ExpressionType::Error; + throw "Invalid constant type."; + break; + } + return expr; + } + + IntVal* GetIntVal(ConstantExpressionSyntaxNode* expr) + { + // TODO(tfoley): don't keep allocating here! + return new ConstantIntVal(expr->IntValue); + } + + RefPtr<IntVal> TryConstantFoldExpr( + InvokeExpressionSyntaxNode* invokeExpr) + { + // We need all the operands to the expression + + // Check if the callee is an operation that is amenable to constant-folding. + // + // For right now we will look for calls to intrinsic functions, and then inspect + // their names (this is bad and slow). + auto funcDeclRefExpr = invokeExpr->FunctionExpr.As<DeclRefExpr>(); + if (!funcDeclRefExpr) return nullptr; + + auto funcDeclRef = funcDeclRefExpr->declRef; + auto intrinsicMod = funcDeclRef.GetDecl()->FindModifier<IntrinsicModifier>(); + if (!intrinsicMod) return nullptr; + + // Let's not constant-fold operations with more than a certain number of arguments, for simplicity + static const int kMaxArgs = 8; + if (invokeExpr->Arguments.Count() > kMaxArgs) + return nullptr; + + // Before checking the operation name, let's look at the arguments + RefPtr<IntVal> argVals[kMaxArgs]; + int constArgVals[kMaxArgs]; + int argCount = 0; + bool allConst = true; + for (auto argExpr : invokeExpr->Arguments) + { + auto argVal = TryCheckIntegerConstantExpression(argExpr.Ptr()); + if (!argVal) + return nullptr; + + argVals[argCount] = argVal; + + if (auto constArgVal = argVal.As<ConstantIntVal>()) + { + constArgVals[argCount] = constArgVal->value; + } + else + { + allConst = false; + } + argCount++; + } + + if (!allConst) + { + // TODO(tfoley): We probably want to support a very limited number of operations + // on "constants" that aren't actually known, to be able to handle a generic + // that takes an integer `N` but then constructs a vector of size `N+1`. + // + // The hard part there is implementing the rules for value unification in the + // presence of more complicated `IntVal` subclasses, like `SumIntVal`. You'd + // need inference to be smart enough to know that `2 + N` and `N + 2` are the + // same value, as are `N + M + 1 + 1` and `M + 2 + N`. + // + // For now we can just bail in this case. + return nullptr; + } + + // At this point, all the operands had simple integer values, so we are golden. + int resultValue = 0; + auto opName = funcDeclRef.GetName(); + + // handle binary operators + if (opName == "-") + { + if (argCount == 1) + { + resultValue = -constArgVals[0]; + } + else if (argCount == 2) + { + resultValue = constArgVals[0] - constArgVals[1]; + } + } + + // simple binary operators +#define CASE(OP) \ + else if(opName == #OP) do { \ + if(argCount != 2) return nullptr; \ + resultValue = constArgVals[0] OP constArgVals[1]; \ + } while(0) + + CASE(+); // TODO: this can also be unary... + CASE(*); +#undef CASE + + // binary operators with chance of divide-by-zero + // TODO: issue a suitable error in that case +#define CASE(OP) \ + else if(opName == #OP) do { \ + if(argCount != 2) return nullptr; \ + if(!constArgVals[1]) return nullptr; \ + resultValue = constArgVals[0] OP constArgVals[1]; \ + } while(0) + + CASE(/); + CASE(%); +#undef CASE + + // TODO(tfoley): more cases + else + { + return nullptr; + } + + RefPtr<IntVal> result = new ConstantIntVal(resultValue); + return result; + } + + RefPtr<IntVal> TryConstantFoldExpr( + ExpressionSyntaxNode* expr) + { + // TODO(tfoley): more serious constant folding here + if (auto constExp = dynamic_cast<ConstantExpressionSyntaxNode*>(expr)) + { + return GetIntVal(constExp); + } + + // it is possible that we are referring to a generic value param + if (auto declRefExpr = dynamic_cast<DeclRefExpr*>(expr)) + { + auto declRef = declRefExpr->declRef; + + if (auto genericValParamRef = declRef.As<GenericValueParamDeclRef>()) + { + // TODO(tfoley): handle the case of non-`int` value parameters... + return new GenericParamIntVal(genericValParamRef); + } + + // We may also need to check for references to variables that + // are defined in a way that can be used as a constant expression: + if(auto varRef = declRef.As<VarDeclBaseRef>()) + { + auto varDecl = varRef.GetDecl(); + + switch(sourceLanguage) + { + case SourceLanguage::Slang: + case SourceLanguage::HLSL: + // HLSL: `static const` is used to mark compile-time constant expressions + if(auto staticAttr = varDecl->FindModifier<HLSLStaticModifier>()) + { + if(auto constAttr = varDecl->FindModifier<ConstModifier>()) + { + // HLSL `static const` can be used as a constant expression + if(auto initExpr = varRef.getInitExpr()) + { + return TryConstantFoldExpr(initExpr.Ptr()); + } + } + } + break; + + case SourceLanguage::GLSL: + // GLSL: `const` indicates compile-time constant expression + // + // TODO(tfoley): The current logic here isn't robust against + // GLSL "specialization constants" - we will extract the + // initializer for a `const` variable and use it to extract + // a value, when we really should be using an opaque + // reference to the variable. + if(auto constAttr = varDecl->FindModifier<ConstModifier>()) + { + // We need to handle a "specialization constant" (with a `constant_id` layout modifier) + // differently from an ordinary compile-time constant. The latter can/should be reduced + // to a value, while the former should be kept as a symbolic reference + + if(auto constantIDModifier = varDecl->FindModifier<GLSLConstantIDLayoutModifier>()) + { + // Retain the specialization constant as a symbolic reference + // + // TODO(tfoley): handle the case of non-`int` value parameters... + // + // TODO(tfoley): this is cloned from the case above that handles generic value parameters + return new GenericParamIntVal(varRef); + } + else if(auto initExpr = varRef.getInitExpr()) + { + // This is an ordinary constant, and not a specialization constant, so we + // can try to fold its value right now. + return TryConstantFoldExpr(initExpr.Ptr()); + } + } + break; + } + + } + } + + if (auto invokeExpr = dynamic_cast<InvokeExpressionSyntaxNode*>(expr)) + { + auto val = TryConstantFoldExpr(invokeExpr); + if (val) + return val; + } + else if(auto castExpr = dynamic_cast<TypeCastExpressionSyntaxNode*>(expr)) + { + auto val = TryConstantFoldExpr(castExpr->Expression.Ptr()); + if(val) + return val; + } + + return nullptr; + } + + // Try to check an integer constant expression, either returning the value, + // or NULL if the expression isn't recognized as a constant. + RefPtr<IntVal> TryCheckIntegerConstantExpression(ExpressionSyntaxNode* exp) + { + if (!exp->Type.type->Equals(ExpressionType::GetInt())) + { + return nullptr; + } + + + + // Otherwise, we need to consider operations that we might be able to constant-fold... + return TryConstantFoldExpr(exp); + } + + // Enforce that an expression resolves to an integer constant, and get its value + RefPtr<IntVal> CheckIntegerConstantExpression(ExpressionSyntaxNode* inExpr) + { + // First coerce the expression to the expected type + auto expr = Coerce(ExpressionType::GetInt(),inExpr); + auto result = TryCheckIntegerConstantExpression(expr.Ptr()); + if (!result) + { + getSink()->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); + } + return result; + } + + + + RefPtr<ExpressionSyntaxNode> CheckSimpleSubscriptExpr( + RefPtr<IndexExpressionSyntaxNode> subscriptExpr, + RefPtr<ExpressionType> elementType) + { + auto baseExpr = subscriptExpr->BaseExpression; + auto indexExpr = subscriptExpr->IndexExpression; + + if (!indexExpr->Type->Equals(ExpressionType::GetInt()) && + !indexExpr->Type->Equals(ExpressionType::GetUInt())) + { + getSink()->diagnose(indexExpr, Diagnostics::subscriptIndexNonInteger); + return CreateErrorExpr(subscriptExpr.Ptr()); + } + + subscriptExpr->Type = elementType; + + // TODO(tfoley): need to be more careful about this stuff + subscriptExpr->Type.IsLeftValue = baseExpr->Type.IsLeftValue; + + return subscriptExpr; + } + + // The way that we have designed out type system, pretyt much *every* + // type is a reference to some declaration in the standard library. + // That means that when we construct a new type on the fly, we need + // to make sure that it is wired up to reference the appropriate + // declaration, or else it won't compare as equal to other types + // that *do* reference the declaration. + // + // This function is used to construct a `vector<T,N>` type + // programmatically, so that it will work just like a type of + // that form constructed by the user. + RefPtr<VectorExpressionType> createVectorType( + RefPtr<ExpressionType> elementType, + RefPtr<IntVal> elementCount) + { + auto vectorGenericDecl = findMagicDecl("Vector").As<GenericDecl>(); + auto vectorTypeDecl = vectorGenericDecl->inner; + + auto substitutions = new Substitutions(); + substitutions->genericDecl = vectorGenericDecl.Ptr(); + substitutions->args.Add(elementType); + substitutions->args.Add(elementCount); + + auto declRef = DeclRef(vectorTypeDecl.Ptr(), substitutions); + + return DeclRefType::Create(declRef)->As<VectorExpressionType>(); + } + + virtual RefPtr<ExpressionSyntaxNode> VisitIndexExpression(IndexExpressionSyntaxNode* subscriptExpr) override + { + auto baseExpr = subscriptExpr->BaseExpression; + baseExpr = CheckExpr(baseExpr); + + RefPtr<ExpressionSyntaxNode> indexExpr = subscriptExpr->IndexExpression; + if (indexExpr) + { + indexExpr = CheckExpr(indexExpr); + } + + subscriptExpr->BaseExpression = baseExpr; + subscriptExpr->IndexExpression = indexExpr; + + // If anything went wrong in the base expression, + // then just move along... + if (IsErrorExpr(baseExpr)) + return CreateErrorExpr(subscriptExpr); + + // Otherwise, we need to look at the type of the base expression, + // to figure out how subscripting should work. + auto baseType = baseExpr->Type.Ptr(); + if (auto baseTypeType = baseType->As<TypeType>()) + { + // We are trying to "index" into a type, so we have an expression like `float[2]` + // which should be interpreted as resolving to an array type. + + RefPtr<IntVal> elementCount = nullptr; + if (indexExpr) + { + elementCount = CheckIntegerConstantExpression(indexExpr.Ptr()); + } + + auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->type)); + auto arrayType = new ArrayExpressionType(); + arrayType->BaseType = elementType; + arrayType->ArrayLength = elementCount; + + typeResult = arrayType; + subscriptExpr->Type = new TypeType(arrayType); + return subscriptExpr; + } + else if (auto baseArrayType = baseType->As<ArrayExpressionType>()) + { + return CheckSimpleSubscriptExpr( + subscriptExpr, + baseArrayType->BaseType); + } + else if (auto vecType = baseType->As<VectorExpressionType>()) + { + return CheckSimpleSubscriptExpr( + subscriptExpr, + vecType->elementType); + } + else if (auto matType = baseType->As<MatrixExpressionType>()) + { + // TODO(tfoley): We shouldn't go and recompute + // row types over and over like this... :( + auto rowType = createVectorType( + matType->getElementType(), + matType->getColumnCount()); + + return CheckSimpleSubscriptExpr( + subscriptExpr, + rowType); + } + + // Default behavior is to look at all available `__subscript` + // declarations on the type and try to call one of them. + + if (auto declRefType = baseType->AsDeclRefType()) + { + if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>()) + { + // Checking of the type must be complete before we can reference its members safely + EnsureDecl(aggTypeDeclRef.GetDecl(), DeclCheckState::Checked); + + // Note(tfoley): The name used for lookup here is a bit magical, since + // it must match what the parser installed in subscript declarations. + LookupResult lookupResult = LookUpLocal("operator[]", aggTypeDeclRef); + if (!lookupResult.isValid()) + { + goto fail; + } + + RefPtr<ExpressionSyntaxNode> subscriptFuncExpr = createLookupResultExpr( + lookupResult, subscriptExpr->BaseExpression, subscriptExpr); + + // Now that we know there is at least one subscript member, + // we will construct a reference to it and try to call it + + RefPtr<InvokeExpressionSyntaxNode> subscriptCallExpr = new InvokeExpressionSyntaxNode(); + subscriptCallExpr->Position = subscriptExpr->Position; + subscriptCallExpr->FunctionExpr = subscriptFuncExpr; + + // TODO(tfoley): This path can support multiple arguments easily + subscriptCallExpr->Arguments.Add(subscriptExpr->IndexExpression); + + return CheckInvokeExprWithCheckedOperands(subscriptCallExpr.Ptr()); + } + } + + fail: + { + getSink()->diagnose(subscriptExpr, Diagnostics::subscriptNonArray, baseType); + return CreateErrorExpr(subscriptExpr); + } + } + + bool MatchArguments(FunctionSyntaxNode * functionNode, List <RefPtr<ExpressionSyntaxNode>> &args) + { + if (functionNode->GetParameters().Count() != args.Count()) + return false; + int i = 0; + for (auto param : functionNode->GetParameters()) + { + if (!param->Type.Equals(args[i]->Type.Ptr())) + return false; + i++; + } + return true; + } + + // Coerce an expression to a specific type that it is expected to have in context + RefPtr<ExpressionSyntaxNode> CoerceExprToType( + RefPtr<ExpressionSyntaxNode> expr, + RefPtr<ExpressionType> type) + { + // TODO(tfoley): clean this up so there is only one version... + return Coerce(type, expr); + } + + // Resolve a call to a function, represented here + // by a symbol with a `FuncType` type. + RefPtr<ExpressionSyntaxNode> ResolveFunctionApp( + RefPtr<FuncType> funcType, + InvokeExpressionSyntaxNode* /*appExpr*/) + { + // TODO(tfoley): Actual checking logic needs to go here... +#if 0 + auto& args = appExpr->Arguments; + List<RefPtr<ParameterSyntaxNode>> params; + RefPtr<ExpressionType> resultType; + if (auto funcDeclRef = funcType->declRef) + { + EnsureDecl(funcDeclRef.GetDecl()); + + params = funcDeclRef->GetParameters().ToArray(); + resultType = funcDecl->ReturnType; + } + else if (auto funcSym = funcType->Func) + { + auto funcDecl = funcSym->SyntaxNode; + EnsureDecl(funcDecl); + + params = funcDecl->GetParameters().ToArray(); + resultType = funcDecl->ReturnType; + } + else if (auto componentFuncSym = funcType->Component) + { + auto componentFuncDecl = componentFuncSym->Implementations.First()->SyntaxNode; + params = componentFuncDecl->GetParameters().ToArray(); + resultType = componentFuncDecl->Type; + } + + auto argCount = args.Count(); + auto paramCount = params.Count(); + if (argCount != paramCount) + { + getSink()->diagnose(appExpr, Diagnostics::unimplemented, "wrong number of arguments for call"); + appExpr->Type = ExpressionType::Error; + return appExpr; + } + + for (int ii = 0; ii < argCount; ++ii) + { + auto arg = args[ii]; + auto param = params[ii]; + + arg = CoerceExprToType(arg, param->Type); + + args[ii] = arg; + } + + assert(resultType); + appExpr->Type = resultType; + return appExpr; +#else + throw "unimplemented"; +#endif + } + + // Resolve a constructor call, formed by apply a type to arguments + RefPtr<ExpressionSyntaxNode> ResolveConstructorApp( + RefPtr<ExpressionType> type, + InvokeExpressionSyntaxNode* appExpr) + { + // TODO(tfoley): Actual checking logic needs to go here... + + appExpr->Type = type; + return appExpr; + } + + + // + + virtual void VisitExtensionDecl(ExtensionDecl* decl) override + { + if (decl->IsChecked(DeclCheckState::Checked)) return; + + decl->SetCheckState(DeclCheckState::CheckingHeader); + decl->targetType = CheckProperType(decl->targetType); + + // TODO: need to check that the target type names a declaration... + + if (auto targetDeclRefType = decl->targetType->As<DeclRefType>()) + { + // Attach our extension to that type as a candidate... + if (auto aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDeclRef>()) + { + auto aggTypeDecl = aggTypeDeclRef.GetDecl(); + decl->nextCandidateExtension = aggTypeDecl->candidateExtensions; + aggTypeDecl->candidateExtensions = decl; + } + else + { + getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here"); + } + } + else if (decl->targetType->Equals(ExpressionType::Error)) + { + // there was an error, so ignore + } + else + { + getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here"); + } + + decl->SetCheckState(DeclCheckState::CheckedHeader); + + // now check the members of the extension + for (auto m : decl->Members) + { + EnsureDecl(m); + } + + decl->SetCheckState(DeclCheckState::Checked); + } + + virtual void VisitConstructorDecl(ConstructorDecl* decl) override + { + if (decl->IsChecked(DeclCheckState::Checked)) return; + decl->SetCheckState(DeclCheckState::CheckingHeader); + + for (auto& paramDecl : decl->GetParameters()) + { + paramDecl->Type = CheckUsableType(paramDecl->Type); + } + decl->SetCheckState(DeclCheckState::CheckedHeader); + + // TODO(tfoley): check body + decl->SetCheckState(DeclCheckState::Checked); + } + + + virtual void visitSubscriptDecl(SubscriptDecl* decl) override + { + if (decl->IsChecked(DeclCheckState::Checked)) return; + decl->SetCheckState(DeclCheckState::CheckingHeader); + + for (auto& paramDecl : decl->GetParameters()) + { + paramDecl->Type = CheckUsableType(paramDecl->Type); + } + + decl->ReturnType = CheckUsableType(decl->ReturnType); + + decl->SetCheckState(DeclCheckState::CheckedHeader); + + decl->SetCheckState(DeclCheckState::Checked); + } + + virtual void visitAccessorDecl(AccessorDecl* decl) override + { + // TODO: check the body! + + decl->SetCheckState(DeclCheckState::Checked); + } + + + // + + struct Constraint + { + Decl* decl; // the declaration of the thing being constraints + RefPtr<Val> val; // the value to which we are constraining it + bool satisfied = false; // Has this constraint been met? + }; + + // A collection of constraints that will need to be satisified (solved) + // in order for checking to suceed. + struct ConstraintSystem + { + List<Constraint> constraints; + }; + + RefPtr<ExpressionType> TryJoinVectorAndScalarType( + RefPtr<VectorExpressionType> vectorType, + RefPtr<BasicExpressionType> scalarType) + { + // Join( vector<T,N>, S ) -> vetor<Join(T,S), N> + // + // That is, the join of a vector and a scalar type is + // a vector type with a joined element type. + auto joinElementType = TryJoinTypes( + vectorType->elementType, + scalarType); + if(!joinElementType) + return nullptr; + + return createVectorType( + joinElementType, + vectorType->elementCount); + } + + bool DoesTypeConformToTrait( + RefPtr<ExpressionType> type, + TraitDeclRef traitDeclRef) + { + // for now look up a conformance member... + if(auto declRefType = type->As<DeclRefType>()) + { + if( auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>() ) + { + for( auto conformanceRef : aggTypeDeclRef.GetMembersOfType<TraitConformanceDeclRef>()) + { + EnsureDecl(conformanceRef.GetDecl()); + + if(traitDeclRef.Equals(conformanceRef.GetTraitDeclRef())) + return true; + } + } + } + + // default is failure + return false; + } + + RefPtr<ExpressionType> TryJoinTypeWithTrait( + RefPtr<ExpressionType> type, + TraitDeclRef traitDeclRef) + { + // The most basic test here should be: does the type declare conformance to the trait. + if(DoesTypeConformToTrait(type, traitDeclRef)) + return type; + + // There is a more nuanced case if `type` is a builtin type, and we need to make it + // conform to a trait that some but not all builtin types support (the main problem + // here is when an operation wants an integer type, but one of our operands is a `float`. + // The HLSL rules will allow that, with implicit conversion, but our default join rules + // will end up picking `float` and we don't want that...). + + // For now we don't handle the hard case and just bail + return nullptr; + } + + // Try to compute the "join" between two types + RefPtr<ExpressionType> TryJoinTypes( + RefPtr<ExpressionType> left, + RefPtr<ExpressionType> right) + { + // Easy case: they are the same type! + if (left->Equals(right)) + return left; + + // We can join two basic types by picking the "better" of the two + if (auto leftBasic = left->As<BasicExpressionType>()) + { + if (auto rightBasic = right->As<BasicExpressionType>()) + { + auto leftFlavor = leftBasic->BaseType; + auto rightFlavor = rightBasic->BaseType; + + // TODO(tfoley): Need a special-case rule here that if + // either operand is of type `half`, then we promote + // to at least `float` + + // Return the one that had higher rank... + if (leftFlavor > rightFlavor) + return left; + else + { + assert(rightFlavor > leftFlavor); + return right; + } + } + + // We can also join a vector and a scalar + if(auto rightVector = right->As<VectorExpressionType>()) + { + return TryJoinVectorAndScalarType(rightVector, leftBasic); + } + } + + // We can join two vector types by joining their element types + // (and also their sizes...) + if( auto leftVector = left->As<VectorExpressionType>()) + { + if(auto rightVector = right->As<VectorExpressionType>()) + { + // Check if the vector sizes match + if(!leftVector->elementCount->EqualsVal(rightVector->elementCount.Ptr())) + return nullptr; + + // Try to join the element types + auto joinElementType = TryJoinTypes( + leftVector->elementType, + rightVector->elementType); + if(!joinElementType) + return nullptr; + + return createVectorType( + joinElementType, + leftVector->elementCount); + } + + // We can also join a vector and a scalar + if(auto rightBasic = right->As<BasicExpressionType>()) + { + return TryJoinVectorAndScalarType(leftVector, rightBasic); + } + } + + // HACK: trying to work trait types in here... + if(auto leftDeclRefType = left->As<DeclRefType>()) + { + if( auto leftTraitRef = leftDeclRefType->declRef.As<TraitDeclRef>() ) + { + // + return TryJoinTypeWithTrait(right, leftTraitRef); + } + } + if(auto rightDeclRefType = right->As<DeclRefType>()) + { + if( auto rightTraitRef = rightDeclRefType->declRef.As<TraitDeclRef>() ) + { + // + return TryJoinTypeWithTrait(left, rightTraitRef); + } + } + + // TODO: all the cases for vectors apply to matrices too! + + // Default case is that we just fail. + return nullptr; + } + + // Try to solve a system of generic constraints. + // The `system` argument provides the constraints. + // The `varSubst` argument provides the list of constraint + // variables that were created for the system. + // + // Returns a new substitution representing the values that + // we solved for along the way. + RefPtr<Substitutions> TrySolveConstraintSystem( + ConstraintSystem* system, + GenericDeclRef genericDeclRef) + { + // For now the "solver" is going to be ridiculously simplistic. + + // The generic itself will have some constraints, so we need to try and solve those too + for( auto constraintDeclRef : genericDeclRef.GetMembersOfType<GenericTypeConstraintDeclRef>() ) + { + if(!TryUnifyTypes(*system, constraintDeclRef.GetSub(), constraintDeclRef.GetSup())) + return nullptr; + } + + // We will loop over the generic parameters, and for + // each we will try to find a way to satisfy all + // the constraints for that parameter + List<RefPtr<Val>> args; + for (auto m : genericDeclRef.GetMembers()) + { + if (auto typeParam = m.As<GenericTypeParamDeclRef>()) + { + RefPtr<ExpressionType> type = nullptr; + for (auto& c : system->constraints) + { + if (c.decl != typeParam.GetDecl()) + continue; + + auto cType = c.val.As<ExpressionType>(); + assert(cType.Ptr()); + + if (!type) + { + type = cType; + } + else + { + auto joinType = TryJoinTypes(type, cType); + if (!joinType) + { + // failure! + return nullptr; + } + type = joinType; + } + + c.satisfied = true; + } + + if (!type) + { + // failure! + return nullptr; + } + args.Add(type); + } + else if (auto valParam = m.As<GenericValueParamDeclRef>()) + { + // TODO(tfoley): maybe support more than integers some day? + // TODO(tfoley): figure out how this needs to interact with + // compile-time integers that aren't just constants... + RefPtr<IntVal> val = nullptr; + for (auto& c : system->constraints) + { + if (c.decl != valParam.GetDecl()) + continue; + + auto cVal = c.val.As<IntVal>(); + assert(cVal.Ptr()); + + if (!val) + { + val = cVal; + } + else + { + if(!val->EqualsVal(cVal.Ptr())) + { + // failure! + return nullptr; + } + } + + c.satisfied = true; + } + + if (!val) + { + // failure! + return nullptr; + } + args.Add(val); + } + else + { + // ignore anything that isn't a generic parameter + } + } + + // Make sure we haven't constructed any spurious constraints + // that we aren't able to satisfy: + for (auto c : system->constraints) + { + if (!c.satisfied) + { + return nullptr; + } + } + + // Consruct a reference to the extension with our constraint variables + // as the + RefPtr<Substitutions> solvedSubst = new Substitutions(); + solvedSubst->genericDecl = genericDeclRef.GetDecl(); + solvedSubst->outer = genericDeclRef.substitutions; + solvedSubst->args = args; + + return solvedSubst; + + +#if 0 + List<RefPtr<Val>> solvedArgs; + for (auto varArg : varSubst->args) + { + if (auto typeVar = dynamic_cast<ConstraintVarType*>(varArg.Ptr())) + { + RefPtr<ExpressionType> type = nullptr; + for (auto& c : system->constraints) + { + if (c.decl != typeVar->declRef.GetDecl()) + continue; + + auto cType = c.val.As<ExpressionType>(); + assert(cType.Ptr()); + + if (!type) + { + type = cType; + } + else + { + if (!type->Equals(cType)) + { + // failure! + return nullptr; + } + } + + c.satisfied = true; + } + + if (!type) + { + // failure! + return nullptr; + } + solvedArgs.Add(type); + } + else if (auto valueVar = dynamic_cast<ConstraintVarInt*>(varArg.Ptr())) + { + // TODO(tfoley): maybe support more than integers some day? + RefPtr<IntVal> val = nullptr; + for (auto& c : system->constraints) + { + if (c.decl != valueVar->declRef.GetDecl()) + continue; + + auto cVal = c.val.As<IntVal>(); + assert(cVal.Ptr()); + + if (!val) + { + val = cVal; + } + else + { + if (val->value != cVal->value) + { + // failure! + return nullptr; + } + } + + c.satisfied = true; + } + + if (!val) + { + // failure! + return nullptr; + } + solvedArgs.Add(val); + } + else + { + // ignore anything that isn't a generic parameter + } + } + + // Make sure we haven't constructed any spurious constraints + // that we aren't able to satisfy: + for (auto c : system->constraints) + { + if (!c.satisfied) + { + return nullptr; + } + } + + RefPtr<Substitutions> newSubst = new Substitutions(); + newSubst->genericDecl = varSubst->genericDecl; + newSubst->outer = varSubst->outer; + newSubst->args = solvedArgs; + return newSubst; + +#endif + } + + // + + struct OverloadCandidate + { + enum class Flavor + { + Func, + Generic, + UnspecializedGeneric, + }; + Flavor flavor; + + enum class Status + { + GenericArgumentInferenceFailed, + Unchecked, + ArityChecked, + FixityChecked, + TypeChecked, + Appicable, + }; + Status status = Status::Unchecked; + + // Reference to the declaration being applied + LookupResultItem item; + + // The type of the result expression if this candidate is selected + RefPtr<ExpressionType> resultType; + + // A system for tracking constraints introduced on generic parameters + ConstraintSystem constraintSystem; + + // How much conversion cost should be considered for this overload, + // when ranking candidates. + ConversionCost conversionCostSum = kConversionCost_None; + }; + + + + // State related to overload resolution for a call + // to an overloaded symbol + struct OverloadResolveContext + { + enum class Mode + { + // We are just checking if a candidate works or not + JustTrying, + + // We want to actually update the AST for a chosen candidate + ForReal, + }; + + RefPtr<AppExprBase> appExpr; + RefPtr<ExpressionSyntaxNode> baseExpr; + + // Are we still trying out candidates, or are we + // checking the chosen one for real? + Mode mode = Mode::JustTrying; + + // We store one candidate directly, so that we don't + // need to do dynamic allocation on the list every time + OverloadCandidate bestCandidateStorage; + OverloadCandidate* bestCandidate = nullptr; + + // Full list of all candidates being considered, in the ambiguous case + List<OverloadCandidate> bestCandidates; + }; + + struct ParamCounts + { + int required; + int allowed; + }; + + // count the number of parameters required/allowed for a callable + ParamCounts CountParameters(FilteredMemberRefList<ParamDeclRef> params) + { + ParamCounts counts = { 0, 0 }; + for (auto param : params) + { + counts.allowed++; + + // No initializer means no default value + // + // TODO(tfoley): The logic here is currently broken in two ways: + // + // 1. We are assuming that once one parameter has a default, then all do. + // This can/should be validated earlier, so that we can assume it here. + // + // 2. We are not handling the possibility of multiple declarations for + // a single function, where we'd need to merge default parameters across + // all the declarations. + if (!param.GetDecl()->Expr) + { + counts.required++; + } + } + return counts; + } + + // count the number of parameters required/allowed for a generic + ParamCounts CountParameters(GenericDeclRef genericRef) + { + ParamCounts counts = { 0, 0 }; + for (auto m : genericRef.GetDecl()->Members) + { + if (auto typeParam = m.As<GenericTypeParamDecl>()) + { + counts.allowed++; + if (!typeParam->initType.Ptr()) + { + counts.required++; + } + } + else if (auto valParam = m.As<GenericValueParamDecl>()) + { + counts.allowed++; + if (!valParam->Expr) + { + counts.required++; + } + } + } + return counts; + } + + bool TryCheckOverloadCandidateArity( + OverloadResolveContext& context, + OverloadCandidate const& candidate) + { + int argCount = context.appExpr->Arguments.Count(); + ParamCounts paramCounts = { 0, 0 }; + switch (candidate.flavor) + { + case OverloadCandidate::Flavor::Func: + paramCounts = CountParameters(candidate.item.declRef.As<CallableDeclRef>().GetParameters()); + break; + + case OverloadCandidate::Flavor::Generic: + paramCounts = CountParameters(candidate.item.declRef.As<GenericDeclRef>()); + break; + + default: + assert(!"unexpected"); + break; + } + + if (argCount >= paramCounts.required && argCount <= paramCounts.allowed) + return true; + + // Emit an error message if we are checking this call for real + if (context.mode != OverloadResolveContext::Mode::JustTrying) + { + if (argCount < paramCounts.required) + { + getSink()->diagnose(context.appExpr, Diagnostics::notEnoughArguments, argCount, paramCounts.required); + } + else + { + assert(argCount > paramCounts.allowed); + getSink()->diagnose(context.appExpr, Diagnostics::tooManyArguments, argCount, paramCounts.allowed); + } + } + + return false; + } + + bool TryCheckOverloadCandidateFixity( + OverloadResolveContext& context, + OverloadCandidate const& candidate) + { + auto expr = context.appExpr; + + auto decl = candidate.item.declRef.decl; + + if(auto prefixExpr = expr.As<PrefixExpr>()) + { + if(decl->HasModifier<PrefixModifier>()) + return true; + + if (context.mode != OverloadResolveContext::Mode::JustTrying) + { + getSink()->diagnose(context.appExpr, Diagnostics::expectedPrefixOperator); + getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); + } + + return false; + } + else if(auto postfixExpr = expr.As<PostfixExpr>()) + { + if(decl->HasModifier<PostfixModifier>()) + return true; + + if (context.mode != OverloadResolveContext::Mode::JustTrying) + { + getSink()->diagnose(context.appExpr, Diagnostics::expectedPostfixOperator); + getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); + } + + return false; + } + else + { + return true; + } + + return false; + } + + bool TryCheckGenericOverloadCandidateTypes( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + auto& args = context.appExpr->Arguments; + + auto genericDeclRef = candidate.item.declRef.As<GenericDeclRef>(); + + int aa = 0; + for (auto memberRef : genericDeclRef.GetMembers()) + { + if (auto typeParamRef = memberRef.As<GenericTypeParamDeclRef>()) + { + auto arg = args[aa++]; + + if (context.mode == OverloadResolveContext::Mode::JustTrying) + { + if (!CanCoerceToProperType(TypeExp(arg))) + { + return false; + } + } + else + { + TypeExp typeExp = CoerceToProperType(TypeExp(arg)); + } + } + else if (auto valParamRef = memberRef.As<GenericValueParamDeclRef>()) + { + auto arg = args[aa++]; + + if (context.mode == OverloadResolveContext::Mode::JustTrying) + { + ConversionCost cost = kConversionCost_None; + if (!CanCoerce(valParamRef.GetType(), arg->Type, &cost)) + { + return false; + } + candidate.conversionCostSum += cost; + } + else + { + arg = Coerce(valParamRef.GetType(), arg); + auto val = ExtractGenericArgInteger(arg); + } + } + else + { + continue; + } + } + + return true; + } + + bool TryCheckOverloadCandidateTypes( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + auto& args = context.appExpr->Arguments; + int argCount = args.Count(); + + List<ParamDeclRef> params; + switch (candidate.flavor) + { + case OverloadCandidate::Flavor::Func: + params = candidate.item.declRef.As<CallableDeclRef>().GetParameters().ToArray(); + break; + + case OverloadCandidate::Flavor::Generic: + return TryCheckGenericOverloadCandidateTypes(context, candidate); + + default: + assert(!"unexpected"); + break; + } + + // Note(tfoley): We might have fewer arguments than parameters in the + // case where one or more parameters had defaults. + assert(argCount <= params.Count()); + + for (int ii = 0; ii < argCount; ++ii) + { + auto& arg = args[ii]; + auto param = params[ii]; + + if (context.mode == OverloadResolveContext::Mode::JustTrying) + { + ConversionCost cost = kConversionCost_None; + if (!CanCoerce(param.GetType(), arg->Type, &cost)) + { + return false; + } + candidate.conversionCostSum += cost; + } + else + { + arg = Coerce(param.GetType(), arg); + } + } + return true; + } + + bool TryCheckOverloadCandidateDirections( + OverloadResolveContext& /*context*/, + OverloadCandidate const& /*candidate*/) + { + // TODO(tfoley): check `in` and `out` markers, as needed. + return true; + } + + // Try to check an overload candidate, but bail out + // if any step fails + void TryCheckOverloadCandidate( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + if (!TryCheckOverloadCandidateArity(context, candidate)) + return; + + candidate.status = OverloadCandidate::Status::ArityChecked; + if (!TryCheckOverloadCandidateFixity(context, candidate)) + return; + + candidate.status = OverloadCandidate::Status::FixityChecked; + if (!TryCheckOverloadCandidateTypes(context, candidate)) + return; + + candidate.status = OverloadCandidate::Status::TypeChecked; + if (!TryCheckOverloadCandidateDirections(context, candidate)) + return; + + candidate.status = OverloadCandidate::Status::Appicable; + } + + // Create the representation of a given generic applied to some arguments + RefPtr<ExpressionSyntaxNode> CreateGenericDeclRef( + RefPtr<ExpressionSyntaxNode> baseExpr, + RefPtr<AppExprBase> appExpr) + { + auto baseDeclRefExpr = baseExpr.As<DeclRefExpr>(); + if (!baseDeclRefExpr) + { + assert(!"unexpected"); + return CreateErrorExpr(appExpr.Ptr()); + } + auto baseGenericRef = baseDeclRefExpr->declRef.As<GenericDeclRef>(); + if (!baseGenericRef) + { + assert(!"unexpected"); + return CreateErrorExpr(appExpr.Ptr()); + } + + RefPtr<Substitutions> subst = new Substitutions(); + subst->genericDecl = baseGenericRef.GetDecl(); + subst->outer = baseGenericRef.substitutions; + + for (auto arg : appExpr->Arguments) + { + subst->args.Add(ExtractGenericArgVal(arg)); + } + + DeclRef innerDeclRef(baseGenericRef.GetInner(), subst); + + return ConstructDeclRefExpr( + innerDeclRef, + nullptr, + appExpr); + } + + // Take an overload candidate that previously got through + // `TryCheckOverloadCandidate` above, and try to finish + // up the work and turn it into a real expression. + // + // If the candidate isn't actually applicable, this is + // where we'd start reporting the issue(s). + RefPtr<ExpressionSyntaxNode> CompleteOverloadCandidate( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + // special case for generic argument inference failure + if (candidate.status == OverloadCandidate::Status::GenericArgumentInferenceFailed) + { + String callString = GetCallSignatureString(context.appExpr); + getSink()->diagnose( + context.appExpr, + Diagnostics::genericArgumentInferenceFailed, + callString); + + String declString = getDeclSignatureString(candidate.item); + getSink()->diagnose(candidate.item.declRef, Diagnostics::genericSignatureTried, declString); + goto error; + } + + context.mode = OverloadResolveContext::Mode::ForReal; + context.appExpr->Type = ExpressionType::Error; + + if (!TryCheckOverloadCandidateArity(context, candidate)) + goto error; + + if (!TryCheckOverloadCandidateFixity(context, candidate)) + goto error; + + if (!TryCheckOverloadCandidateTypes(context, candidate)) + goto error; + + if (!TryCheckOverloadCandidateDirections(context, candidate)) + goto error; + + { + auto baseExpr = ConstructLookupResultExpr( + candidate.item, context.baseExpr, context.appExpr->FunctionExpr); + + switch(candidate.flavor) + { + case OverloadCandidate::Flavor::Func: + context.appExpr->FunctionExpr = baseExpr; + context.appExpr->Type = candidate.resultType; + + // A call may yield an l-value, and we should take a look at the candidate to be sure + if(auto subscriptDeclRef = candidate.item.declRef.As<SubscriptDeclRef>()) + { + for(auto setter : subscriptDeclRef.GetDecl()->GetMembersOfType<SetterDecl>()) + { + context.appExpr->Type.IsLeftValue = true; + } + } + + // TODO: there may be other cases that confer l-value-ness + + return context.appExpr; + break; + + case OverloadCandidate::Flavor::Generic: + return CreateGenericDeclRef(baseExpr, context.appExpr); + break; + + default: + assert(!"unexpected"); + break; + } + } + + + error: + return CreateErrorExpr(context.appExpr.Ptr()); + } + + // Implement a comparison operation between overload candidates, + // so that the better candidate compares as less-than the other + int CompareOverloadCandidates( + OverloadCandidate* left, + OverloadCandidate* right) + { + // If one candidate got further along in validation, pick it + if (left->status != right->status) + return int(right->status) - int(left->status); + + // If both candidates are applicable, then we need to compare + // the costs of their type conversion sequences + if(left->status == OverloadCandidate::Status::Appicable) + { + if (left->conversionCostSum != right->conversionCostSum) + return left->conversionCostSum - right->conversionCostSum; + } + + return 0; + } + + void AddOverloadCandidateInner( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + // Filter our existing candidates, to remove any that are worse than our new one + + bool keepThisCandidate = true; // should this candidate be kept? + + if (context.bestCandidates.Count() != 0) + { + // We have multiple candidates right now, so filter them. + bool anyFiltered = false; + // Note that we are querying the list length on every iteration, + // because we might remove things. + for (int cc = 0; cc < context.bestCandidates.Count(); ++cc) + { + int cmp = CompareOverloadCandidates(&candidate, &context.bestCandidates[cc]); + if (cmp < 0) + { + // our new candidate is better! + + // remove it from the list (by swapping in a later one) + context.bestCandidates.FastRemoveAt(cc); + // and then reduce our index so that we re-visit the same index + --cc; + + anyFiltered = true; + } + else if(cmp > 0) + { + // our candidate is worse! + keepThisCandidate = false; + } + } + // It should not be possible that we removed some existing candidate *and* + // chose not to keep this candidate (otherwise the better-ness relation + // isn't transitive). Therefore we confirm that we either chose to keep + // this candidate (in which case filtering is okay), or we didn't filter + // anything. + assert(keepThisCandidate || !anyFiltered); + } + else if(context.bestCandidate) + { + // There's only one candidate so far + int cmp = CompareOverloadCandidates(&candidate, context.bestCandidate); + if(cmp < 0) + { + // our new candidate is better! + context.bestCandidate = nullptr; + } + else if (cmp > 0) + { + // our candidate is worse! + keepThisCandidate = false; + } + } + + // If our candidate isn't good enough, then drop it + if (!keepThisCandidate) + return; + + // Otherwise we want to keep the candidate + if (context.bestCandidates.Count() > 0) + { + // There were already multiple candidates, and we are adding one more + context.bestCandidates.Add(candidate); + } + else if (context.bestCandidate) + { + // There was a unique best candidate, but now we are ambiguous + context.bestCandidates.Add(*context.bestCandidate); + context.bestCandidates.Add(candidate); + context.bestCandidate = nullptr; + } + else + { + // This is the only candidate worthe keeping track of right now + context.bestCandidateStorage = candidate; + context.bestCandidate = &context.bestCandidateStorage; + } + } + + void AddOverloadCandidate( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + // Try the candidate out, to see if it is applicable at all. + TryCheckOverloadCandidate(context, candidate); + + // Now (potentially) add it to the set of candidate overloads to consider. + AddOverloadCandidateInner(context, candidate); + } + + void AddFuncOverloadCandidate( + LookupResultItem item, + CallableDeclRef funcDeclRef, + OverloadResolveContext& context) + { + EnsureDecl(funcDeclRef.GetDecl()); + + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Func; + candidate.item = item; + candidate.resultType = funcDeclRef.GetResultType(); + + AddOverloadCandidate(context, candidate); + } + + void AddFuncOverloadCandidate( + RefPtr<FuncType> /*funcType*/, + OverloadResolveContext& /*context*/) + { +#if 0 + if (funcType->decl) + { + AddFuncOverloadCandidate(funcType->decl, context); + } + else if (funcType->Func) + { + AddFuncOverloadCandidate(funcType->Func->SyntaxNode, context); + } + else if (funcType->Component) + { + AddComponentFuncOverloadCandidate(funcType->Component, context); + } +#else + throw "unimplemented"; +#endif + } + + void AddCtorOverloadCandidate( + LookupResultItem typeItem, + RefPtr<ExpressionType> type, + ConstructorDeclRef ctorDeclRef, + OverloadResolveContext& context) + { + EnsureDecl(ctorDeclRef.GetDecl()); + + // `typeItem` refers to the type being constructed (the thing + // that was applied as a function) so we need to construct + // a `LookupResultItem` that refers to the constructor instead + + LookupResultItem ctorItem; + ctorItem.declRef = ctorDeclRef; + ctorItem.breadcrumbs = new LookupResultItem::Breadcrumb(LookupResultItem::Breadcrumb::Kind::Member, typeItem.declRef, typeItem.breadcrumbs); + + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Func; + candidate.item = ctorItem; + candidate.resultType = type; + + AddOverloadCandidate(context, candidate); + } + + // If the given declaration has generic parameters, then + // return the corresponding `GenericDecl` that holds the + // parameters, etc. + GenericDecl* GetOuterGeneric(Decl* decl) + { + auto parentDecl = decl->ParentDecl; + if (!parentDecl) return nullptr; + auto parentGeneric = dynamic_cast<GenericDecl*>(parentDecl); + return parentGeneric; + } + + // Try to find a unification for two values + bool TryUnifyVals( + ConstraintSystem& constraints, + RefPtr<Val> fst, + RefPtr<Val> snd) + { + // if both values are types, then unify types + if (auto fstType = fst.As<ExpressionType>()) + { + if (auto sndType = snd.As<ExpressionType>()) + { + return TryUnifyTypes(constraints, fstType, sndType); + } + } + + // if both values are constant integers, then compare them + if (auto fstIntVal = fst.As<ConstantIntVal>()) + { + if (auto sndIntVal = snd.As<ConstantIntVal>()) + { + return fstIntVal->value == sndIntVal->value; + } + } + + // Check if both are integer values in general + if (auto fstInt = fst.As<IntVal>()) + { + if (auto sndInt = snd.As<IntVal>()) + { + auto fstParam = fstInt.As<GenericParamIntVal>(); + auto sndParam = sndInt.As<GenericParamIntVal>(); + + if (fstParam) + TryUnifyIntParam(constraints, fstParam->declRef, sndInt); + if (sndParam) + TryUnifyIntParam(constraints, sndParam->declRef, fstInt); + + if (fstParam || sndParam) + return true; + } + } + + throw "unimplemented"; + + // default: fail + return false; + } + + bool TryUnifySubstitutions( + ConstraintSystem& constraints, + RefPtr<Substitutions> fst, + RefPtr<Substitutions> snd) + { + // They must both be NULL or non-NULL + if (!fst || !snd) + return fst == snd; + + // They must be specializing the same generic + if (fst->genericDecl != snd->genericDecl) + return false; + + // Their arguments must unify + assert(fst->args.Count() == snd->args.Count()); + int argCount = fst->args.Count(); + for (int aa = 0; aa < argCount; ++aa) + { + if (!TryUnifyVals(constraints, fst->args[aa], snd->args[aa])) + return false; + } + + // Their "base" specializations must unify + if (!TryUnifySubstitutions(constraints, fst->outer, snd->outer)) + return false; + + return true; + } + + bool TryUnifyTypeParam( + ConstraintSystem& constraints, + RefPtr<GenericTypeParamDecl> typeParamDecl, + RefPtr<ExpressionType> type) + { + // We want to constrain the given type parameter + // to equal the given type. + Constraint constraint; + constraint.decl = typeParamDecl.Ptr(); + constraint.val = type; + + constraints.constraints.Add(constraint); + + return true; + } + + bool TryUnifyIntParam( + ConstraintSystem& constraints, + RefPtr<GenericValueParamDecl> paramDecl, + RefPtr<IntVal> val) + { + // We want to constrain the given parameter to equal the given value. + Constraint constraint; + constraint.decl = paramDecl.Ptr(); + constraint.val = val; + + constraints.constraints.Add(constraint); + + return true; + } + + bool TryUnifyIntParam( + ConstraintSystem& constraints, + VarDeclBaseRef const& varRef, + RefPtr<IntVal> val) + { + if(auto genericValueParamRef = varRef.As<GenericValueParamDeclRef>()) + { + return TryUnifyIntParam(constraints, genericValueParamRef.GetDecl(), val); + } + else + { + return false; + } + } + + bool TryUnifyTypesByStructuralMatch( + ConstraintSystem& constraints, + RefPtr<ExpressionType> fst, + RefPtr<ExpressionType> snd) + { + if (auto fstDeclRefType = fst->As<DeclRefType>()) + { + auto fstDeclRef = fstDeclRefType->declRef; + + if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(fstDeclRef.GetDecl())) + return TryUnifyTypeParam(constraints, typeParamDecl, snd); + + if (auto sndDeclRefType = snd->As<DeclRefType>()) + { + auto sndDeclRef = sndDeclRefType->declRef; + + if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(sndDeclRef.GetDecl())) + return TryUnifyTypeParam(constraints, typeParamDecl, fst); + + // can't be unified if they refer to differnt declarations. + if (fstDeclRef.GetDecl() != sndDeclRef.GetDecl()) return false; + + // next we need to unify the substitutions applied + // to each decalration reference. + if (!TryUnifySubstitutions( + constraints, + fstDeclRef.substitutions, + sndDeclRef.substitutions)) + { + return false; + } + + return true; + } + } + + return false; + } + + bool TryUnifyTypes( + ConstraintSystem& constraints, + RefPtr<ExpressionType> fst, + RefPtr<ExpressionType> snd) + { + if (fst->Equals(snd)) return true; + + // An error type can unify with anything, just so we avoid cascading errors. + + if (auto fstErrorType = fst->As<ErrorType>()) + return true; + + if (auto sndErrorType = snd->As<ErrorType>()) + return true; + + // A generic parameter type can unify with anything. + // TODO: there actually needs to be some kind of "occurs check" sort + // of thing here... + + if (auto fstDeclRefType = fst->As<DeclRefType>()) + { + auto fstDeclRef = fstDeclRefType->declRef; + + if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(fstDeclRef.GetDecl())) + return TryUnifyTypeParam(constraints, typeParamDecl, snd); + } + + if (auto sndDeclRefType = snd->As<DeclRefType>()) + { + auto sndDeclRef = sndDeclRefType->declRef; + + if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(sndDeclRef.GetDecl())) + return TryUnifyTypeParam(constraints, typeParamDecl, fst); + } + + // If we can unify the types structurally, then we are golden + if(TryUnifyTypesByStructuralMatch(constraints, fst, snd)) + return true; + + // Now we need to consider cases where coercion might + // need to be applied. For now we can try to do this + // in a completely ad hoc fashion, but eventually we'd + // want to do it more formally. + + if(auto fstVectorType = fst->As<VectorExpressionType>()) + { + if(auto sndScalarType = snd->As<BasicExpressionType>()) + { + return TryUnifyTypes( + constraints, + fstVectorType->elementType, + sndScalarType); + } + } + + if(auto fstScalarType = fst->As<BasicExpressionType>()) + { + if(auto sndVectorType = snd->As<VectorExpressionType>()) + { + return TryUnifyTypes( + constraints, + fstScalarType, + sndVectorType->elementType); + } + } + + // TODO: the same thing for vectors... + + return false; + } + + // Is the candidate extension declaration actually applicable to the given type + ExtensionDeclRef ApplyExtensionToType( + ExtensionDecl* extDecl, + RefPtr<ExpressionType> type) + { + if (auto extGenericDecl = GetOuterGeneric(extDecl)) + { + ConstraintSystem constraints; + + if (!TryUnifyTypes(constraints, extDecl->targetType, type)) + return DeclRef().As<ExtensionDeclRef>(); + + auto constraintSubst = TrySolveConstraintSystem(&constraints, DeclRef(extGenericDecl, nullptr).As<GenericDeclRef>()); + if (!constraintSubst) + { + return DeclRef().As<ExtensionDeclRef>(); + } + + // Consruct a reference to the extension with our constraint variables + // set as they were found by solving the constraint system. + ExtensionDeclRef extDeclRef = DeclRef(extDecl, constraintSubst).As<ExtensionDeclRef>(); + + // We expect/require that the result of unification is such that + // the target types are now equal + assert(extDeclRef.GetTargetType()->Equals(type)); + + return extDeclRef; + } + else + { + // The easy case is when the extension isn't generic: + // either it applies to the type or not. + if (!type->Equals(extDecl->targetType)) + return DeclRef().As<ExtensionDeclRef>(); + return DeclRef(extDecl, nullptr).As<ExtensionDeclRef>(); + } + } + + bool TryUnifyArgAndParamTypes( + ConstraintSystem& system, + RefPtr<ExpressionSyntaxNode> argExpr, + ParamDeclRef paramDeclRef) + { + // TODO(tfoley): potentially need a bit more + // nuance in case where argument might be + // an overload group... + return TryUnifyTypes(system, argExpr->Type, paramDeclRef.GetType()); + } + + // Take a generic declaration and try to specialize its parameters + // so that the resulting inner declaration can be applicable in + // a particular context... + DeclRef SpecializeGenericForOverload( + GenericDeclRef genericDeclRef, + OverloadResolveContext& context) + { + ConstraintSystem constraints; + + // Construct a reference to the inner declaration that has any generic + // parameter substitutions in place already, but *not* any substutions + // for the generic declaration we are currently trying to infer. + auto innerDecl = genericDeclRef.GetInner(); + DeclRef unspecializedInnerRef = DeclRef(innerDecl, genericDeclRef.substitutions); + + // Check what type of declaration we are dealing with, and then try + // to match it up with the arguments accordingly... + if (auto funcDeclRef = unspecializedInnerRef.As<CallableDeclRef>()) + { + auto& args = context.appExpr->Arguments; + auto params = funcDeclRef.GetParameters().ToArray(); + + int argCount = args.Count(); + int paramCount = params.Count(); + + // Bail out on mismatch. + // TODO(tfoley): need more nuance here + if (argCount != paramCount) + { + return DeclRef(nullptr, nullptr); + } + + for (int aa = 0; aa < argCount; ++aa) + { +#if 0 + if (!TryUnifyArgAndParamTypes(constraints, args[aa], params[aa])) + return DeclRef(nullptr, nullptr); +#else + // The question here is whether failure to "unify" an argument + // and parameter should lead to immediate failure. + // + // The case that is interesting is if we want to unify, say: + // `vector<float,N>` and `vector<int,3>` + // + // It is clear that we should solve with `N = 3`, and then + // a later step may find that the resulting types aren't + // actually a match. + // + // A more refined approach to "unification" could of course + // see that `int` can convert to `float` and use that fact. + // (and indeed we already use something like this to unify + // `float` and `vector<T,3>`) + // + // So the question is then whether a mismatch during the + // unification step should be taken as an immediate failure... + + TryUnifyArgAndParamTypes(constraints, args[aa], params[aa]); +#endif + } + } + else + { + // TODO(tfoley): any other cases needed here? + return DeclRef(nullptr, nullptr); + } + + auto constraintSubst = TrySolveConstraintSystem(&constraints, genericDeclRef); + if (!constraintSubst) + { + // constraint solving failed + return DeclRef(nullptr, nullptr); + } + + // We can now construct a reference to the inner declaration using + // the solution to our constraints. + return DeclRef(innerDecl, constraintSubst); + } + + void AddAggTypeOverloadCandidates( + LookupResultItem typeItem, + RefPtr<ExpressionType> type, + AggTypeDeclRef aggTypeDeclRef, + OverloadResolveContext& context) + { + for (auto ctorDeclRef : aggTypeDeclRef.GetMembersOfType<ConstructorDeclRef>()) + { + // now work through this candidate... + AddCtorOverloadCandidate(typeItem, type, ctorDeclRef, context); + } + + // Now walk through any extensions we can find for this types + for (auto ext = aggTypeDeclRef.GetCandidateExtensions(); ext; ext = ext->nextCandidateExtension) + { + auto extDeclRef = ApplyExtensionToType(ext, type); + if (!extDeclRef) + continue; + + for (auto ctorDeclRef : extDeclRef.GetMembersOfType<ConstructorDeclRef>()) + { + // TODO(tfoley): `typeItem` here should really reference the extension... + + // now work through this candidate... + AddCtorOverloadCandidate(typeItem, type, ctorDeclRef, context); + } + + // Also check for generic constructors + for (auto genericDeclRef : extDeclRef.GetMembersOfType<GenericDeclRef>()) + { + if (auto ctorDecl = genericDeclRef.GetDecl()->inner.As<ConstructorDecl>()) + { + DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context); + if (!innerRef) + continue; + + ConstructorDeclRef innerCtorRef = innerRef.As<ConstructorDeclRef>(); + + AddCtorOverloadCandidate(typeItem, type, innerCtorRef, context); + + // TODO(tfoley): need a way to do the solving step for the constraint system + } + } + } + } + + void AddTypeOverloadCandidates( + RefPtr<ExpressionType> type, + OverloadResolveContext& context) + { + if (auto declRefType = type->As<DeclRefType>()) + { + if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>()) + { + AddAggTypeOverloadCandidates(LookupResultItem(aggTypeDeclRef), type, aggTypeDeclRef, context); + } + } + } + + void AddDeclRefOverloadCandidates( + LookupResultItem item, + OverloadResolveContext& context) + { + auto declRef = item.declRef; + + if (auto funcDeclRef = item.declRef.As<CallableDeclRef>()) + { + AddFuncOverloadCandidate(item, funcDeclRef, context); + } + else if (auto aggTypeDeclRef = item.declRef.As<AggTypeDeclRef>()) + { + auto type = DeclRefType::Create(aggTypeDeclRef); + AddAggTypeOverloadCandidates(item, type, aggTypeDeclRef, context); + } + else if (auto genericDeclRef = item.declRef.As<GenericDeclRef>()) + { + // Try to infer generic arguments, based on the context + DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context); + + if (innerRef) + { + // If inference works, then we've now got a + // specialized declaration reference we can apply. + + LookupResultItem innerItem; + innerItem.breadcrumbs = item.breadcrumbs; + innerItem.declRef = innerRef; + + AddDeclRefOverloadCandidates(innerItem, context); + } + else + { + // If inference failed, then we need to create + // a candidate that can be used to reflect that fact + // (so we can report a good error) + OverloadCandidate candidate; + candidate.item = item; + candidate.flavor = OverloadCandidate::Flavor::UnspecializedGeneric; + candidate.status = OverloadCandidate::Status::GenericArgumentInferenceFailed; + + AddOverloadCandidateInner(context, candidate); + } + } + else if( auto typeDefDeclRef = item.declRef.As<TypeDefDeclRef>() ) + { + AddTypeOverloadCandidates(typeDefDeclRef.GetType(), context); + } + else + { + // TODO(tfoley): any other cases needed here? + } + } + + void AddOverloadCandidates( + RefPtr<ExpressionSyntaxNode> funcExpr, + OverloadResolveContext& context) + { + auto funcExprType = funcExpr->Type; + + if (auto funcDeclRefExpr = funcExpr.As<DeclRefExpr>()) + { + // The expression referenced a function declaration + AddDeclRefOverloadCandidates(LookupResultItem(funcDeclRefExpr->declRef), context); + } + else if (auto funcType = funcExprType->As<FuncType>()) + { + // TODO(tfoley): deprecate this path... + AddFuncOverloadCandidate(funcType, context); + } + else if (auto overloadedExpr = funcExpr.As<OverloadedExpr>()) + { + auto lookupResult = overloadedExpr->lookupResult2; + assert(lookupResult.isOverloaded()); + for(auto item : lookupResult.items) + { + AddDeclRefOverloadCandidates(item, context); + } + } + else if (auto typeType = funcExprType->As<TypeType>()) + { + // If none of the above cases matched, but we are + // looking at a type, then I suppose we have + // a constructor call on our hands. + // + // TODO(tfoley): are there any meaningful types left + // that aren't declaration references? + AddTypeOverloadCandidates(typeType->type, context); + return; + } + } + + void formatType(StringBuilder& sb, RefPtr<ExpressionType> type) + { + sb << type->ToString(); + } + + void formatVal(StringBuilder& sb, RefPtr<Val> val) + { + sb << val->ToString(); + } + + void formatDeclPath(StringBuilder& sb, DeclRef declRef) + { + // Find the parent declaration + auto parentDeclRef = declRef.GetParent(); + + // If the immediate parent is a generic, then we probably + // want the declaration above that... + auto parentGenericDeclRef = parentDeclRef.As<GenericDeclRef>(); + if(parentGenericDeclRef) + { + parentDeclRef = parentGenericDeclRef.GetParent(); + } + + // Depending on what the parent is, we may want to format things specially + if(auto aggTypeDeclRef = parentDeclRef.As<AggTypeDeclRef>()) + { + formatDeclPath(sb, aggTypeDeclRef); + sb << "."; + } + + sb << declRef.GetName(); + + // If the parent declaration is a generic, then we need to print out its + // signature + if( parentGenericDeclRef ) + { + assert(declRef.substitutions); + assert(declRef.substitutions->genericDecl == parentGenericDeclRef.GetDecl()); + + sb << "<"; + bool first = true; + for(auto arg : declRef.substitutions->args) + { + if(!first) sb << ", "; + formatVal(sb, arg); + first = false; + } + sb << ">"; + } + } + + void formatDeclParams(StringBuilder& sb, DeclRef declRef) + { + if (auto funcDeclRef = declRef.As<CallableDeclRef>()) + { + + // This is something callable, so we need to also print parameter types for overloading + sb << "("; + + bool first = true; + for (auto paramDeclRef : funcDeclRef.GetParameters()) + { + if (!first) sb << ", "; + + formatType(sb, paramDeclRef.GetType()); + + first = false; + + } + + sb << ")"; + } + else if(auto genericDeclRef = declRef.As<GenericDeclRef>()) + { + sb << "<"; + bool first = true; + for (auto paramDeclRef : genericDeclRef.GetMembers()) + { + if(auto genericTypeParam = paramDeclRef.As<GenericTypeParamDeclRef>()) + { + if (!first) sb << ", "; + first = false; + + sb << genericTypeParam.GetName(); + } + else if(auto genericValParam = paramDeclRef.As<GenericValueParamDeclRef>()) + { + if (!first) sb << ", "; + first = false; + + formatType(sb, genericValParam.GetType()); + sb << " "; + sb << genericValParam.GetName(); + } + else + {} + } + sb << ">"; + + formatDeclParams(sb, DeclRef(genericDeclRef.GetInner(), genericDeclRef.substitutions)); + } + else + { + } + } + + void formatDeclSignature(StringBuilder& sb, DeclRef declRef) + { + formatDeclPath(sb, declRef); + formatDeclParams(sb, declRef); + } + + String getDeclSignatureString(DeclRef declRef) + { + StringBuilder sb; + formatDeclSignature(sb, declRef); + return sb.ProduceString(); + } + + String getDeclSignatureString(LookupResultItem item) + { + return getDeclSignatureString(item.declRef); + } + + String GetCallSignatureString(RefPtr<AppExprBase> expr) + { + StringBuilder argsListBuilder; + argsListBuilder << "("; + bool first = true; + for (auto a : expr->Arguments) + { + if (!first) argsListBuilder << ", "; + argsListBuilder << a->Type->ToString(); + first = false; + } + argsListBuilder << ")"; + return argsListBuilder.ProduceString(); + } + + + RefPtr<ExpressionSyntaxNode> ResolveInvoke(InvokeExpressionSyntaxNode * expr) + { + // Look at the base expression for the call, and figure out how to invoke it. + auto funcExpr = expr->FunctionExpr; + auto funcExprType = funcExpr->Type; + + // If we are trying to apply an erroroneous expression, then just bail out now. + if(IsErrorExpr(funcExpr)) + { + return CreateErrorExpr(expr); + } + + OverloadResolveContext context; + context.appExpr = expr; + if (auto funcMemberExpr = funcExpr.As<MemberExpressionSyntaxNode>()) + { + context.baseExpr = funcMemberExpr->BaseExpression; + } + else if (auto funcOverloadExpr = funcExpr.As<OverloadedExpr>()) + { + context.baseExpr = funcOverloadExpr->base; + } + AddOverloadCandidates(funcExpr, context); + + if (context.bestCandidates.Count() > 0) + { + // Things were ambiguous. + + // It might be that things were only ambiguous because + // one of the argument expressions had an error, and + // so a bunch of candidates could match at that position. + // + // If any argument was an error, we skip out on printing + // another message, to avoid cascading errors. + for (auto arg : expr->Arguments) + { + if (IsErrorExpr(arg)) + { + return CreateErrorExpr(expr); + } + } + + String funcName; + if (auto baseVar = funcExpr.As<VarExpressionSyntaxNode>()) + funcName = baseVar->Variable; + else if(auto baseMemberRef = funcExpr.As<MemberExpressionSyntaxNode>()) + funcName = baseMemberRef->MemberName; + + String argsList = GetCallSignatureString(expr); + + if (context.bestCandidates[0].status != OverloadCandidate::Status::Appicable) + { + // There were multple equally-good candidates, but none actually usable. + // We will construct a diagnostic message to help out. + if (funcName.Length() != 0) + { + getSink()->diagnose(expr, Diagnostics::noApplicableOverloadForNameWithArgs, funcName, argsList); + } + else + { + getSink()->diagnose(expr, Diagnostics::noApplicableWithArgs, argsList); + } + } + else + { + // There were multiple applicable candidates, so we need to report them. + + if (funcName.Length() != 0) + { + getSink()->diagnose(expr, Diagnostics::ambiguousOverloadForNameWithArgs, funcName, argsList); + } + else + { + getSink()->diagnose(expr, Diagnostics::ambiguousOverloadWithArgs, argsList); + } + } + + int candidateCount = context.bestCandidates.Count(); + int maxCandidatesToPrint = 10; // don't show too many candidates at once... + int candidateIndex = 0; + for (auto candidate : context.bestCandidates) + { + String declString = getDeclSignatureString(candidate.item); + + declString = declString + "[" + String(candidate.conversionCostSum) + "]"; + + getSink()->diagnose(candidate.item.declRef, Diagnostics::overloadCandidate, declString); + + candidateIndex++; + if (candidateIndex == maxCandidatesToPrint) + break; + } + if (candidateIndex != candidateCount) + { + getSink()->diagnose(expr, Diagnostics::moreOverloadCandidates, candidateCount - candidateIndex); + } + + return CreateErrorExpr(expr); + } + else if (context.bestCandidate) + { + // There was one best candidate, even if it might not have been + // applicable in the end. + // We will report errors for this one candidate, then, to give + // the user the most help we can. + return CompleteOverloadCandidate(context, *context.bestCandidate); + } + else + { + // Nothing at all was found that we could even consider invoking + getSink()->diagnose(expr->FunctionExpr, Diagnostics::expectedFunction); + expr->Type = ExpressionType::Error; + return expr; + } + } + + void AddGenericOverloadCandidate( + LookupResultItem baseItem, + OverloadResolveContext& context) + { + if (auto genericDeclRef = baseItem.declRef.As<GenericDeclRef>()) + { + EnsureDecl(genericDeclRef.GetDecl()); + + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Generic; + candidate.item = baseItem; + candidate.resultType = nullptr; + + AddOverloadCandidate(context, candidate); + } + } + + void AddGenericOverloadCandidates( + RefPtr<ExpressionSyntaxNode> baseExpr, + OverloadResolveContext& context) + { + if(auto baseDeclRefExpr = baseExpr.As<DeclRefExpr>()) + { + auto declRef = baseDeclRefExpr->declRef; + AddGenericOverloadCandidate(LookupResultItem(declRef), context); + } + else if (auto overloadedExpr = baseExpr.As<OverloadedExpr>()) + { + // We are referring to a bunch of declarations, each of which might be generic + LookupResult result; + for (auto item : overloadedExpr->lookupResult2.items) + { + AddGenericOverloadCandidate(item, context); + } + } + else + { + // any other cases? + } + } + + RefPtr<ExpressionSyntaxNode> VisitGenericApp(GenericAppExpr * genericAppExpr) override + { + // We are applying a generic to arguments, but there might be multiple generic + // declarations with the same name, so this becomes a specialized case of + // overload resolution. + + + // Start by checking the base expression and arguments. + auto& baseExpr = genericAppExpr->FunctionExpr; + baseExpr = CheckTerm(baseExpr); + auto& args = genericAppExpr->Arguments; + for (auto& arg : args) + { + arg = CheckTerm(arg); + } + + // If there was an error in the base expression, or in any of + // the arguments, then just bail. + if (IsErrorExpr(baseExpr)) + { + return CreateErrorExpr(genericAppExpr); + } + for (auto argExpr : args) + { + if (IsErrorExpr(argExpr)) + { + return CreateErrorExpr(genericAppExpr); + } + } + + // Otherwise, let's start looking at how to find an overload... + + OverloadResolveContext context; + context.appExpr = genericAppExpr; + context.baseExpr = GetBaseExpr(baseExpr); + + AddGenericOverloadCandidates(baseExpr, context); + + if (context.bestCandidates.Count() > 0) + { + // Things were ambiguous. + if (context.bestCandidates[0].status != OverloadCandidate::Status::Appicable) + { + // There were multple equally-good candidates, but none actually usable. + // We will construct a diagnostic message to help out. + + // TODO(tfoley): print a reasonable message here... + + getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "no applicable generic"); + + return CreateErrorExpr(genericAppExpr); + } + else + { + // There were multiple viable candidates, but that isn't an error: we just need + // to complete all of them and create an overloaded expression as a result. + + LookupResult result; + for (auto candidate : context.bestCandidates) + { + auto candidateExpr = CompleteOverloadCandidate(context, candidate); + } + + throw "what now?"; +// auto overloadedExpr = new OverloadedExpr(); +// return overloadedExpr; + } + } + else if (context.bestCandidate) + { + // There was one best candidate, even if it might not have been + // applicable in the end. + // We will report errors for this one candidate, then, to give + // the user the most help we can. + return CompleteOverloadCandidate(context, *context.bestCandidate); + } + else + { + // Nothing at all was found that we could even consider invoking + getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "expected a generic"); + return CreateErrorExpr(genericAppExpr); + } + + +#if TIMREMOVED + + if (IsErrorExpr(base)) + { + return CreateErrorExpr(typeNode); + } + else if(auto baseDeclRefExpr = base.As<DeclRefExpr>()) + { + auto declRef = baseDeclRefExpr->declRef; + + if (auto genericDeclRef = declRef.As<GenericDeclRef>()) + { + int argCount = typeNode->Args.Count(); + int argIndex = 0; + for (RefPtr<Decl> member : genericDeclRef.GetDecl()->Members) + { + if (auto typeParam = member.As<GenericTypeParamDecl>()) + { + if (argIndex == argCount) + { + // Too few arguments! + + } + + // TODO: checking! + } + else if (auto valParam = member.As<GenericValueParamDecl>()) + { + // TODO: checking + } + else + { + + } + } + if (argIndex != argCount) + { + // Too many arguments! + } + + // Now instantiate the declaration given those arguments + auto type = InstantiateGenericType(genericDeclRef, args); + typeResult = type; + typeNode->Type = new TypeExpressionType(type); + return typeNode; + } + } + else if (auto overloadedExpr = base.As<OverloadedExpr>()) + { + // We are referring to a bunch of declarations, each of which might be generic + LookupResult result; + for (auto item : overloadedExpr->lookupResult2.items) + { + auto applied = TryApplyGeneric(item, typeNode); + if (!applied) + continue; + + AddToLookupResult(result, appliedItem); + } + } + + // TODO: correct diagnostic here! + getSink()->diagnose(typeNode, Diagnostics::expectedAGeneric, base->Type); + return CreateErrorExpr(typeNode); +#endif + } + + RefPtr<ExpressionSyntaxNode> VisitSharedTypeExpr(SharedTypeExpr* expr) override + { + if (!expr->Type.Ptr()) + { + expr->base = CheckProperType(expr->base); + expr->Type = expr->base.exp->Type; + } + return expr; + } + + + + + RefPtr<ExpressionSyntaxNode> CheckExpr(RefPtr<ExpressionSyntaxNode> expr) + { + return expr->Accept(this).As<ExpressionSyntaxNode>(); + } + + RefPtr<ExpressionSyntaxNode> CheckInvokeExprWithCheckedOperands(InvokeExpressionSyntaxNode *expr) + { + + auto rs = ResolveInvoke(expr); + if (auto invoke = dynamic_cast<InvokeExpressionSyntaxNode*>(rs.Ptr())) + { + // if this is still an invoke expression, test arguments passed to inout/out parameter are LValues + if(auto funcType = invoke->FunctionExpr->Type->As<FuncType>()) + { + List<RefPtr<ParameterSyntaxNode>> paramsStorage; + List<RefPtr<ParameterSyntaxNode>> * params = nullptr; + if (auto func = funcType->declRef.GetDecl()) + { + paramsStorage = func->GetParameters().ToArray(); + params = ¶msStorage; + } + if (params) + { + for (int i = 0; i < (*params).Count(); i++) + { + if ((*params)[i]->HasModifier<OutModifier>()) + { + if (i < expr->Arguments.Count() && expr->Arguments[i]->Type->AsBasicType() && + !expr->Arguments[i]->Type.IsLeftValue) + { + getSink()->diagnose(expr->Arguments[i], Diagnostics::argumentExpectedLValue, (*params)[i]->Name); + } + } + } + } + } + } + return rs; + } + + virtual RefPtr<ExpressionSyntaxNode> VisitInvokeExpression(InvokeExpressionSyntaxNode *expr) override + { + // check the base expression first + expr->FunctionExpr = CheckExpr(expr->FunctionExpr); + + // Next check the argument expressions + for (auto & arg : expr->Arguments) + { + arg = CheckExpr(arg); + } + + return CheckInvokeExprWithCheckedOperands(expr); + } + + + virtual RefPtr<ExpressionSyntaxNode> VisitVarExpression(VarExpressionSyntaxNode *expr) override + { + // If we've already resolved this expression, don't try again. + if (expr->declRef) + return expr; + + expr->Type = ExpressionType::Error; + + auto lookupResult = LookUp(expr->Variable, expr->scope); + if (lookupResult.isValid()) + { + return createLookupResultExpr( + lookupResult, + nullptr, + expr); + } + + getSink()->diagnose(expr, Diagnostics::undefinedIdentifier2, expr->Variable); + + return expr; + } + virtual RefPtr<ExpressionSyntaxNode> VisitTypeCastExpression(TypeCastExpressionSyntaxNode * expr) override + { + expr->Expression = expr->Expression->Accept(this).As<ExpressionSyntaxNode>(); + auto targetType = CheckProperType(expr->TargetType); + expr->TargetType = targetType; + + // The way to perform casting depends on the types involved + if (expr->Expression->Type->Equals(ExpressionType::Error.Ptr())) + { + // If the expression being casted has an error type, then just silently succeed + expr->Type = targetType; + return expr; + } + else if (auto targetArithType = targetType->AsArithmeticType()) + { + if (auto exprArithType = expr->Expression->Type->AsArithmeticType()) + { + // Both source and destination types are arithmetic, so we might + // have a valid cast + auto targetScalarType = targetArithType->GetScalarType(); + auto exprScalarType = exprArithType->GetScalarType(); + + if (!IsNumeric(exprScalarType->BaseType)) goto fail; + if (!IsNumeric(targetScalarType->BaseType)) goto fail; + + // TODO(tfoley): this checking is incomplete here, and could + // lead to downstream compilation failures + expr->Type = targetType; + return expr; + } + } + + fail: + // Default: in no other case succeds, then the cast failed and we emit a diagnostic. + getSink()->diagnose(expr, Diagnostics::invalidTypeCast, expr->Expression->Type, targetType->ToString()); + expr->Type = ExpressionType::Error; + return expr; + } +#if TIMREMOVED + virtual RefPtr<ExpressionSyntaxNode> VisitSelectExpression(SelectExpressionSyntaxNode * expr) override + { + auto selectorExpr = expr->SelectorExpr; + selectorExpr = CheckExpr(selectorExpr); + selectorExpr = Coerce(ExpressionType::GetBool(), selectorExpr); + expr->SelectorExpr = selectorExpr; + + // TODO(tfoley): We need a general purpose "join" on types for inferring + // generic argument types for builtins/intrinsics, so this should really + // be using the exact same logic... + // + expr->Expr0 = expr->Expr0->Accept(this).As<ExpressionSyntaxNode>(); + expr->Expr1 = expr->Expr1->Accept(this).As<ExpressionSyntaxNode>(); + if (!expr->Expr0->Type->Equals(expr->Expr1->Type.Ptr())) + { + getSink()->diagnose(expr, Diagnostics::selectValuesTypeMismatch); + } + expr->Type = expr->Expr0->Type; + return expr; + } +#endif + + // Get the type to use when referencing a declaration + QualType GetTypeForDeclRef(DeclRef declRef) + { + return getTypeForDeclRef( + this, + getSink(), + declRef, + &typeResult); + } + + RefPtr<ExpressionSyntaxNode> MaybeDereference(RefPtr<ExpressionSyntaxNode> inExpr) + { + RefPtr<ExpressionSyntaxNode> expr = inExpr; + for (;;) + { + auto& type = expr->Type; + if (auto pointerLikeType = type->As<PointerLikeType>()) + { + type = pointerLikeType->elementType; + + auto derefExpr = new DerefExpr(); + derefExpr->base = expr; + derefExpr->Type = pointerLikeType->elementType; + + // TODO(tfoley): deal with l-value-ness here + + expr = derefExpr; + continue; + } + + // Default case: just use the expression as-is + return expr; + } + } + + RefPtr<ExpressionSyntaxNode> CheckSwizzleExpr( + MemberExpressionSyntaxNode* memberRefExpr, + RefPtr<ExpressionType> baseElementType, + int baseElementCount) + { + RefPtr<SwizzleExpr> swizExpr = new SwizzleExpr(); + swizExpr->Position = memberRefExpr->Position; + swizExpr->base = memberRefExpr->BaseExpression; + + int limitElement = baseElementCount; + + int elementIndices[4]; + int elementCount = 0; + + bool elementUsed[4] = { false, false, false, false }; + bool anyDuplicates = false; + bool anyError = false; + + for (int i = 0; i < memberRefExpr->MemberName.Length(); i++) + { + auto ch = memberRefExpr->MemberName[i]; + int elementIndex = -1; + switch (ch) + { + case 'x': case 'r': elementIndex = 0; break; + case 'y': case 'g': elementIndex = 1; break; + case 'z': case 'b': elementIndex = 2; break; + case 'w': case 'a': elementIndex = 3; break; + default: + // An invalid character in the swizzle is an error + getSink()->diagnose(swizExpr, Diagnostics::unimplemented, "invalid component name for swizzle"); + anyError = true; + continue; + } + + // TODO(tfoley): GLSL requires that all component names + // come from the same "family"... + + // Make sure the index is in range for the source type + if (elementIndex >= limitElement) + { + getSink()->diagnose(swizExpr, Diagnostics::unimplemented, "swizzle component out of range for type"); + anyError = true; + continue; + } + + // Check if we've seen this index before + for (int ee = 0; ee < elementCount; ee++) + { + if (elementIndices[ee] == elementIndex) + anyDuplicates = true; + } + + // add to our list... + elementIndices[elementCount++] = elementIndex; + } + + for (int ee = 0; ee < elementCount; ++ee) + { + swizExpr->elementIndices[ee] = elementIndices[ee]; + } + swizExpr->elementCount = elementCount; + + if (anyError) + { + swizExpr->Type = ExpressionType::Error; + } + else if (elementCount == 1) + { + // single-component swizzle produces a scalar + // + // Note(tfoley): the official HLSL rules seem to be that it produces + // a one-component vector, which is then implicitly convertible to + // a scalar, but that seems like it just adds complexity. + swizExpr->Type = baseElementType; + } + else + { + // TODO(tfoley): would be nice to "re-sugar" type + // here if the input type had a sugared name... + swizExpr->Type = createVectorType( + baseElementType, + new ConstantIntVal(elementCount)); + } + + // A swizzle can be used as an l-value as long as there + // were no duplicates in the list of components + swizExpr->Type.IsLeftValue = !anyDuplicates; + + return swizExpr; + } + + RefPtr<ExpressionSyntaxNode> CheckSwizzleExpr( + MemberExpressionSyntaxNode* memberRefExpr, + RefPtr<ExpressionType> baseElementType, + RefPtr<IntVal> baseElementCount) + { + if (auto constantElementCount = baseElementCount.As<ConstantIntVal>()) + { + return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->value); + } + else + { + getSink()->diagnose(memberRefExpr, Diagnostics::unimplemented, "swizzle on vector of unknown size"); + return CreateErrorExpr(memberRefExpr); + } + } + + + virtual RefPtr<ExpressionSyntaxNode> VisitMemberExpression(MemberExpressionSyntaxNode * expr) override + { + expr->BaseExpression = CheckExpr(expr->BaseExpression); + + expr->BaseExpression = MaybeDereference(expr->BaseExpression); + + auto & baseType = expr->BaseExpression->Type; + + // Note: Checking for vector types before declaration-reference types, + // because vectors are also declaration reference types... + if (auto baseVecType = baseType->AsVectorType()) + { + return CheckSwizzleExpr( + expr, + baseVecType->elementType, + baseVecType->elementCount); + } + else if(auto baseScalarType = baseType->AsBasicType()) + { + // Treat scalar like a 1-element vector when swizzling + return CheckSwizzleExpr( + expr, + baseScalarType, + 1); + } + else if (auto declRefType = baseType->AsDeclRefType()) + { + if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>()) + { + // Checking of the type must be complete before we can reference its members safely + EnsureDecl(aggTypeDeclRef.GetDecl(), DeclCheckState::Checked); + + + LookupResult lookupResult = LookUpLocal(expr->MemberName, aggTypeDeclRef); + if (!lookupResult.isValid()) + { + goto fail; + } + + return createLookupResultExpr( + lookupResult, + expr->BaseExpression, + expr); +#if 0 + DeclRef memberDeclRef(lookupResult.decl, aggTypeDeclRef.substitutions); + return ConstructDeclRefExpr(memberDeclRef, expr->BaseExpression, expr); +#endif + +#if 0 + + + // TODO(tfoley): It is unfortunate that the lookup strategy + // here isn't unified with the ordinary `Scope` case. + // In particular, if we add support for "transparent" declarations, + // etc. here then we would need to add them in ordinary lookup + // as well. + + Decl* memberDecl = nullptr; // The first declaration we found, if any + Decl* secondDecl = nullptr; // Another declaration with the same name, if any + for (auto m : aggTypeDeclRef.GetMembers()) + { + if (m.GetName() != expr->MemberName) + continue; + + if (!memberDecl) + { + memberDecl = m.GetDecl(); + } + else + { + secondDecl = m.GetDecl(); + break; + } + } + + // If we didn't find any member, then we signal an error + if (!memberDecl) + { + expr->Type = ExpressionType::Error; + getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType); + return expr; + } + + // If we found only a single member, then we are fine + if (!secondDecl) + { + // TODO: need to + DeclRef memberDeclRef(memberDecl, aggTypeDeclRef.substitutions); + + expr->declRef = memberDeclRef; + expr->Type = GetTypeForDeclRef(memberDeclRef); + + // When referencing a member variable, the result is an l-value + // if and only if the base expression was. + if (auto memberVarDecl = dynamic_cast<VarDeclBase*>(memberDecl)) + { + expr->Type.IsLeftValue = expr->BaseExpression->Type.IsLeftValue; + } + return expr; + } + + // We found multiple members with the same name, and need + // to resolve the embiguity at some point... + expr->Type = ExpressionType::Error; + getSink()->diagnose(expr, Diagnostics::unimplemented, "ambiguous member reference"); + return expr; + +#endif + +#if 0 + + StructField* field = structDecl->FindField(expr->MemberName); + if (!field) + { + expr->Type = ExpressionType::Error; + getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType); + } + else + expr->Type = field->Type; + + // A reference to a struct member is an l-value if the reference to the struct + // value was also an l-value. + expr->Type.IsLeftValue = expr->BaseExpression->Type.IsLeftValue; + return expr; +#endif + } + + // catch-all + fail: + getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType); + expr->Type = ExpressionType::Error; + return expr; + } + // All remaining cases assume we have a `BasicType` + else if (!baseType->AsBasicType()) + expr->Type = ExpressionType::Error; + else + expr->Type = ExpressionType::Error; + if (!baseType->Equals(ExpressionType::Error.Ptr()) && + expr->Type->Equals(ExpressionType::Error.Ptr())) + { + getSink()->diagnose(expr, Diagnostics::typeHasNoPublicMemberOfName, baseType, expr->MemberName); + } + return expr; + } + SemanticsVisitor & operator = (const SemanticsVisitor &) = delete; + + + // + + virtual RefPtr<ExpressionSyntaxNode> visitInitializerListExpr(InitializerListExpr* expr) override + { + // When faced with an initializer list, we first just check the sub-expressions blindly. + // Actually making them conform to a desired type will wait for when we know the desired + // type based on context. + + for( auto& arg : expr->args ) + { + arg = CheckTerm(arg); + } + + expr->Type = ExpressionType::getInitializerListType(); + + return expr; + } + }; + + SyntaxVisitor * CreateSemanticsVisitor( + DiagnosticSink * err, + CompileOptions const& options) + { + return new SemanticsVisitor(err, options); + } + + // + + // Get the type to use when referencing a declaration + QualType getTypeForDeclRef( + SemanticsVisitor* sema, + DiagnosticSink* sink, + DeclRef declRef, + RefPtr<ExpressionType>* outTypeResult) + { + if( sema ) + { + sema->EnsureDecl(declRef.GetDecl()); + } + + // We need to insert an appropriate type for the expression, based on + // what we found. + if (auto varDeclRef = declRef.As<VarDeclBaseRef>()) + { + QualType qualType; + qualType.type = varDeclRef.GetType(); + qualType.IsLeftValue = true; // TODO(tfoley): allow explicit `const` or `let` variables + return qualType; + } + else if (auto typeAliasDeclRef = declRef.As<TypeDefDeclRef>()) + { + auto type = new NamedExpressionType(typeAliasDeclRef); + *outTypeResult = type; + return new TypeType(type); + } + else if (auto aggTypeDeclRef = declRef.As<AggTypeDeclRef>()) + { + auto type = DeclRefType::Create(aggTypeDeclRef); + *outTypeResult = type; + return new TypeType(type); + } + else if (auto simpleTypeDeclRef = declRef.As<SimpleTypeDeclRef>()) + { + auto type = DeclRefType::Create(simpleTypeDeclRef); + *outTypeResult = type; + return new TypeType(type); + } + else if (auto genericDeclRef = declRef.As<GenericDeclRef>()) + { + auto type = new GenericDeclRefType(genericDeclRef); + *outTypeResult = type; + return new TypeType(type); + } + else if (auto funcDeclRef = declRef.As<CallableDeclRef>()) + { + auto type = new FuncType(); + type->declRef = funcDeclRef; + return type; + } + + if( sink ) + { + sink->diagnose(declRef, Diagnostics::unimplemented, "cannot form reference to this kind of declaration"); + } + return ExpressionType::Error; + } + + QualType getTypeForDeclRef( + DeclRef declRef) + { + RefPtr<ExpressionType> typeResult; + return getTypeForDeclRef(nullptr, nullptr, declRef, &typeResult); + } + + } +}
\ No newline at end of file diff --git a/source/slang/compiled-program.h b/source/slang/compiled-program.h new file mode 100644 index 000000000..7a86dfe90 --- /dev/null +++ b/source/slang/compiled-program.h @@ -0,0 +1,96 @@ +#ifndef BAKER_SL_COMPILED_PROGRAM_H +#define BAKER_SL_COMPILED_PROGRAM_H + +#include "../core/basic.h" +#include "diagnostics.h" +#include "syntax.h" +#include "type-layout.h" + +namespace Slang +{ + namespace Compiler + { +#if 0 + class ShaderMetaData + { + public: + CoreLib::String ShaderName; + CoreLib::EnumerableDictionary<CoreLib::String, CoreLib::RefPtr<ILModuleParameterSet>> ParameterSets; // bindingName->DescSet + }; + + class StageSource + { + public: + String MainCode; + List<unsigned char> BinaryCode; + }; + + class CompiledShaderSource + { + public: + EnumerableDictionary<String, StageSource> Stages; + ShaderMetaData MetaData; + }; +#endif + + void IndentString(StringBuilder & sb, String src); + + struct EntryPointResult + { + String outputSource; + }; + + struct TranslationUnitResult + { + String outputSource; + List<EntryPointResult> entryPoints; + }; + + class CompileResult + { + public: + DiagnosticSink* mSink = nullptr; + +#if 0 + String ScheduleFile; + RefPtr<ILProgram> Program; + EnumerableDictionary<String, CompiledShaderSource> CompiledSource; // shader -> stage -> code +#endif + + // Per-translation-unit results + List<TranslationUnitResult> translationUnits; + +#if 0 + void PrintDiagnostics() + { + for (int i = 0; i < sink.diagnostics.Count(); i++) + { + fprintf(stderr, "%S(%d): %s %d: %S\n", + sink.diagnostics[i].Position.FileName.ToWString(), + sink.diagnostics[i].Position.Line, + getSeverityName(sink.diagnostics[i].severity), + sink.diagnostics[i].ErrorID, + sink.diagnostics[i].Message.ToWString()); + } + } +#endif + + CompileResult() + {} + ~CompileResult() + { + } + DiagnosticSink * GetErrorWriter() + { + return mSink; + } + int GetErrorCount() + { + return mSink->GetErrorCount(); + } + }; + + } +} + +#endif
\ No newline at end of file diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp new file mode 100644 index 000000000..e78de9133 --- /dev/null +++ b/source/slang/compiler.cpp @@ -0,0 +1,659 @@ +// Compiler.cpp : Defines the entry point for the console application. +// +#include "../core/basic.h" +#include "../core/slang-io.h" +#include "compiler.h" +#include "lexer.h" +#include "parameter-binding.h" +#include "parser.h" +#include "preprocessor.h" +#include "syntax-visitors.h" +#include "slang-stdlib.h" + +#include "reflection.h" +#include "emit.h" + +// Utilities for pass-through modes +#include "../../tools/glslang/glslang.h" + + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include <Windows.h> +#undef WIN32_LEAN_AND_MEAN +#undef NOMINMAX +#include <d3dcompiler.h> +#endif + +#ifdef CreateDirectory +#undef CreateDirectory +#endif + +using namespace CoreLib::Basic; +using namespace CoreLib::IO; +using namespace Slang::Compiler; + +namespace Slang +{ + namespace Compiler + { + // + + Profile Profile::LookUp(char const* name) + { + #define PROFILE(TAG, NAME, STAGE, VERSION) if(strcmp(name, #NAME) == 0) return Profile::TAG; + #define PROFILE_ALIAS(TAG, NAME) if(strcmp(name, #NAME) == 0) return Profile::TAG; + #include "profile-defs.h" + + return Profile::Unknown; + } + + + + // + + int compilerInstances = 0; + + class ShaderCompilerImpl : public ShaderCompiler + { + public: + + // Actual context for compilation... :( + struct ExtraContext + { + CompileOptions const* options = nullptr; + TranslationUnitOptions const* translationUnitOptions = nullptr; + + CompileResult* compileResult = nullptr; + + RefPtr<ProgramSyntaxNode> programSyntax; + ProgramLayout* programLayout; + + String sourceText; + String sourcePath; + + CompileOptions const& getOptions() { return *options; } + TranslationUnitOptions const& getTranslationUnitOptions() { return *translationUnitOptions; } + }; + + + String EmitHLSL(ExtraContext& context) + { + if (context.getOptions().passThrough != PassThroughMode::None) + { + return context.sourceText; + } + else + { + // TODO(tfoley): probably need a way to customize the emit logic... + return emitProgram( + context.programSyntax.Ptr(), + context.programLayout, + CodeGenTarget::HLSL); + } + } + + String emitGLSLForEntryPoint(ExtraContext& context, EntryPointOption const& entryPoint) + { + if (context.getOptions().passThrough != PassThroughMode::None) + { + return context.sourceText; + } + else + { + // TODO(tfoley): probably need a way to customize the emit logic... + return emitProgram( + context.programSyntax.Ptr(), + context.programLayout, + CodeGenTarget::GLSL); + } + } + + char const* GetHLSLProfileName(Profile profile) + { + switch(profile.raw) + { + #define PROFILE(TAG, NAME, STAGE, VERSION) case Profile::TAG: return #NAME; + #include "profile-defs.h" + + default: + // TODO: emit an error here! + return "unknown"; + } + } + +#ifdef _WIN32 + void* GetD3DCompilerDLL() + { + // TODO(tfoley): let user specify version of d3dcompiler DLL to use. + static HMODULE d3dCompiler = LoadLibraryA("d3dcompiler_47"); + // TODO(tfoley): handle case where we can't find it gracefully + assert(d3dCompiler); + return d3dCompiler; + } + + List<uint8_t> EmitDXBytecodeForEntryPoint( + ExtraContext& context, + EntryPointOption const& entryPoint) + { + static pD3DCompile D3DCompile_ = nullptr; + if (!D3DCompile_) + { + HMODULE d3dCompiler = (HMODULE)GetD3DCompilerDLL(); + assert(d3dCompiler); + + D3DCompile_ = (pD3DCompile)GetProcAddress(d3dCompiler, "D3DCompile"); + assert(D3DCompile_); + } + + // The HLSL compiler will try to "canonicalize" our input file path, + // and we don't want it to do that, because they it won't report + // the same locations on error messages that we would. + // + // To work around that, we prepend a custom `#line` directive. + + String rawHlslCode = EmitHLSL(context); + + StringBuilder hlslCodeBuilder; + hlslCodeBuilder << "#line 1 \""; + for(auto c : context.sourcePath) + { + char buffer[] = { c, 0 }; + switch(c) + { + default: + hlslCodeBuilder << buffer; + break; + + case '\\': + hlslCodeBuilder << "\\\\"; + } + } + hlslCodeBuilder << "\"\n"; + hlslCodeBuilder << rawHlslCode; + + auto hlslCode = hlslCodeBuilder.ProduceString(); + + ID3DBlob* codeBlob; + ID3DBlob* diagnosticsBlob; + HRESULT hr = D3DCompile_( + hlslCode.begin(), + hlslCode.Length(), + context.sourcePath.begin(), + nullptr, + nullptr, + entryPoint.name.begin(), + GetHLSLProfileName(entryPoint.profile), + 0, + 0, + &codeBlob, + &diagnosticsBlob); + List<uint8_t> data; + if (codeBlob) + { + data.AddRange((uint8_t const*)codeBlob->GetBufferPointer(), (int)codeBlob->GetBufferSize()); + codeBlob->Release(); + } + if (diagnosticsBlob) + { + // TODO(tfoley): need a better policy for how we translate diagnostics + // back into the Slang world (although we should always try to generate + // HLSL that doesn't produce any diagnostics...) + String diagnostics = (char const*) diagnosticsBlob->GetBufferPointer(); + fprintf(stderr, "%s", diagnostics.begin()); + OutputDebugStringA(diagnostics.begin()); + diagnosticsBlob->Release(); + } + if (FAILED(hr)) + { + // TODO(tfoley): What to do on failure? + } + return data; + } + + List<uint8_t> EmitDXBytecode( + ExtraContext& context) + { + if(context.getTranslationUnitOptions().entryPoints.Count() != 1) + { + if(context.getTranslationUnitOptions().entryPoints.Count() == 0) + { + // TODO(tfoley): need to write diagnostics into this whole thing... + fprintf(stderr, "no entry point specified\n"); + } + else + { + fprintf(stderr, "multiple entry points specified\n"); + } + return List<uint8_t>(); + } + + return EmitDXBytecodeForEntryPoint(context, context.getTranslationUnitOptions().entryPoints[0]); + } + + String EmitDXBytecodeAssemblyForEntryPoint( + ExtraContext& context, + EntryPointOption const& entryPoint) + { + static pD3DDisassemble D3DDisassemble_ = nullptr; + if (!D3DDisassemble_) + { + HMODULE d3dCompiler = (HMODULE)GetD3DCompilerDLL(); + assert(d3dCompiler); + + D3DDisassemble_ = (pD3DDisassemble)GetProcAddress(d3dCompiler, "D3DDisassemble"); + assert(D3DDisassemble_); + } + + List<uint8_t> dxbc = EmitDXBytecodeForEntryPoint(context, entryPoint); + if (!dxbc.Count()) + { + return ""; + } + + ID3DBlob* codeBlob; + HRESULT hr = D3DDisassemble_( + &dxbc[0], + dxbc.Count(), + 0, + nullptr, + &codeBlob); + + String result; + if (codeBlob) + { + result = String((char const*) codeBlob->GetBufferPointer()); + codeBlob->Release(); + } + if (FAILED(hr)) + { + // TODO(tfoley): need to figure out what to diagnose here... + } + return result; + } + + + String EmitDXBytecodeAssembly( + ExtraContext& context) + { + if(context.getTranslationUnitOptions().entryPoints.Count() == 0) + { + // TODO(tfoley): need to write diagnostics into this whole thing... + fprintf(stderr, "no entry point specified\n"); + return ""; + } + + StringBuilder sb; + for (auto entryPoint : context.getTranslationUnitOptions().entryPoints) + { + sb << EmitDXBytecodeAssemblyForEntryPoint(context, entryPoint); + } + return sb.ProduceString(); + } + + + HMODULE getGLSLCompilerDLL() + { + // TODO(tfoley): let user specify version of glslang DLL to use. + static HMODULE glslCompiler = LoadLibraryA("glslang"); + // TODO(tfoley): handle case where we can't find it gracefully + assert(glslCompiler); + return glslCompiler; + } + + + String emitSPIRVAssemblyForEntryPoint( + ExtraContext& context, + EntryPointOption const& entryPoint) + { + String rawGLSL = emitGLSLForEntryPoint(context, entryPoint); + + static glslang_CompileFunc glslang_compile = nullptr; + if (!glslang_compile) + { + HMODULE glslCompiler = getGLSLCompilerDLL(); + assert(glslCompiler); + + glslang_compile = (glslang_CompileFunc)GetProcAddress(glslCompiler, "glslang_compile"); + assert(glslang_compile); + } + + StringBuilder diagnosticBuilder; + StringBuilder outputBuilder; + + auto outputFunc = [](char const* text, void* userData) + { + *(StringBuilder*)userData << text; + }; + + glslang_CompileRequest request; + request.sourcePath = context.sourcePath.begin(); + request.sourceText = rawGLSL.begin(); + request.slangStage = (SlangStage) entryPoint.profile.GetStage(); + + request.diagnosticFunc = outputFunc; + request.diagnosticUserData = &diagnosticBuilder; + + request.outputFunc = outputFunc; + request.outputUserData = &outputBuilder; + + int err = glslang_compile(&request); + + String diagnostics = diagnosticBuilder.ProduceString(); + String output = outputBuilder.ProduceString(); + + if(err) + { + OutputDebugStringA(diagnostics.Buffer()); + fprintf(stderr, "%s", diagnostics.Buffer()); + exit(1); + } + + return output; + } +#endif + + String emitSPIRVAssembly( + ExtraContext& context) + { + if(context.getTranslationUnitOptions().entryPoints.Count() == 0) + { + // TODO(tfoley): need to write diagnostics into this whole thing... + fprintf(stderr, "no entry point specified\n"); + return ""; + } + + StringBuilder sb; + for (auto entryPoint : context.getTranslationUnitOptions().entryPoints) + { + sb << emitSPIRVAssemblyForEntryPoint(context, entryPoint); + } + return sb.ProduceString(); + } + + // Do emit logic for a single entry point + EntryPointResult emitEntryPoint(ExtraContext& context, EntryPointOption& entryPoint) + { + EntryPointResult result; + + switch (context.getOptions().Target) + { + case CodeGenTarget::GLSL: + { + String code = emitGLSLForEntryPoint(context, entryPoint); + result.outputSource = code; + } + break; + + case CodeGenTarget::DXBytecode: + { + auto code = EmitDXBytecodeForEntryPoint(context, entryPoint); + + // TODO(tfoley): Need to figure out an appropriate interface + // for returning binary code, in addition to source. +#if 0 + if (context.compileResult) + { + StringBuilder sb; + sb.Append((char*) code.begin(), code.Count()); + + String codeString = sb.ProduceString(); + result.outputSource = codeString; + } + else +#endif + { + int col = 0; + for(auto ii : code) + { + if(col != 0) fputs(" ", stdout); + fprintf(stdout, "%02X", ii); + col++; + if(col == 8) + { + fputs("\n", stdout); + col = 0; + } + } + if(col != 0) + { + fputs("\n", stdout); + } + } + return result; + } + break; + + case CodeGenTarget::DXBytecodeAssembly: + { + String code = EmitDXBytecodeAssemblyForEntryPoint(context, entryPoint); + result.outputSource = code; + } + break; + + case CodeGenTarget::SPIRVAssembly: + { + String code = emitSPIRVAssemblyForEntryPoint(context, entryPoint); + result.outputSource = code; + } + break; + + // Note(tfoley): We currently hit this case when compiling the stdlib + case CodeGenTarget::Unknown: + break; + + default: + throw "unimplemented"; + } + + return result; + + + } + + TranslationUnitResult emitTranslationUnitEntryPoints(ExtraContext& context) + { + TranslationUnitResult result; + + for (auto& entryPoint : context.getTranslationUnitOptions().entryPoints) + { + EntryPointResult entryPointResult = emitEntryPoint(context, entryPoint); + + result.entryPoints.Add(entryPointResult); + } + + // The result for the translation unit will just be the concatenation + // of the results for each entry point. This doesn't actually make + // much sense, but it is good enough for now. + StringBuilder sb; + for (auto& entryPointResult : result.entryPoints) + { + sb << entryPointResult.outputSource; + } + + result.outputSource = sb.ProduceString(); + + return result; + } + + // Do emit logic for an entire translation unit, which might + // have zero or more entry points + TranslationUnitResult emitTranslationUnit(ExtraContext& context) + { + // Most of our code generation targets will require us + // to proceed through one entry point at a time, but + // in some cases we can emit an entire translation unit + // in one go. + + switch (context.getOptions().Target) + { + default: + // The default behavior is going to loop over all the entry + // points, and then collect an aggregate result. + return emitTranslationUnitEntryPoints(context); + + case CodeGenTarget::HLSL: + // When targetting HLSL, we can emit the entire translation unit + // as a single HLSL program, and include all the entry points. + { + + String hlsl = EmitHLSL(context); + + TranslationUnitResult result; + result.outputSource = hlsl; + + // Because the user might ask for per-entry-point source, + // we will just attach the same string as the result for + // each entry point. + for( auto& entryPoint : context.getTranslationUnitOptions().entryPoints ) + { + (void)entryPoint; + + EntryPointResult entryPointResult; + entryPointResult.outputSource = hlsl; + result.entryPoints.Add(entryPointResult); + } + + return result; + } + break; + } + } + + TranslationUnitResult DoNewEmitLogic(ExtraContext& context) + { + TranslationUnitResult result = emitTranslationUnit(context); + + // As a bit of a hack, we include a mode where we just + // print things to standard output, so that we can see them + // + // TODO(tfoley): Is this path ever needed/used? + if( !context.compileResult ) + { + fprintf(stdout, "%s", result.outputSource.Buffer()); + } + + return result; + } + + void DoNewEmitLogic( + ExtraContext& context, + CollectionOfTranslationUnits* collectionOfTranslationUnits) + { + switch (context.getOptions().Target) + { + default: + // For most targets, we will do things per-translation-unit + for( auto translationUnit : collectionOfTranslationUnits->translationUnits ) + { + ExtraContext innerContext = context; + innerContext.translationUnitOptions = &translationUnit.options; + innerContext.programSyntax = translationUnit.SyntaxNode; + innerContext.sourcePath = "slang"; // don't have this any more! + innerContext.sourceText = ""; + + TranslationUnitResult translationUnitResult = DoNewEmitLogic(innerContext); + context.compileResult->translationUnits.Add(translationUnitResult); + } + break; + + case CodeGenTarget::ReflectionJSON: + { + String reflectionJSON = emitReflectionJSON(context.programLayout); + + // HACK(tfoley): just print it out since that is what people probably expect. + // TODO: need a way to control where output gets routed across all possible targets. + fprintf(stdout, "%s", reflectionJSON.begin()); + } + break; + } + } + + virtual void Compile( + CompileResult& result, + CollectionOfTranslationUnits* collectionOfTranslationUnits, + const CompileOptions& options) override + { + RefPtr<SyntaxVisitor> visitor = CreateSemanticsVisitor(result.GetErrorWriter(), options); + try + { + for( auto& translationUnit : collectionOfTranslationUnits->translationUnits ) + { + visitor->setSourceLanguage(translationUnit.options.sourceLanguage); + + translationUnit.SyntaxNode->Accept(visitor.Ptr()); + } + if (result.GetErrorCount() > 0) + return; + + // Do binding generation, and then reflection (globally) + // before we move on to any code-generation activites. + GenerateParameterBindings(collectionOfTranslationUnits); + + + // HACK(tfoley): for right now I just want to pretty-print an AST + // into another language, so the whole compiler back-end is just + // getting in the way. + // + // I'm going to bypass it for now and see what I can do: + + ExtraContext extra; + extra.options = &options; + extra.programLayout = collectionOfTranslationUnits->layout.Ptr(); + extra.compileResult = &result; + DoNewEmitLogic(extra, collectionOfTranslationUnits); + } + catch (int) + { + } + catch (...) + { + throw; + } + return; + } + + ShaderCompilerImpl() + { + if (compilerInstances == 0) + { + BasicExpressionType::Init(); + } + compilerInstances++; + } + + ~ShaderCompilerImpl() + { + compilerInstances--; + if (compilerInstances == 0) + { + BasicExpressionType::Finalize(); + SlangStdLib::Finalize(); + } + } + + virtual TranslationUnitResult PassThrough( + String const& sourceText, + String const& sourcePath, + const CompileOptions & options, + TranslationUnitOptions const& translationUnitOptions) override + { + ExtraContext extra; + extra.options = &options; + extra.translationUnitOptions = &translationUnitOptions; + extra.sourcePath = sourcePath; + extra.sourceText = sourceText; + + return DoNewEmitLogic(extra); + } + + }; + + ShaderCompiler * CreateShaderCompiler() + { + return new ShaderCompilerImpl(); + } + + } +} diff --git a/source/slang/compiler.h b/source/slang/compiler.h new file mode 100644 index 000000000..f35d6603c --- /dev/null +++ b/source/slang/compiler.h @@ -0,0 +1,156 @@ +#ifndef RASTER_SHADER_COMPILER_H +#define RASTER_SHADER_COMPILER_H + +#include "../core/basic.h" + +#include "compiled-program.h" +#include "diagnostics.h" +#include "profile.h" +#include "syntax.h" +#include "type-layout.h" + +#include "../../slang.h" + +namespace Slang +{ + namespace Compiler + { + class ILConstOperand; + struct IncludeHandler; + + enum class CompilerMode + { + ProduceLibrary, + ProduceShader, + GenerateChoice + }; + + enum class StageTarget + { + Unknown, + VertexShader, + HullShader, + DomainShader, + GeometryShader, + FragmentShader, + ComputeShader, + }; + + enum class CodeGenTarget + { + Unknown = SLANG_TARGET_UNKNOWN, + GLSL = SLANG_GLSL, + GLSL_Vulkan = SLANG_GLSL_VULKAN, + GLSL_Vulkan_OneDesc = SLANG_GLSL_VULKAN_ONE_DESC, + HLSL = SLANG_HLSL, + SPIRV = SLANG_SPIRV, + SPIRVAssembly = SLANG_SPIRV_ASM, + DXBytecode = SLANG_DXBC, + DXBytecodeAssembly = SLANG_DXBC_ASM, + ReflectionJSON = SLANG_REFLECTION_JSON, + }; + + // Describes an entry point that we've been requested to compile + struct EntryPointOption + { + String name; + Profile profile; + }; + + enum class PassThroughMode : SlangPassThrough + { + None = SLANG_PASS_THROUGH_NONE, // don't pass through: use Slang compiler + HLSL = SLANG_PASS_THROUGH_FXC, // pass through HLSL to `D3DCompile` API +// GLSL, // pass through GLSL to `glslang` library + }; + + // Represents a single source file (either an on-disk file, or a + // "virtual" file passed in as a string) + class SourceFile : public RefObject + { + public: + // The file path for a real file, or the nominal path for a virtual file + String path; + + // The actual contents of the file + String content; + }; + + // Options for a single translation unit being requested by the user + class TranslationUnitOptions + { + public: + SourceLanguage sourceLanguage = SourceLanguage::Unknown; + + // All entry points we've been asked to compile for this translation unit + List<EntryPointOption> entryPoints; + + // The source file(s) that will be compiled to form this translation unit + List<RefPtr<SourceFile> > sourceFiles; + }; + + class CompileOptions + { + public: + CompilerMode Mode = CompilerMode::ProduceShader; + CodeGenTarget Target = CodeGenTarget::Unknown; + StageTarget stage = StageTarget::Unknown; + EnumerableDictionary<String, String> BackendArguments; + + String SymbolToCompile; + String outputName; + List<String> TemplateShaderArguments; + List<String> SearchDirectories; + Dictionary<String, String> PreprocessorDefinitions; + + List<TranslationUnitOptions> translationUnits; + + // the code generation profile we've been asked to use + Profile profile; + + // should we just pass the input to another compiler? + PassThroughMode passThrough = PassThroughMode::None; + + // Flags supplied through the API + SlangCompileFlags flags = 0; + }; + + // This is the representation of a given translation unit + class CompileUnit + { + public: + TranslationUnitOptions options; + RefPtr<ProgramSyntaxNode> SyntaxNode; + }; + + // TODO: pick an appropriate name for this... + class CollectionOfTranslationUnits : public RefObject + { + public: + List<CompileUnit> translationUnits; + + // TODO: this is more output-oriented, but maybe okay to have here... + RefPtr<ProgramLayout> layout; + }; + + class ShaderCompiler : public CoreLib::Basic::Object + { + public: + virtual void Compile( + CompileResult& result, + CollectionOfTranslationUnits* collectionOfTranslationUnits, + const CompileOptions& options) = 0; + + virtual TranslationUnitResult PassThrough( + String const& sourceText, + String const& sourcePath, + const CompileOptions & options, + TranslationUnitOptions const& translationUnitOptions) = 0; + + }; + + ShaderCompiler * CreateShaderCompiler(); + } +} + +#endif
\ No newline at end of file diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h new file mode 100644 index 000000000..3f690a5da --- /dev/null +++ b/source/slang/diagnostic-defs.h @@ -0,0 +1,338 @@ +// + +// The file is meant to be included multiple times, to produce different +// pieces of declaration/definition code related to diagnostic messages +// +// Each diagnostic is declared here with: +// +// DIAGNOSTIC(id, severity, name, messageFormat) +// +// Where `id` is the unique diagnostic ID, `severity` is the default +// severity (from the `Severity` enum), `name` is a name used to refer +// to this diagnostic from code, and `messageFormat` is the default +// (non-localized) message for the diagnostic, with placeholders +// for any arguments. + +#ifndef DIAGNOSTIC +#error Need to #define DIAGNOSTIC(...) before including "DiagnosticDefs.h" +#define DIAGNOSTIC(id, severity, name, messageFormat) /* */ +#endif + +// +// -1 - Notes that decorate another diagnostic. +// + +DIAGNOSTIC(-1, Note, alsoSeePipelineDefinition, "also see pipeline definition"); +DIAGNOSTIC(-1, Note, implicitParameterMatchingFailedBecauseNameNotAccessible, "implicit parameter matching failed because the component of the same name is not accessible from '$0'.\ncheck if you have declared necessary requirements and properly used the 'public' qualifier.") +DIAGNOSTIC(-1, Note, implicitParameterMatchingFailedBecauseShaderDoesNotDefineComponent, "implicit parameter matching failed because shader '$0' does not define component '$1'.") +DIAGNOSTIC(-1, Note, implicitParameterMatchingFailedBecauseTypeMismatch, "implicit parameter matching failed because the component of the same name does not match parameter type '$0'.") +DIAGNOSTIC(-1, Note, noteShaderIsTargetingPipeine, "shader '$0' is targeting pipeline '$1'") +DIAGNOSTIC(-1, Note, seeDefinitionOf, "see definition of '$0'") +DIAGNOSTIC(-1, Note, seeInterfaceDefinitionOf, "see interface definition of '$0'") +DIAGNOSTIC(-1, Note, seeUsingOf, "see using of '$0'") +DIAGNOSTIC(-1, Note, seeDefinitionOfShader, "see definition of shader '$0'") +DIAGNOSTIC(-1, Note, seeInclusionOf, "see inclusion of '$0'") +DIAGNOSTIC(-1, Note, seeModuleBeingUsedIn, "see module '$0' being used in '$1'") +DIAGNOSTIC(-1, Note, seePipelineRequirementDefinition, "see pipeline requirement definition") +DIAGNOSTIC(-1, Note, seePotentialDefinitionOfComponent, "see potential definition of component '$0'") +DIAGNOSTIC(-1, Note, seePreviousDefinition, "see previous definition") +DIAGNOSTIC(-1, Note, seePreviousDefinitionOf, "see previous definition of '$0'") +DIAGNOSTIC(-1, Note, seeRequirementDeclaration, "see requirement declaration") +DIAGNOSTIC(-1, Note, doYouForgetToMakeComponentAccessible, "do you forget to make component '$0' acessible from '$1' (missing public qualifier)?") +// +// 0xxxx - Command line and interaction with host platform APIs. +// + +DIAGNOSTIC( 1, Error, cannotOpenFile, "cannot open file '$0'.") +DIAGNOSTIC( 2, Error, cannotFindFile, "cannot find file '$0'.") +DIAGNOSTIC( 2, Error, unsupportedCompilerMode, "unsupported compiler mode.") +DIAGNOSTIC( 4, Error, cannotWriteOutputFile, "cannot write output file '$0'.") + +// +// 1xxxx - Lexical anaylsis +// + +DIAGNOSTIC(10000, Error, illegalCharacterPrint, "illegal character '$0'"); +DIAGNOSTIC(10000, Error, illegalCharacterHex, "illegal character (0x$0)"); +DIAGNOSTIC(10001, Error, illegalCharacterLiteral, "illegal character literal"); + +DIAGNOSTIC(10002, Warning, octalLiteral, "'0' prefix indicates octal literal") +DIAGNOSTIC(10003, Error, invalidDigitForBase, "invalid digit for base-$1 literal: '$0'") + +DIAGNOSTIC(10004, Error, endOfFileInLiteral, "end of file in literal"); +DIAGNOSTIC(10005, Error, newlineInLiteral, "newline in literal"); + +// +// 15xxx - Preprocessing +// + +// 150xx - conditionals +DIAGNOSTIC(15000, Error, endOfFileInPreprocessorConditional, "end of file encountered during preprocessor conditional") +DIAGNOSTIC(15001, Error, directiveWithoutIf, "'$0' directive without '#if'") +DIAGNOSTIC(15002, Error, directiveAfterElse , "'$0' directive without '#if'") + +DIAGNOSTIC(-1, Note, seeDirective, "see '$0' directive") + +// 151xx - directive parsing +DIAGNOSTIC(15100, Error, expectedPreprocessorDirectiveName, "expected preprocessor directive name") +DIAGNOSTIC(15101, Error, unknownPreprocessorDirective, "unknown preprocessor directive '$0'") +DIAGNOSTIC(15102, Error, expectedTokenInPreprocessorDirective, "expected '$0' in '$1' directive") +DIAGNOSTIC(15102, Error, expected2TokensInPreprocessorDirective, "expected '$0' or '$1' in '$2' directive") +DIAGNOSTIC(15103, Error, unexpectedTokensAfterDirective, "unexpected tokens following '$0' directive") + + +// 152xx - preprocessor expressions +DIAGNOSTIC(15200, Error, expectedTokenInPreprocessorExpression, "expected '$0' in preprocessor expression"); +DIAGNOSTIC(15201, Error, syntaxErrorInPreprocessorExpression, "syntax error in preprocessor expression"); +DIAGNOSTIC(15202, Error, divideByZeroInPreprocessorExpression, "division by zero in preprocessor expression"); +DIAGNOSTIC(15203, Error, expectedTokenInDefinedExpression, "expected '$0' in 'defined' expression"); +DIAGNOSTIC(15204, Warning, directiveExpectsExpression, "'$0' directive requires an expression"); + +DIAGNOSTIC(-1, Note, seeOpeningToken, "see opening '$0'") + +// 153xx - #include +DIAGNOSTIC(15300, Error, includeFailed, "failed to find include file '$0'") +DIAGNOSTIC(-1, Error, noIncludeHandlerSpecified, "no `#include` handler was specified") + +// 154xx - macro definition +DIAGNOSTIC(15400, Warning, macroRedefinition, "redefinition of macro '$0'") +DIAGNOSTIC(15401, Warning, macroNotDefined, "macro '$0' is not defined") +DIAGNOSTIC(15403, Error, expectedTokenInMacroParameters, "expected '$0' in macro parameters") + +// 155xx - macro expansion +DIAGNOSTIC(15500, Warning, expectedTokenInMacroArguments, "expected '$0' in macro invocation") + +// 159xx - user-defined error/warning +DIAGNOSTIC(15900, Error, userDefinedError, "#error: $0") +DIAGNOSTIC(15901, Warning, userDefinedWarning, "#warning: $0") + +// +// 2xxxx - Parsing +// + +DIAGNOSTIC(20003, Error, unexpectedToken, "unexpected $0"); +DIAGNOSTIC(20001, Error, unexpectedTokenExpectedTokenType, "unexpected $0, expected $1"); +DIAGNOSTIC(20001, Error, unexpectedTokenExpectedTokenName, "unexpected $0, expected '$1'"); + +DIAGNOSTIC(0, Error, tokenNameExpectedButEOF, "\"$0\" expected but end of file encountered."); +DIAGNOSTIC(0, Error, tokenTypeExpectedButEOF, "$0 expected but end of file encountered."); +DIAGNOSTIC(20001, Error, tokenNameExpected, "\"$0\" expected"); +DIAGNOSTIC(20001, Error, tokenNameExpectedButEOF2, "\"$0\" expected but end of file encountered."); +DIAGNOSTIC(20001, Error, tokenTypeExpected, "$0 expected"); +DIAGNOSTIC(20001, Error, tokenTypeExpectedButEOF2, "$0 expected but end of file encountered."); +DIAGNOSTIC(20001, Error, typeNameExpectedBut, "unexpected $0, expected type name"); +DIAGNOSTIC(20001, Error, typeNameExpectedButEOF, "type name expected but end of file encountered."); +DIAGNOSTIC(20001, Error, unexpectedEOF, " Unexpected end of file."); +DIAGNOSTIC(20002, Error, syntaxError, "syntax error."); +DIAGNOSTIC(20004, Error, unexpectedTokenExpectedComponentDefinition, "unexpected token '$0', only component definitions are allowed in a shader scope.") +DIAGNOSTIC(20008, Error, invalidOperator, "invalid operator '$0'."); +DIAGNOSTIC(20011, Error, unexpectedColon, "unexpected ':'.") + +// +// 3xxxx - Semantic analysis +// + +DIAGNOSTIC(30001, Error, functionRedefinitionWithArgList, "'$0$1': function redefinition.") +DIAGNOSTIC(30002, Error, parameterAlreadyDefined, "parameter '$0' already defined.") +DIAGNOSTIC(30003, Error, breakOutsideLoop, "'break' must appear inside loop constructs.") +DIAGNOSTIC(30004, Error, continueOutsideLoop, "'continue' must appear inside loop constructs.") +DIAGNOSTIC(30005, Error, whilePredicateTypeError, "'while': expression must evaluate to int.") +DIAGNOSTIC(30006, Error, ifPredicateTypeError, "'if': expression must evaluate to int.") +DIAGNOSTIC(30006, Error, returnNeedsExpression, "'return' should have an expression.") +DIAGNOSTIC(30007, Error, componentReturnTypeMismatch, "expression type '$0' does not match component's type '$1'") +DIAGNOSTIC(30007, Error, functionReturnTypeMismatch, "expression type '$0' does not match function's return type '$1'") +DIAGNOSTIC(30008, Error, variableNameAlreadyDefined, "variable $0 already defined.") +DIAGNOSTIC(30009, Error, invalidTypeVoid, "invalid type 'void'.") +DIAGNOSTIC(30010, Error, whilePredicateTypeError2, "'while': expression must evaluate to int.") +DIAGNOSTIC(30011, Error, assignNonLValue, "left of '=' is not an l-value.") +DIAGNOSTIC(30012, Error, noApplicationUnaryOperator, "no overload found for operator $0 ($1).") +DIAGNOSTIC(30012, Error, noOverloadFoundForBinOperatorOnTypes, "no overload found for operator $0 ($1, $2).") +DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for type '$0'") +DIAGNOSTIC(30014, Error, subscriptIndexNonInteger, "index expression must evaluate to int.") +DIAGNOSTIC(30015, Error, undefinedIdentifier, "'$0': undefined identifier.") +DIAGNOSTIC(30015, Error, undefinedIdentifier2, "undefined identifier '$0'.") +DIAGNOSTIC(30016, Error, parameterCannotBeVoid, "'void' can not be parameter type.") +DIAGNOSTIC(30017, Error, componentNotAccessibleFromShader, "component '$0' is not accessible from shader '$1'.") +DIAGNOSTIC(30019, Error, typeMismatch, "expected an expression of type '$0', got '$1'") +DIAGNOSTIC(30020, Error, importOperatorReturnTypeMismatch, "import operator should return '$1', but the expression has type '$0''. do you forget 'project'?") +DIAGNOSTIC(30021, Error, noApplicationFunction, "$0: no overload takes arguments ($1)") +DIAGNOSTIC(30022, Error, invalidTypeCast, "invalid type cast between \"$0\" and \"$1\".") +DIAGNOSTIC(30023, Error, typeHasNoPublicMemberOfName, "\"$0\" does not have public member \"$1\"."); +DIAGNOSTIC(30025, Error, invalidArraySize, "array size must be larger than zero.") +DIAGNOSTIC(30026, Error, returnInComponentMustComeLast, "'return' can only appear as the last statement in component definition.") +DIAGNOSTIC(30027, Error, noMemberOfNameInType, "'$0' is not a member of '$1'."); +DIAGNOSTIC(30028, Error, forPredicateTypeError, "'for': predicate expression must evaluate to bool.") +DIAGNOSTIC(30030, Error, projectionOutsideImportOperator, "'project': invalid use outside import operator.") +DIAGNOSTIC(30031, Error, projectTypeMismatch, "'project': expression must evaluate to record type '$0'.") +DIAGNOSTIC(30033, Error, invalidTypeForLocalVariable, "cannot declare a local variable of this type.") +DIAGNOSTIC(30035, Error, componentOverloadTypeMismatch, "'$0': type of overloaded component mismatches previous definition.") +DIAGNOSTIC(30041, Error, bitOperationNonIntegral, "bit operation: operand must be integral type.") +DIAGNOSTIC(30047, Error, argumentExpectedLValue, "argument passed to parameter '$0' must be l-value.") +DIAGNOSTIC(30051, Error, invalidValueForArgument, "invalid value for argument '$0'") +DIAGNOSTIC(30052, Error, ordinaryFunctionAsModuleArgument, "ordinary functions not allowed as argument to function-typed module parameter.") +DIAGNOSTIC(30079, Error, selectPrdicateTypeMismatch, "selector must evaluate to bool."); +DIAGNOSTIC(30080, Error, selectValuesTypeMismatch, "the two value expressions in a select clause must have same type."); +DIAGNOSTIC(31040, Error, undefinedTypeName, "undefined type name: '$0'.") +DIAGNOSTIC(32013, Error, circularReferenceNotAllowed, "'$0': circular reference is not allowed."); +DIAGNOSTIC(32014, Error, shaderDoesProvideRequirement, "shader '$0' does not provide '$1' as required by '$2'.") +DIAGNOSTIC(32015, Error, argumentNotAvilableInWorld, "argument '$0' is not available in world '$1' as required by '$2'.") +DIAGNOSTIC(32015, Error, componentNotAvilableInWorld, "component '$0' is not available in world '$1' as required by '$2'.") +DIAGNOSTIC(32047, Error, firstArgumentToImportNotComponent, "first argument of an import operator call does not resolve to a component."); +DIAGNOSTIC(32051, Error, componentTypeNotWhatPipelineRequires, "component '$0' has type '$1', but pipeline '$2' requires it to be '$3'.") +DIAGNOSTIC(32052, Error, shaderDoesNotDefineComponentAsRequiredByPipeline, "shader '$0' does not define '$1' as required by pipeline '$2''.") +DIAGNOSTIC(33001, Error, worldNameAlreadyDefined, "world '$0' is already defined.") +DIAGNOSTIC(33002, Error, explicitPipelineSpecificationRequiredForShader, "explicit pipeline specification required for shader '$0' because multiple pipelines are defined in current context.") +DIAGNOSTIC(33003, Error, cannotDefineComponentsInAPipeline, "cannot define components in a pipeline.") +DIAGNOSTIC(33004, Error, undefinedWorldName, "undefined world name '$0'.") +DIAGNOSTIC(33005, Error, abstractWorldAsTargetOfImport, "abstract world cannot appear as target as an import operator.") + +// Note(tfoley): This is a duplicate of 33004 above. +DIAGNOSTIC(33006, Error, undefinedWorldName2, "undefined world name '$0'.") + +DIAGNOSTIC(33007, Error, importOperatorCircularity, "import operator '$0' creates a circular dependency between world '$1' and '$2'") +DIAGNOSTIC(33009, Error, parametersOnlyAllowedInModules, "parameters can only be defined in modules.") +DIAGNOSTIC(33010, Error, undefinedPipelineName, "pipeline '$0' is undefined.") +DIAGNOSTIC(33011, Error, shaderCircularity, "shader '$0' involves circular reference.") +DIAGNOSTIC(33012, Error, worldIsNotDefinedInPipeline, "'$0' is not a defined world in '$1'.") +DIAGNOSTIC(33013, Error, abstractWorldCannotAppearWithOthers, "abstract world cannot appear with other worlds.") +DIAGNOSTIC(33014, Error, nonAbstractComponentMustHaveImplementation, "non-abstract component must have an implementation.") +DIAGNOSTIC(33016, Error, usingInComponentDefinition, "'using': importing not allowed in component definition.") +DIAGNOSTIC(33018, Error, nameAlreadyDefined, "'$0' is already defined.") +DIAGNOSTIC(33018, Error, shaderAlreadyDefined, "shader '$0' has already been defined.") +DIAGNOSTIC(33019, Error, componentMarkedExportMustHaveWorld, "component '$0': definition marked as 'export' must have an explicitly specified world.") +DIAGNOSTIC(33020, Error, componentIsAlreadyDefined, "'$0' is already defined.") +DIAGNOSTIC(33020, Error, componentIsAlreadyDefinedInThatWorld, "'$0' is already defined at '$1'.") +DIAGNOSTIC(33021, Error, inconsistentSignatureForComponent, "'$0': inconsistent signature.") +DIAGNOSTIC(33022, Error, nameAlreadyDefinedInCurrentScope, "'$0' is already defined in current scope.") +DIAGNOSTIC(33022, Error, parameterNameConflictsWithExistingDefinition, "'$0': parameter name conflicts with existing definition.") +DIAGNOSTIC(33023, Error, parameterOfModuleIsUnassigned, "parameter '$0' of module '$1' is unassigned.") +DIAGNOSTIC(33027, Error, argumentTypeDoesNotMatchParameterType, "argument type ($0) does not match parameter type ($1)") +DIAGNOSTIC(33028, Error, nameIsNotAParameterOfCallee, "'$0' is not a parameter of '$1'.") +DIAGNOSTIC(33029, Error, requirementsClashWithPreviousDef, "'$0': requirement clash with previous definition.") +DIAGNOSTIC(33030, Error, positionArgumentAfterNamed, "positional argument cannot appear after a named argument.") +DIAGNOSTIC(33032, Error, functionRedefinition, "'$0': function redefinition.") +DIAGNOSTIC(33034, Error, recordTypeVariableInImportOperator, "cannot declare a record-typed variable in an import operator.") +DIAGNOSTIC(33037, Error, componetMarkedExportCannotHaveParameters, "component '$0': definition marked as 'export' cannot have parameters.") +DIAGNOSTIC(33039, Error, componentInInputWorldCantHaveCode, "'$0': no code allowed for component defined in input world.") +DIAGNOSTIC(33040, Error, requireWithComputation, "'require': cannot define computation on component requirements.") +DIAGNOSTIC(33042, Error, paramWithComputation, "'param': cannot define computation on parameters.") +DIAGNOSTIC(33041, Error, pipelineOfModuleIncompatibleWithPipelineOfShader, "pipeline '$0' targeted by module '$1' is incompatible with pipeline '$2' targeted by shader '$3'.") +DIAGNOSTIC(33070, Error, expectedFunction, "expression preceding parenthesis of apparent call must have function type.") +DIAGNOSTIC(33071, Error, importOperatorCalledFromAutoPlacedComponent, "cannot call an import operator from an auto-placed component '$0'. try qualify the component with explicit worlds.") +DIAGNOSTIC(33072, Error, noApplicableImportOperator, "'$0' is an import operator defined in pipeline '$1', but none of the import operator overloads converting to world '$2' matches argument list ($3).") +DIAGNOSTIC(33073, Error, importOperatorCalledFromMultiWorldComponent, "cannot call an import operator from a multi-world component definition. consider qualify the component with only one explicit world.") +DIAGNOSTIC(33080, Error, componentTypeDoesNotMatchInterface, "'$0': component type does not match definition in interface '$1'.") +DIAGNOSTIC(33081, Error, shaderDidNotDefineComponentFunction, "shader '$0' did not define component function $1 as required by interface '$2'.") +DIAGNOSTIC(33082, Error, shaderDidNotDefineComponent, "shader '$0' did not define component '$1' as required by interface '$2'.") +DIAGNOSTIC(33083, Error, interfaceImplMustBePublic, "'$0': component fulfilling interface '$1' must be declared as 'public'.") +DIAGNOSTIC(33084, Error, defaultParamNotAllowedInInterface, "'$0': default parameter value not allowed in interface definition.") + +DIAGNOSTIC(33100, Error, componentCantBeComputedAtWorldBecauseDependentNotAvailable, "'$0' cannot be computed at '$1' because the dependent component '$2' is not accessible.") +DIAGNOSTIC(33101, Warning, worldIsNotAValidChoiceForKey, "'$0' is not a valid choice for '$1'.") +DIAGNOSTIC(33102, Error, componentDefinitionCircularity, "component definition '$0' involves circular reference.") +DIAGNOSTIC(34024, Error, componentAlreadyDefinedWhenCompiling, "component named '$0' is already defined when compiling '$1'.") +DIAGNOSTIC(34025, Error, globalComponentConflictWithPreviousDeclaration, "'$0': global component conflicts with previous declaration.") +DIAGNOSTIC(34026, Warning, componentIsAlreadyDefinedUseRequire, "'$0': component is already defined when compiling shader '$1'. use 'require' to declare it as a parameter.") +DIAGNOSTIC(34062, Error, cylicReference, "cyclic reference: $0"); +DIAGNOSTIC(34064, Error, noApplicableImplicitImportOperator, "cannot find import operator to import component '$0' to world '$1' when compiling '$2'.") +DIAGNOSTIC(34065, Error, resourceTypeMustBeParamOrRequire, "'$0': resource typed component must be declared as 'param' or 'require'."); +DIAGNOSTIC(34066, Error, cannotDefineComputationOnResourceType, "'$0': cannot define computation on resource typed component."); + +DIAGNOSTIC(35001, Error, fragDepthAttributeCanOnlyApplyToOutput, "FragDepth attribute can only apply to an output component."); +DIAGNOSTIC(35002, Error, fragDepthAttributeCanOnlyApplyToFloatComponent, "FragDepth attribute can only apply to a float component."); + + +DIAGNOSTIC(36001, Error, insufficientTemplateShaderArguments, "instantiating template shader '$0': insufficient arguments."); +DIAGNOSTIC(36002, Error, tooManyTemplateShaderArguments, "instantiating template shader '$0': too many arguments."); +DIAGNOSTIC(36003, Error, templateShaderArgumentIsNotDefined, "'$0' provided as template shader argument to '$1' is not a defined module."); +DIAGNOSTIC(36004, Error, templateShaderArgumentDidNotImplementRequiredInterface, "module '$0' provided as template shader argument to '$1' did not implement required interface '$2'."); + +// TODO: need to assign numbers to all these extra diagnostics... + +DIAGNOSTIC(39999, Error, expectedIntegerConstantWrongType, "expected integer constant (found: '$0')") +DIAGNOSTIC(39999, Error, expectedIntegerConstantNotConstant, "expression does not evaluate to a compile-time constant") +DIAGNOSTIC(39999, Error, expectedIntegerConstantNotLiteral, "could not extract value from integer constant") + +DIAGNOSTIC(39999, Error, noApplicableOverloadForNameWithArgs, "no overload for '$0' applicable to arguments of type $1") +DIAGNOSTIC(39999, Error, noApplicableWithArgs, "no overload applicable to arguments of type $0") + +DIAGNOSTIC(39999, Error, ambiguousOverloadForNameWithArgs, "ambiguous call to '$0' operation with arguments of type $1") +DIAGNOSTIC(39999, Error, ambiguousOverloadWithArgs, "ambiguous call to overloaded operation with arguments of type $0") + +DIAGNOSTIC(39999, Note, overloadCandidate, "candidate: $0") +DIAGNOSTIC(39999, Note, moreOverloadCandidates, "$0 more overload candidates") + +DIAGNOSTIC(39999, Error, caseOutsideSwitch, "'case' not allowed outside of a 'switch' statement") +DIAGNOSTIC(39999, Error, defaultOutsideSwitch, "'default' not allowed outside of a 'switch' statement") + +DIAGNOSTIC(39999, Error, expectedAGeneric, "expected a generic when using '<...>' (found: '$0')") + +DIAGNOSTIC(39999, Error, genericArgumentInferenceFailed, "could not specialize generic for arguments of type $0") +DIAGNOSTIC(39999, Note, genericSignatureTried, "see declaration of $0") + +DIAGNOSTIC(39999, Error, expectedATraitGot, "expected a trait, got '$0'") + +DIAGNOSTIC(39999, Error, ambiguousReference, "amiguous reference to '$0'"); + +DIAGNOSTIC(39999, Error, declarationDidntDeclareAnything, "declaration does not declare anything"); + + +DIAGNOSTIC(39999, Error, expectedPrefixOperator, "function called as prefix operator was not declared `__prefix`") +DIAGNOSTIC(39999, Error, expectedPostfixOperator, "function called as postfix operator was not declared `__postfix`") + +DIAGNOSTIC(39999, Error, notEnoughArguments, "not enough arguments to call (got $0, expected $1)") +DIAGNOSTIC(39999, Error, tooManyArguments, "too many arguments to call (got $0, expected $1)") + +// +// 4xxxx - IL code generation. +// +DIAGNOSTIC(40001, Error, bindingAlreadyOccupiedByComponent, "resource binding location '$0' is already occupied by component '$1'.") +DIAGNOSTIC(40002, Error, invalidBindingValue, "binding location '$0' is out of valid range.") +DIAGNOSTIC(40003, Error, bindingExceedsLimit, "binding location '$0' assigned to component '$1' exceeds maximum limit.") +DIAGNOSTIC(40004, Error, bindingAlreadyOccupiedByModule, "DescriptorSet ID '$0' is already occupied by module instance '$1'.") +DIAGNOSTIC(40005, Error, topLevelModuleUsedWithoutSpecifyingBinding, "top level module '$0' is being used without specifying binding location. Use [Binding: \"index\"] attribute to provide a binding location.") +// +// 5xxxx - Target code generation. +// + +DIAGNOSTIC(50020, Error, unknownStageType, "Unknown stage type '$0'.") +DIAGNOSTIC(50020, Error, invalidTessCoordType, "TessCoord must have vec2 or vec3 type.") +DIAGNOSTIC(50020, Error, invalidFragCoordType, "FragCoord must be a vec4.") +DIAGNOSTIC(50020, Error, invalidInvocationIdType, "InvocationId must have int type.") +DIAGNOSTIC(50020, Error, invalidThreadIdType, "ThreadId must have int type.") +DIAGNOSTIC(50020, Error, invalidPrimitiveIdType, "PrimitiveId must have int type.") +DIAGNOSTIC(50020, Error, invalidPatchVertexCountType, "PatchVertexCount must have int type.") +DIAGNOSTIC(50022, Error, worldIsNotDefined, "world '$0' is not defined."); +DIAGNOSTIC(50023, Error, stageShouldProvideWorldAttribute, "'$0' should provide 'World' attribute."); +DIAGNOSTIC(50040, Error, componentHasInvalidTypeForPositionOutput, "'$0': component used as 'Position' output must be of vec4 type.") +DIAGNOSTIC(50041, Error, componentNotDefined, "'$0': component not defined.") + +DIAGNOSTIC(50052, Error, domainShaderRequiresControlPointCount, "'DomainShader' requires attribute 'ControlPointCount'."); +DIAGNOSTIC(50052, Error, hullShaderRequiresControlPointCount, "'HullShader' requires attribute 'ControlPointCount'.") +DIAGNOSTIC(50052, Error, hullShaderRequiresControlPointWorld, "'HullShader' requires attribute 'ControlPointWorld'."); +DIAGNOSTIC(50052, Error, hullShaderRequiresCornerPointWorld, "'HullShader' requires attribute 'CornerPointWorld'."); +DIAGNOSTIC(50052, Error, hullShaderRequiresDomain, "'HullShader' requires attribute 'Domain'."); +DIAGNOSTIC(50052, Error, hullShaderRequiresInputControlPointCount, "'HullShader' requires attribute 'InputControlPointCount'.") +DIAGNOSTIC(50052, Error, hullShaderRequiresOutputTopology, "'HullShader' requires attribute 'OutputTopology'.") +DIAGNOSTIC(50052, Error, hullShaderRequiresPartitioning, "'HullShader' requires attribute 'Partitioning'.") +DIAGNOSTIC(50052, Error, hullShaderRequiresPatchWorld, "'HullShader' requires attribute 'PatchWorld'."); +DIAGNOSTIC(50052, Error, hullShaderRequiresTessLevelInner, "'HullShader' requires attribute 'TessLevelInner'.") +DIAGNOSTIC(50052, Error, hullShaderRequiresTessLevelOuter, "'HullShader' requires attribute 'TessLevelOuter'.") + +DIAGNOSTIC(50053, Error, invalidTessellationDomian, "'Domain' should be either 'triangles' or 'quads'."); +DIAGNOSTIC(50053, Error, invalidTessellationOutputTopology, "'OutputTopology' must be one of: 'point', 'line', 'triangle_cw', or 'triangle_ccw'."); +DIAGNOSTIC(50053, Error, invalidTessellationPartitioning, "'Partitioning' must be one of: 'integer', 'pow2', 'fractional_even', or 'fractional_odd'.") +DIAGNOSTIC(50053, Error, invalidTessellationDomain, "'Domain' should be either 'triangles' or 'quads'.") + +DIAGNOSTIC(50082, Error, importingFromPackedBufferUnsupported, "importing type '$0' from PackedBuffer is not supported by the GLSL backend.") +DIAGNOSTIC(51090, Error, cannotGenerateCodeForExternComponentType, "cannot generate code for extern component type '$0'.") +DIAGNOSTIC(51091, Error, typeCannotBePlacedInATexture, "type '$0' cannot be placed in a texture.") +DIAGNOSTIC(51092, Error, stageDoesntHaveInputWorld, "'$0' doesn't appear to have any input world"); + + +// 99999 - Internal compiler errors, and not-yet-classified diagnostics. + +DIAGNOSTIC(99999, Internal, internalCompilerError, "internal compiler error") +DIAGNOSTIC(99999, Internal, unimplemented, "unimplemented feature: $0") + +#undef DIAGNOSTIC diff --git a/source/slang/diagnostics.cpp b/source/slang/diagnostics.cpp new file mode 100644 index 000000000..d8527466b --- /dev/null +++ b/source/slang/diagnostics.cpp @@ -0,0 +1,204 @@ +// Diagnostics.cpp +#include "Diagnostics.h" + +#include "Syntax.h" + +#include <assert.h> + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include <Windows.h> +#undef WIN32_LEAN_AND_MEAN +#undef NOMINMAX +#include <d3dcompiler.h> +#endif + +namespace Slang { +namespace Compiler { + +void printDiagnosticArg(StringBuilder& sb, char const* str) +{ + sb << str; +} + +void printDiagnosticArg(StringBuilder& sb, int str) +{ + sb << str; +} + +void printDiagnosticArg(StringBuilder& sb, CoreLib::Basic::String const& str) +{ + sb << str; +} + +void printDiagnosticArg(StringBuilder& sb, Decl* decl) +{ + sb << decl->Name.Content; +} + +void printDiagnosticArg(StringBuilder& sb, Type* type) +{ + sb << type->DataType->ToString(); +} + +void printDiagnosticArg(StringBuilder& sb, ExpressionType* type) +{ + sb << type->ToString(); +} + +void printDiagnosticArg(StringBuilder& sb, TypeExp const& type) +{ + sb << type.type->ToString(); +} + +void printDiagnosticArg(StringBuilder& sb, QualType const& type) +{ + sb << type.type->ToString(); +} + +void printDiagnosticArg(StringBuilder& sb, TokenType tokenType) +{ + sb << TokenTypeToString(tokenType); +} + +void printDiagnosticArg(StringBuilder& sb, Token const& token) +{ + sb << token.Content; +} + +CodePosition const& getDiagnosticPos(SyntaxNode const* syntax) +{ + return syntax->Position; +} + +CodePosition const& getDiagnosticPos(Token const& token) +{ + return token.Position; +} + +CodePosition const& getDiagnosticPos(TypeExp const& typeExp) +{ + return typeExp.exp->Position; +} + +// Take the format string for a diagnostic message, along with its arguments, and turn it into a +static void formatDiagnosticMessage(StringBuilder& sb, char const* format, int argCount, DiagnosticArg const* const* args) +{ + char const* spanBegin = format; + for(;;) + { + char const* spanEnd = spanBegin; + while (int c = *spanEnd) + { + if (c == '$') + break; + spanEnd++; + } + + sb.Append(spanBegin, int(spanEnd - spanBegin)); + if (!*spanEnd) + return; + + assert(*spanEnd == '$'); + spanEnd++; + int d = *spanEnd++; + switch (d) + { + // A double dollar sign `$$` is used to emit a single `$` + case '$': + sb.Append('$'); + break; + + // A single digit means to emit the corresponding argument. + // TODO: support more than 10 arguments, and add options + // to control formatting, etc. + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + { + int index = d - '0'; + if (index >= argCount) + { + // TODO(tfoley): figure out what a good policy will be for "panic" situations like this + throw InvalidOperationException("too few arguments for diagnostic message"); + } + else + { + DiagnosticArg const* arg = args[index]; + arg->printFunc(sb, arg->data); + } + } + break; + + default: + throw InvalidOperationException("invalid diagnostic message format"); + break; + } + + spanBegin = spanEnd; + } +} + +static void formatDiagnostic( + StringBuilder& sb, + Diagnostic const& diagnostic) +{ + sb << diagnostic.Position.FileName; + sb << "("; + sb << diagnostic.Position.Line; + sb << "): "; + sb << getSeverityName(diagnostic.severity); + sb << " "; + sb << diagnostic.ErrorID; + sb << ": "; + sb << diagnostic.Message; + sb << "\n"; +} + +void DiagnosticSink::diagnoseImpl(CodePosition const& pos, DiagnosticInfo const& info, int argCount, DiagnosticArg const* const* args) +{ + StringBuilder sb; + formatDiagnosticMessage(sb, info.messageFormat, argCount, args); + + Diagnostic diagnostic; + diagnostic.ErrorID = info.id; + diagnostic.Message = sb.ProduceString(); + diagnostic.Position = pos; + diagnostic.severity = info.severity; + + if (diagnostic.severity >= Severity::Error) + { + errorCount++; + } + + // Did the client supply a callback for us to use? + if( callback ) + { + // If so, pass the error string along to them + StringBuilder sb; + formatDiagnostic(sb, diagnostic); + + callback(sb.ProduceString().begin(), callbackUserData); + } + else + { + // If the user doesn't have a callback, then just + // collect our diagnostic messages into a buffer + formatDiagnostic(outputBuffer, diagnostic); + } + + if (diagnostic.severity >= Severity::Fatal) + { + // TODO: figure out a better policy for aborting compilation + throw InvalidOperationException(); + } +} + +namespace Diagnostics +{ +#define DIAGNOSTIC(id, severity, name, messageFormat) const DiagnosticInfo name = { id, Severity::severity, messageFormat }; +#include "diagnostic-defs.h" +} + + +}} // namespace Slang::Compiler diff --git a/source/slang/diagnostics.h b/source/slang/diagnostics.h new file mode 100644 index 000000000..c1559df5d --- /dev/null +++ b/source/slang/diagnostics.h @@ -0,0 +1,218 @@ +#ifndef RASTER_RENDERER_COMPILE_ERROR_H +#define RASTER_RENDERER_COMPILE_ERROR_H + +#include "../core/basic.h" + +#include "source-loc.h" +#include "token.h" + +#include "../../slang.h" + +namespace Slang +{ + namespace Compiler + { + using namespace CoreLib::Basic; + + enum class Severity + { + Note, + Warning, + Error, + Fatal, + Internal, + }; + + // TODO(tfoley): move this into a source file... + inline const char* getSeverityName(Severity severity) + { + switch (severity) + { + case Severity::Note: return "note"; + case Severity::Warning: return "warning"; + case Severity::Error: return "error"; + case Severity::Fatal: return "fatal error"; + case Severity::Internal: return "internal error"; + default: return "unknown error"; + } + } + + // A structure to be used in static data describing different + // diagnostic messages. + struct DiagnosticInfo + { + int id; + Severity severity; + char const* messageFormat; + }; + + class Diagnostic + { + public: + String Message; + CodePosition Position; + int ErrorID; + Severity severity; + + Diagnostic() + { + ErrorID = -1; + } + Diagnostic( + const String & msg, + int id, + const CodePosition & pos, + Severity severity) + : severity(severity) + { + Message = msg; + ErrorID = id; + Position = pos; + } + }; + + class Decl; + class Type; + class ExpressionType; + class ILType; + class StageAttribute; + struct TypeExp; + struct QualType; + + void printDiagnosticArg(StringBuilder& sb, char const* str); + void printDiagnosticArg(StringBuilder& sb, int val); + void printDiagnosticArg(StringBuilder& sb, CoreLib::Basic::String const& str); + void printDiagnosticArg(StringBuilder& sb, Decl* decl); + void printDiagnosticArg(StringBuilder& sb, Type* type); + void printDiagnosticArg(StringBuilder& sb, ExpressionType* type); + void printDiagnosticArg(StringBuilder& sb, TypeExp const& type); + void printDiagnosticArg(StringBuilder& sb, QualType const& type); + void printDiagnosticArg(StringBuilder& sb, TokenType tokenType); + void printDiagnosticArg(StringBuilder& sb, Token const& token); + + template<typename T> + void printDiagnosticArg(StringBuilder& sb, RefPtr<T> ptr) + { + printDiagnosticArg(sb, ptr.Ptr()); + } + + inline CodePosition const& getDiagnosticPos(CodePosition const& pos) { return pos; } + + class SyntaxNode; + class ShaderClosure; + CodePosition const& getDiagnosticPos(SyntaxNode const* syntax); + CodePosition const& getDiagnosticPos(Token const& token); + CodePosition const& getDiagnosticPos(TypeExp const& typeExp); + + template<typename T> + CodePosition getDiagnosticPos(RefPtr<T> const& ptr) + { + return getDiagnosticPos(ptr.Ptr()); + } + + struct DiagnosticArg + { + void* data; + void (*printFunc)(StringBuilder&, void*); + + template<typename T> + struct Helper + { + static void printFunc(StringBuilder& sb, void* data) { printDiagnosticArg(sb, *(T*)data); } + }; + + template<typename T> + DiagnosticArg(T const& arg) + : data((void*)&arg) + , printFunc(&Helper<T>::printFunc) + {} + }; + + class DiagnosticSink + { + public: + StringBuilder outputBuffer; +// List<Diagnostic> diagnostics; + int errorCount = 0; + + SlangDiagnosticCallback callback = nullptr; + void* callbackUserData = nullptr; + +/* + void Error(int id, const String & msg, const CodePosition & pos) + { + diagnostics.Add(Diagnostic(msg, id, pos, Severity::Error)); + errorCount++; + } + + void Warning(int id, const String & msg, const CodePosition & pos) + { + diagnostics.Add(Diagnostic(msg, id, pos, Severity::Warning)); + } +*/ + int GetErrorCount() { return errorCount; } + + void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info) + { + diagnoseImpl(pos, info, 0, NULL); + } + + void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0) + { + DiagnosticArg const* args[] = { &arg0 }; + diagnoseImpl(pos, info, 1, args); + } + + void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1) + { + DiagnosticArg const* args[] = { &arg0, &arg1 }; + diagnoseImpl(pos, info, 2, args); + } + + void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1, DiagnosticArg const& arg2) + { + DiagnosticArg const* args[] = { &arg0, &arg1, &arg2 }; + diagnoseImpl(pos, info, 3, args); + } + + void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1, DiagnosticArg const& arg2, DiagnosticArg const& arg3) + { + DiagnosticArg const* args[] = { &arg0, &arg1, &arg2, &arg3 }; + diagnoseImpl(pos, info, 4, args); + } + + template<typename P, typename... Args> + void diagnose(P const& pos, DiagnosticInfo const& info, Args const&... args ) + { + diagnoseDispatch(getDiagnosticPos(pos), info, args...); + } + + void diagnoseImpl(CodePosition const& pos, DiagnosticInfo const& info, int argCount, DiagnosticArg const* const* args); + }; + + namespace Diagnostics + { +#define DIAGNOSTIC(id, severity, name, messageFormat) extern const DiagnosticInfo name; +#include "diagnostic-defs.h" + } + } +} + +#ifdef _DEBUG +#define SLANG_INTERNAL_ERROR(sink, pos) \ + (sink)->diagnose(Slang::Compiler::CodePosition(__LINE__, 0, 0, __FILE__), Slang::Compiler::Diagnostics::internalCompilerError) +#define SLANG_UNIMPLEMENTED(sink, pos, what) \ + (sink)->diagnose(Slang::Compiler::CodePosition(__LINE__, 0, 0, __FILE__), Slang::Compiler::Diagnostics::unimplemented, what) + +#define SLANG_UNREACHABLE(msg) do { assert(!"ureachable code:" msg); exit(1); } while(0) +#else +#define SLANG_INTERNAL_ERROR(sink, pos) \ + (sink)->diagnose(pos, Slang::Compiler::Diagnostics::internalCompilerError) +#define SLANG_UNIMPLEMENTED(sink, pos, what) \ + (sink)->diagnose(pos, Slang::Compiler::Diagnostics::unimplemented, what) + +// TODO: find something that will perform better +#define SLANG_UNREACHABLE(msg) exit(1) +#endif + +#endif diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp new file mode 100644 index 000000000..6b488a68f --- /dev/null +++ b/source/slang/emit.cpp @@ -0,0 +1,2537 @@ +// emit.cpp +#include "emit.h" + +#include "syntax.h" +#include "type-layout.h" + +#include <assert.h> + +#ifdef _WIN32 +#include <d3dcompiler.h> +#pragma warning(disable:4996) +#endif + +namespace Slang { namespace Compiler { + +struct EmitContext +{ + StringBuilder sb; + + // Current source position for tracking purposes... + CodePosition loc; + + // The target language we want to generate code for + CodeGenTarget target; + + // A set of words reserved by the target + Dictionary<String, String> reservedWords; +}; + +// + +static void EmitDecl(EmitContext* context, RefPtr<Decl> decl); +static void EmitDecl(EmitContext* context, RefPtr<DeclBase> declBase); +static void EmitDeclUsingLayout(EmitContext* context, RefPtr<Decl> decl, RefPtr<VarLayout> layout); +static void EmitType(EmitContext* context, RefPtr<ExpressionType> type, String const& name); +static void EmitType(EmitContext* context, RefPtr<ExpressionType> type); +static void EmitExpr(EmitContext* context, RefPtr<ExpressionSyntaxNode> expr); +static void EmitStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt); +static void EmitDeclRef(EmitContext* context, DeclRef declRef); + +// Low-level emit logic + +static void emitRawTextSpan(EmitContext* context, char const* textBegin, char const* textEnd) +{ + // TODO(tfoley): Need to make "corelib" not use `int` for pointer-sized things... + auto len = int(textEnd - textBegin); + + context->sb.Append(textBegin, len); +} + +static void emitRawText(EmitContext* context, char const* text) +{ + emitRawTextSpan(context, text, text + strlen(text)); +} + +static void emitTextSpan(EmitContext* context, char const* textBegin, char const* textEnd) +{ + // Emit the raw text + emitRawTextSpan(context, textBegin, textEnd); + + // Update our logical position + // TODO(tfoley): Need to make "corelib" not use `int` for pointer-sized things... + auto len = int(textEnd - textBegin); + context->loc.Col += len; +} + +static void Emit(EmitContext* context, char const* textBegin, char const* textEnd) +{ + char const* spanBegin = textBegin; + + char const* spanEnd = spanBegin; + for(;;) + { + if(spanEnd == textEnd) + { + // We have a whole range of text waiting to be flushed + emitTextSpan(context, spanBegin, spanEnd); + return; + } + + auto c = *spanEnd++; + + if( c == '\n' ) + { + // At the end of a line, we need to update our tracking + // information on code positions + emitTextSpan(context, spanBegin, spanEnd); + context->loc.Line++; + context->loc.Col = 1; + + // Start a new span for emit purposes + spanBegin = spanEnd; + } + } +} + +static void Emit(EmitContext* context, char const* text) +{ + Emit(context, text, text + strlen(text)); +} + +static void emit(EmitContext* context, String const& text) +{ + Emit(context, text.begin(), text.end()); +} + +static bool isReservedWord(EmitContext* context, String const& name) +{ + return context->reservedWords.TryGetValue(name) != nullptr; +} + +static void emitName(EmitContext* context, String const& inName) +{ + String name = inName; + + // By default, we would like to emit a name in the generated + // code exactly as it appeared in the soriginal program. + // When that isn't possible, we'd like to emit a name as + // close to the original as possible (to ensure that existing + // debugging tools still work reasonably well). + // + // One reason why a name might not be allowed as-is is that + // it could collide with a reserved word in the target language. + // Another reason is that it might not follow a naming convention + // imposed by the target (e.g., in GLSL names starting with + // `gl_` or containing `__` are reserved). + // + // Given a name that should not be allowed, we want to + // change it to a name that *is* allowed. e.g., by adding + // `_` to the end of a reserved word. + // + // The next problem this creates is that the modified name + // could not collide with an existing use of the same + // (valid) name. + // + // For now we are going to solve this problem in a simple + // and ad hoc fashion, but longer term we'll want to do + // something sytematic. + + if (isReservedWord(context, name)) + { + name = name + "_"; + } + + emit(context, name); +} + +static void Emit(EmitContext* context, UInt value) +{ + char buffer[32]; + sprintf(buffer, "%llu", (unsigned long long)(value)); + Emit(context, buffer); +} + +static void Emit(EmitContext* context, int value) +{ + char buffer[16]; + sprintf(buffer, "%d", value); + Emit(context, buffer); +} + +static void Emit(EmitContext* context, double value) +{ + // TODO(tfoley): need to print things in a way that can round-trip + char buffer[128]; + sprintf(buffer, "%.20ff", value); + Emit(context, buffer); +} + +// Expressions + +// Determine if an expression should not be emitted when it is the base of +// a member reference expression. +static bool IsBaseExpressionImplicit(EmitContext* /*context*/, RefPtr<ExpressionSyntaxNode> expr) +{ + // HACK(tfoley): For now, anything with a constant-buffer type should be + // left implicit. + + // Look through any dereferencing that took place + RefPtr<ExpressionSyntaxNode> e = expr; + while (auto derefExpr = e.As<DerefExpr>()) + { + e = derefExpr->base; + } + // Is the expression referencing a constant buffer? + if (auto cbufferType = e->Type->As<ConstantBufferType>()) + { + return true; + } + + return false; +} + +enum +{ + kPrecedence_None, + kPrecedence_Comma, + + kPrecedence_Assign, + kPrecedence_AddAssign = kPrecedence_Assign, + kPrecedence_SubAssign = kPrecedence_Assign, + kPrecedence_MulAssign = kPrecedence_Assign, + kPrecedence_DivAssign = kPrecedence_Assign, + kPrecedence_ModAssign = kPrecedence_Assign, + kPrecedence_LshAssign = kPrecedence_Assign, + kPrecedence_RshAssign = kPrecedence_Assign, + kPrecedence_OrAssign = kPrecedence_Assign, + kPrecedence_AndAssign = kPrecedence_Assign, + kPrecedence_XorAssign = kPrecedence_Assign, + + kPrecedence_General = kPrecedence_Assign, + + kPrecedence_Conditional, // "ternary" + kPrecedence_Or, + kPrecedence_And, + kPrecedence_BitOr, + kPrecedence_BitXor, + kPrecedence_BitAnd, + + kPrecedence_Eql, + kPrecedence_Neq = kPrecedence_Eql, + + kPrecedence_Less, + kPrecedence_Greater = kPrecedence_Less, + kPrecedence_Leq = kPrecedence_Less, + kPrecedence_Geq = kPrecedence_Less, + + kPrecedence_Lsh, + kPrecedence_Rsh = kPrecedence_Lsh, + + kPrecedence_Add, + kPrecedence_Sub = kPrecedence_Add, + + kPrecedence_Mul, + kPrecedence_Div = kPrecedence_Mul, + kPrecedence_Mod = kPrecedence_Mul, + + kPrecedence_Prefix, + kPrecedence_Postfix, + kPrecedence_Atomic = kPrecedence_Postfix +}; + +static void EmitExprWithPrecedence(EmitContext* context, RefPtr<ExpressionSyntaxNode> expr, int outerPrec); + +static void EmitPostfixExpr(EmitContext* context, RefPtr<ExpressionSyntaxNode> expr) +{ + EmitExprWithPrecedence(context, expr, kPrecedence_Postfix); +} + +static void EmitExpr(EmitContext* context, RefPtr<ExpressionSyntaxNode> expr) +{ + EmitExprWithPrecedence(context, expr, kPrecedence_General); +} + +static bool MaybeEmitParens(EmitContext* context, int outerPrec, int prec) +{ + if (prec <= outerPrec) + { + Emit(context, "("); + return true; + } + return false; +} + +// When we are going to emit an expression in an l-value context, +// we may need to ignore certain constructs that the type-checker +// might have introduced, but which interfere with our ability +// to use it effectively in the target language +static RefPtr<ExpressionSyntaxNode> prepareLValueExpr( + EmitContext* context, + RefPtr<ExpressionSyntaxNode> expr) +{ + for(;;) + { + if(auto typeCastExpr = expr.As<TypeCastExpressionSyntaxNode>()) + { + expr = typeCastExpr->Expression; + } + // TODO: any other cases? + else + { + return expr; + } + } + +} + +static void emitInfixExprImpl( + EmitContext* context, + int outerPrec, + int prec, + char const* op, + RefPtr<InvokeExpressionSyntaxNode> binExpr, + bool isAssign) +{ + bool needsClose = MaybeEmitParens(context, outerPrec, prec); + + auto left = binExpr->Arguments[0]; + if(isAssign) + { + left = prepareLValueExpr(context, left); + } + + EmitExprWithPrecedence(context, left, prec); + Emit(context, " "); + Emit(context, op); + Emit(context, " "); + EmitExprWithPrecedence(context, binExpr->Arguments[1], prec); + if (needsClose) + { + Emit(context, ")"); + } +} + +static void EmitBinExpr(EmitContext* context, int outerPrec, int prec, char const* op, RefPtr<InvokeExpressionSyntaxNode> binExpr) +{ + emitInfixExprImpl(context, outerPrec, prec, op, binExpr, false); +} + +static void EmitBinAssignExpr(EmitContext* context, int outerPrec, int prec, char const* op, RefPtr<InvokeExpressionSyntaxNode> binExpr) +{ + emitInfixExprImpl(context, outerPrec, prec, op, binExpr, true); +} + +static void emitUnaryExprImpl( + EmitContext* context, + int outerPrec, + int prec, + char const* preOp, + char const* postOp, + RefPtr<InvokeExpressionSyntaxNode> expr, + bool isAssign) +{ + bool needsClose = MaybeEmitParens(context, outerPrec, prec); + Emit(context, preOp); + + auto arg = expr->Arguments[0]; + if(isAssign) + { + arg = prepareLValueExpr(context, arg); + } + + EmitExprWithPrecedence(context, arg, prec); + Emit(context, postOp); + if (needsClose) + { + Emit(context, ")"); + } +} + +static void EmitUnaryExpr( + EmitContext* context, + int outerPrec, + int prec, + char const* preOp, + char const* postOp, + RefPtr<InvokeExpressionSyntaxNode> expr) +{ + emitUnaryExprImpl(context, outerPrec, prec, preOp, postOp, expr, false); +} + +static void EmitUnaryAssignExpr( + EmitContext* context, + int outerPrec, + int prec, + char const* preOp, + char const* postOp, + RefPtr<InvokeExpressionSyntaxNode> expr) +{ + emitUnaryExprImpl(context, outerPrec, prec, preOp, postOp, expr, true); +} + +static void emitCallExpr( + EmitContext* context, + RefPtr<InvokeExpressionSyntaxNode> callExpr, + int outerPrec) +{ + auto funcExpr = callExpr->FunctionExpr; + if (auto funcDeclRefExpr = funcExpr.As<DeclRefExpr>()) + { + auto funcDeclRef = funcDeclRefExpr->declRef; + auto funcDecl = funcDeclRef.GetDecl(); + if (auto intrinsicModifier = funcDecl->FindModifier<IntrinsicModifier>()) + { + switch (intrinsicModifier->op) + { +#define CASE(NAME, OP) case IntrinsicOp::NAME: EmitBinExpr(context, outerPrec, kPrecedence_##NAME, #OP, callExpr); return + CASE(Mul, *); + CASE(Div, / ); + CASE(Mod, %); + CASE(Add, +); + CASE(Sub, -); + CASE(Lsh, << ); + CASE(Rsh, >> ); + CASE(Eql, == ); + CASE(Neq, != ); + CASE(Greater, >); + CASE(Less, <); + CASE(Geq, >= ); + CASE(Leq, <= ); + CASE(BitAnd, &); + CASE(BitXor, ^); + CASE(BitOr, | ); + CASE(And, &&); + CASE(Or, || ); +#undef CASE + +#define CASE(NAME, OP) case IntrinsicOp::NAME: EmitBinAssignExpr(context, outerPrec, kPrecedence_##NAME, #OP, callExpr); return + CASE(Assign, =); + CASE(AddAssign, +=); + CASE(SubAssign, -=); + CASE(MulAssign, *=); + CASE(DivAssign, /=); + CASE(ModAssign, %=); + CASE(LshAssign, <<=); + CASE(RshAssign, >>=); + CASE(OrAssign, |=); + CASE(AndAssign, &=); + CASE(XorAssign, ^=); +#undef CASE + + case IntrinsicOp::Sequence: EmitBinExpr(context, outerPrec, kPrecedence_Comma, ",", callExpr); return; + +#define CASE(NAME, OP) case IntrinsicOp::NAME: EmitUnaryExpr(context, outerPrec, kPrecedence_Prefix, #OP, "", callExpr); return + CASE(Neg, -); + CASE(Not, !); + CASE(BitNot, ~); +#undef CASE + +#define CASE(NAME, OP) case IntrinsicOp::NAME: EmitUnaryAssignExpr(context, outerPrec, kPrecedence_Prefix, #OP, "", callExpr); return + CASE(PreInc, ++); + CASE(PreDec, --); +#undef CASE + +#define CASE(NAME, OP) case IntrinsicOp::NAME: EmitUnaryAssignExpr(context, outerPrec, kPrecedence_Postfix, "", #OP, callExpr); return + CASE(PostInc, ++); + CASE(PostDec, --); +#undef CASE + + case IntrinsicOp::InnerProduct_Vector_Vector: + // HLSL allows `mul()` to be used as a synonym for `dot()`, + // so we need to translate to `dot` for GLSL + if (context->target == CodeGenTarget::GLSL) + { + Emit(context, "dot("); + EmitExpr(context, callExpr->Arguments[0]); + Emit(context, ", "); + EmitExpr(context, callExpr->Arguments[1]); + Emit(context, ")"); + return; + } + break; + + case IntrinsicOp::InnerProduct_Matrix_Matrix: + case IntrinsicOp::InnerProduct_Matrix_Vector: + case IntrinsicOp::InnerProduct_Vector_Matrix: + // HLSL exposes these with the `mul()` function, while GLSL uses ordinary + // `operator*`. + // + // The other critical detail here is that the way we handle matrix + // conventions requires that the operands to the product be swapped. + if (context->target == CodeGenTarget::GLSL) + { + Emit(context, "(("); + EmitExpr(context, callExpr->Arguments[1]); + Emit(context, ") * ("); + EmitExpr(context, callExpr->Arguments[0]); + Emit(context, "))"); + return; + } + break; + + default: + break; + } + + + // We might be calling an intrinsic subscript operation, + // and should desugar it accordingly + if(auto subscriptDeclRef = funcDeclRef.As<SubscriptDeclRef>()) + { + // We expect any subscript operation to be invoked as a member, + // so the function expression had better be in the correct form. + if(auto memberExpr = funcExpr.As<MemberExpressionSyntaxNode>()) + { + + Emit(context, "("); + EmitExpr(context, memberExpr->BaseExpression); + Emit(context, ")["); + int argCount = callExpr->Arguments.Count(); + for (int aa = 0; aa < argCount; ++aa) + { + if (aa != 0) Emit(context, ", "); + EmitExpr(context, callExpr->Arguments[aa]); + } + Emit(context, "]"); + return; + } + } + } + } + + // Fall through to default handling... + + bool needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Postfix); + + if (auto funcDeclRefExpr = funcExpr.As<DeclRefExpr>()) + { + auto declRef = funcDeclRefExpr->declRef; + if (auto ctorDeclRef = declRef.As<ConstructorDeclRef>()) + { + // We really want to emit a reference to the type begin constructed + EmitType(context, callExpr->Type); + } + else + { + // default case: just emit the decl ref + EmitExpr(context, funcExpr); + } + } + else + { + // default case: just emit the expression + EmitPostfixExpr(context, funcExpr); + } + + Emit(context, "("); + int argCount = callExpr->Arguments.Count(); + for (int aa = 0; aa < argCount; ++aa) + { + if (aa != 0) Emit(context, ", "); + EmitExpr(context, callExpr->Arguments[aa]); + } + Emit(context, ")"); + + if (needClose) + { + Emit(context, ")"); + } +} + +static void EmitExprWithPrecedence(EmitContext* context, RefPtr<ExpressionSyntaxNode> expr, int outerPrec) +{ + bool needClose = false; + if (auto selectExpr = expr.As<SelectExpressionSyntaxNode>()) + { + needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Conditional); + + EmitExprWithPrecedence(context, selectExpr->Arguments[0], kPrecedence_Conditional); + Emit(context, " ? "); + EmitExprWithPrecedence(context, selectExpr->Arguments[1], kPrecedence_Conditional); + Emit(context, " : "); + EmitExprWithPrecedence(context, selectExpr->Arguments[2], kPrecedence_Conditional); + } + else if (auto callExpr = expr.As<InvokeExpressionSyntaxNode>()) + { + emitCallExpr(context, callExpr, outerPrec); + } + else if (auto memberExpr = expr.As<MemberExpressionSyntaxNode>()) + { + needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Postfix); + + // TODO(tfoley): figure out a good way to reference + // declarations that might be generic and/or might + // not be generated as lexically nested declarations... + + // TODO(tfoley): also, probably need to special case + // this for places where we are using a built-in... + + auto base = memberExpr->BaseExpression; + if (IsBaseExpressionImplicit(context, base)) + { + // don't emit the base expression + } + else + { + EmitExprWithPrecedence(context, memberExpr->BaseExpression, kPrecedence_Postfix); + Emit(context, "."); + } + + emitName(context, memberExpr->declRef.GetName()); + } + else if (auto swizExpr = expr.As<SwizzleExpr>()) + { + needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Postfix); + + EmitExprWithPrecedence(context, swizExpr->base, kPrecedence_Postfix); + Emit(context, "."); + static const char* kComponentNames[] = { "x", "y", "z", "w" }; + int elementCount = swizExpr->elementCount; + for (int ee = 0; ee < elementCount; ++ee) + { + Emit(context, kComponentNames[swizExpr->elementIndices[ee]]); + } + } + else if (auto indexExpr = expr.As<IndexExpressionSyntaxNode>()) + { + needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Postfix); + + EmitExprWithPrecedence(context, indexExpr->BaseExpression, kPrecedence_Postfix); + Emit(context, "["); + EmitExpr(context, indexExpr->IndexExpression); + Emit(context, "]"); + } + else if (auto varExpr = expr.As<VarExpressionSyntaxNode>()) + { + needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Atomic); + + // Because of the "rewriter" use case, it is possible that we will + // be trying to emit an expression that hasn't been wired up to + // any associated declaration. In that case, we will just emit + // the variable name. + // + // TODO: A better long-term solution here is to have a distinct + // case for an "unchecked" `NameExpr` that doesn't include + // a declaration reference. + + if(varExpr->declRef) + { + EmitDeclRef(context, varExpr->declRef); + } + else + { + emitName(context, varExpr->Variable); + } + } + else if (auto derefExpr = expr.As<DerefExpr>()) + { + // TODO(tfoley): dereference shouldn't always be implicit + EmitExprWithPrecedence(context, derefExpr->base, outerPrec); + } + else if (auto litExpr = expr.As<ConstantExpressionSyntaxNode>()) + { + needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Atomic); + + switch (litExpr->ConstType) + { + case ConstantExpressionSyntaxNode::ConstantType::Int: + Emit(context, litExpr->IntValue); + break; + case ConstantExpressionSyntaxNode::ConstantType::Float: + Emit(context, litExpr->FloatValue); + break; + case ConstantExpressionSyntaxNode::ConstantType::Bool: + Emit(context, litExpr->IntValue ? "true" : "false"); + break; + default: + assert(!"unreachable"); + break; + } + } + else if (auto castExpr = expr.As<TypeCastExpressionSyntaxNode>()) + { + switch(context->target) + { + case CodeGenTarget::GLSL: + // GLSL requires constructor syntax for all conversions + EmitType(context, castExpr->Type); + Emit(context, "("); + EmitExpr(context, castExpr->Expression); + Emit(context, ")"); + break; + + default: + // HLSL (and C/C++) prefer cast syntax + // (In fact, HLSL doesn't allow constructor syntax for some conversions it allows as a cast) + needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Prefix); + + Emit(context, "("); + EmitType(context, castExpr->Type); + Emit(context, ")("); + EmitExpr(context, castExpr->Expression); + Emit(context, ")"); + break; + } + + } + else if(auto initExpr = expr.As<InitializerListExpr>()) + { + Emit(context, "{ "); + for(auto& arg : initExpr->args) + { + EmitExpr(context, arg); + Emit(context, ", "); + } + Emit(context, "}"); + } + else + { + throw "unimplemented"; + } + if (needClose) + { + Emit(context, ")"); + } +} + +// Types + +void Emit(EmitContext* context, RefPtr<IntVal> val) +{ + if(auto constantIntVal = val.As<ConstantIntVal>()) + { + Emit(context, constantIntVal->value); + } + else if(auto varRefVal = val.As<GenericParamIntVal>()) + { + EmitDeclRef(context, varRefVal->declRef); + } + else + { + assert(!"unimplemented"); + } +} + +// represents a declarator for use in emitting types +struct EDeclarator +{ + enum class Flavor + { + Name, + Array, + UnsizedArray, + }; + Flavor flavor; + EDeclarator* next = nullptr; + + // Used for `Flavor::Name` + String name; + + // Used for `Flavor::Array` + IntVal* elementCount; +}; + +static void EmitDeclarator(EmitContext* context, EDeclarator* declarator) +{ + if (!declarator) return; + + Emit(context, " "); + + switch (declarator->flavor) + { + case EDeclarator::Flavor::Name: + emitName(context, declarator->name); + break; + + case EDeclarator::Flavor::Array: + EmitDeclarator(context, declarator->next); + Emit(context, "["); + if(auto elementCount = declarator->elementCount) + { + Emit(context, elementCount); + } + Emit(context, "]"); + break; + + case EDeclarator::Flavor::UnsizedArray: + EmitDeclarator(context, declarator->next); + Emit(context, "[]"); + break; + + default: + assert(!"unreachable"); + break; + } +} + +static void emitGLSLTypePrefix( + EmitContext* context, + RefPtr<ExpressionType> type) +{ + if(auto basicElementType = type->As<BasicExpressionType>()) + { + switch (basicElementType->BaseType) + { + case BaseType::Float: + // no prefix + break; + + case BaseType::Int: Emit(context, "i"); break; + case BaseType::UInt: Emit(context, "u"); break; + case BaseType::Bool: Emit(context, "b"); break; + default: + assert(!"unreachable"); + break; + } + } + else if(auto vectorType = type->As<VectorExpressionType>()) + { + emitGLSLTypePrefix(context, vectorType->elementType); + } + else if(auto matrixType = type->As<MatrixExpressionType>()) + { + emitGLSLTypePrefix(context, matrixType->getElementType()); + } + else + { + assert(!"unreachable"); + } +} + +static void emitHLSLTextureType( + EmitContext* context, + RefPtr<TextureTypeBase> texType) +{ + switch(texType->getAccess()) + { + case SLANG_RESOURCE_ACCESS_READ: + break; + + case SLANG_RESOURCE_ACCESS_READ_WRITE: + Emit(context, "RW"); + break; + + case SLANG_RESOURCE_ACCESS_RASTER_ORDERED: + Emit(context, "RasterizerOrdered"); + break; + + case SLANG_RESOURCE_ACCESS_APPEND: + Emit(context, "Append"); + break; + + case SLANG_RESOURCE_ACCESS_CONSUME: + Emit(context, "Consume"); + break; + + default: + assert(!"unreachable"); + break; + } + + switch (texType->GetBaseShape()) + { + case TextureType::Shape1D: Emit(context, "Texture1D"); break; + case TextureType::Shape2D: Emit(context, "Texture2D"); break; + case TextureType::Shape3D: Emit(context, "Texture3D"); break; + case TextureType::ShapeCube: Emit(context, "TextureCube"); break; + default: + assert(!"unreachable"); + break; + } + + if (texType->isMultisample()) + { + Emit(context, "MS"); + } + if (texType->isArray()) + { + Emit(context, "Array"); + } + Emit(context, "<"); + EmitType(context, texType->elementType); + Emit(context, ">"); +} + +static void emitGLSLTextureOrTextureSamplerType( + EmitContext* context, + RefPtr<TextureTypeBase> type, + char const* baseName) +{ + emitGLSLTypePrefix(context, type->elementType); + + Emit(context, baseName); + switch (type->GetBaseShape()) + { + case TextureType::Shape1D: Emit(context, "1D"); break; + case TextureType::Shape2D: Emit(context, "2D"); break; + case TextureType::Shape3D: Emit(context, "3D"); break; + case TextureType::ShapeCube: Emit(context, "Cube"); break; + default: + assert(!"unreachable"); + break; + } + + if (type->isMultisample()) + { + Emit(context, "MS"); + } + if (type->isArray()) + { + Emit(context, "Array"); + } +} + +static void emitGLSLTextureType( + EmitContext* context, + RefPtr<TextureType> texType) +{ + emitGLSLTextureOrTextureSamplerType(context, texType, "texture"); +} + +static void emitGLSLTextureSamplerType( + EmitContext* context, + RefPtr<TextureSamplerType> type) +{ + emitGLSLTextureOrTextureSamplerType(context, type, "sampler"); +} + +static void emitGLSLImageType( + EmitContext* context, + RefPtr<GLSLImageType> type) +{ + emitGLSLTextureOrTextureSamplerType(context, type, "image"); +} + +static void emitTextureType( + EmitContext* context, + RefPtr<TextureType> texType) +{ + switch(context->target) + { + case CodeGenTarget::HLSL: + emitHLSLTextureType(context, texType); + break; + + case CodeGenTarget::GLSL: + emitGLSLTextureType(context, texType); + break; + + default: + assert(!"unreachable"); + break; + } +} + +static void emitTextureSamplerType( + EmitContext* context, + RefPtr<TextureSamplerType> type) +{ + switch(context->target) + { + case CodeGenTarget::GLSL: + emitGLSLTextureSamplerType(context, type); + break; + + default: + assert(!"unreachable"); + break; + } +} + +static void emitImageType( + EmitContext* context, + RefPtr<GLSLImageType> type) +{ + switch(context->target) + { + case CodeGenTarget::HLSL: + emitHLSLTextureType(context, type); + break; + + case CodeGenTarget::GLSL: + emitGLSLImageType(context, type); + break; + + default: + assert(!"unreachable"); + break; + } +} + +static void EmitType(EmitContext* context, RefPtr<ExpressionType> type, EDeclarator* declarator) +{ + if (auto basicType = type->As<BasicExpressionType>()) + { + switch (basicType->BaseType) + { + case BaseType::Void: Emit(context, "void"); break; + case BaseType::Int: Emit(context, "int"); break; + case BaseType::Float: Emit(context, "float"); break; + case BaseType::UInt: Emit(context, "uint"); break; + case BaseType::Bool: Emit(context, "bool"); break; + default: + assert(!"unreachable"); + break; + } + + EmitDeclarator(context, declarator); + return; + } + else if (auto vecType = type->As<VectorExpressionType>()) + { + switch(context->target) + { + case CodeGenTarget::GLSL: + case CodeGenTarget::GLSL_Vulkan: + case CodeGenTarget::GLSL_Vulkan_OneDesc: + { + emitGLSLTypePrefix(context, vecType->elementType); + Emit(context, "vec"); + Emit(context, vecType->elementCount); + } + break; + + case CodeGenTarget::HLSL: + // TODO(tfoley): should really emit these with sugar + Emit(context, "vector<"); + EmitType(context, vecType->elementType); + Emit(context, ","); + Emit(context, vecType->elementCount); + Emit(context, ">"); + break; + + default: + assert(!"unreachable"); + break; + } + + Emit(context, " "); + EmitDeclarator(context, declarator); + return; + } + else if (auto matType = type->As<MatrixExpressionType>()) + { + switch(context->target) + { + case CodeGenTarget::GLSL: + case CodeGenTarget::GLSL_Vulkan: + case CodeGenTarget::GLSL_Vulkan_OneDesc: + { + emitGLSLTypePrefix(context, matType->getElementType()); + Emit(context, "mat"); + Emit(context, matType->getRowCount()); + // TODO(tfoley): only emit the next bit + // for non-square matrix + Emit(context, "x"); + Emit(context, matType->getColumnCount()); + } + break; + + case CodeGenTarget::HLSL: + // TODO(tfoley): should really emit these with sugar + Emit(context, "matrix<"); + EmitType(context, matType->getElementType()); + Emit(context, ","); + Emit(context, matType->getRowCount()); + Emit(context, ","); + Emit(context, matType->getColumnCount()); + Emit(context, "> "); + break; + + default: + assert(!"unreachable"); + break; + } + + Emit(context, " "); + EmitDeclarator(context, declarator); + return; + } + else if (auto texType = type->As<TextureType>()) + { + emitTextureType(context, texType); + Emit(context, " "); + EmitDeclarator(context, declarator); + return; + } + else if (auto textureSamplerType = type->As<TextureSamplerType>()) + { + emitTextureSamplerType(context, textureSamplerType); + Emit(context, " "); + EmitDeclarator(context, declarator); + return; + } + else if (auto imageType = type->As<GLSLImageType>()) + { + emitImageType(context, imageType); + Emit(context, " "); + EmitDeclarator(context, declarator); + return; + } + else if (auto samplerStateType = type->As<SamplerStateType>()) + { + switch(context->target) + { + case CodeGenTarget::HLSL: + default: + switch (samplerStateType->flavor) + { + case SamplerStateType::Flavor::SamplerState: Emit(context, "SamplerState"); break; + case SamplerStateType::Flavor::SamplerComparisonState: Emit(context, "SamplerComparisonState"); break; + default: + assert(!"unreachable"); + break; + } + break; + + case CodeGenTarget::GLSL: + Emit(context, "sampler"); + break; + } + + + EmitDeclarator(context, declarator); + return; + } + else if (auto declRefType = type->As<DeclRefType>()) + { + EmitDeclRef(context, declRefType->declRef); + + EmitDeclarator(context, declarator); + return; + } + else if( auto arrayType = type->As<ArrayExpressionType>() ) + { + EDeclarator arrayDeclarator; + arrayDeclarator.next = declarator; + + if(arrayType->ArrayLength) + { + arrayDeclarator.flavor = EDeclarator::Flavor::Array; + arrayDeclarator.elementCount = arrayType->ArrayLength.Ptr(); + } + else + { + arrayDeclarator.flavor = EDeclarator::Flavor::UnsizedArray; + } + + + EmitType(context, arrayType->BaseType, &arrayDeclarator); + return; + } + + throw "unimplemented"; +} + +static void EmitType(EmitContext* context, RefPtr<ExpressionType> type, String const& name) +{ + EDeclarator nameDeclarator; + nameDeclarator.flavor = EDeclarator::Flavor::Name; + nameDeclarator.name = name; + EmitType(context, type, &nameDeclarator); +} + +static void EmitType(EmitContext* context, RefPtr<ExpressionType> type) +{ + EmitType(context, type, nullptr); +} + +// Statements + +// Emit a statement as a `{}`-enclosed block statement, but avoid adding redundant +// curly braces if the statement is itself a block statement. +static void EmitBlockStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt) +{ + // TODO(tfoley): support indenting + Emit(context, "{\n"); + if( auto blockStmt = stmt.As<BlockStatementSyntaxNode>() ) + { + for (auto s : blockStmt->Statements) + { + EmitStmt(context, s); + } + } + else + { + EmitStmt(context, stmt); + } + Emit(context, "}\n"); +} + +static void EmitLoopAttributes(EmitContext* context, RefPtr<StatementSyntaxNode> decl) +{ + // TODO(tfoley): There really ought to be a semantic checking step for attributes, + // that turns abstract syntax into a concrete hierarchy of attribute types (e.g., + // a specific `LoopModifier` or `UnrollModifier`). + + for(auto attr : decl->GetModifiersOfType<HLSLUncheckedAttribute>()) + { + if(attr->nameToken.Content == "loop") + { + Emit(context, "[loop]"); + } + else if(attr->nameToken.Content == "unroll") + { + Emit(context, "[unroll]"); + } + } +} + +static void advanceToSourceLocation( + EmitContext* context, + CodePosition const& sourceLocation) +{ + // If we are currently emitting code at a source location with + // a differnet file or line, *or* if the source location is + // somehow later on the line than what we want to emit, + // then we need to emit a new `#line` directive. + if(sourceLocation.FileName != context->loc.FileName + || sourceLocation.Line != context->loc.Line + || sourceLocation.Col < context->loc.Col) + { + emitRawText(context, "\n#line "); + + char buffer[16]; + sprintf(buffer, "%d", sourceLocation.Line); + emitRawText(context, buffer); + + emitRawText(context, "\""); + for(auto c : sourceLocation.FileName) + { + char charBuffer[] = { c, 0 }; + switch(c) + { + default: + emitRawText(context, charBuffer); + break; + + // TODO: should probably canonicalize paths to not use backslash somewhere else + // in the compilation pipeline... + case '\\': + emitRawText(context, "/"); + break; + } + } + emitRawText(context, "\"\n"); + + context->loc.FileName = sourceLocation.FileName; + context->loc.Line = sourceLocation.Line; + context->loc.Col = 1; + } + + // Now indent up to the appropriate column, so that error messages + // that reference columns will be correct. + // + // TODO: This logic does not take into account whether indentation + // came in as spaces or tabs, so there is necessarily going to be + // coupling between how the downstream compiler counts columns, + // and how we do. + if(sourceLocation.Col > context->loc.Col) + { + int delta = sourceLocation.Col - context->loc.Col; + for( int ii = 0; ii < delta; ++ii ) + { + emitRawText(context, " "); + } + context->loc.Col = sourceLocation.Col; + } +} + +static void emitTokenWithLocation(EmitContext* context, Token const& token) +{ + if( token.Position.FileName.Length() != 0 ) + { + advanceToSourceLocation(context, token.Position); + } + else + { + // If we don't have the original position info, we need to play + // it safe and emit whitespace to line things up nicely + + if(token.flags & TokenFlag::AtStartOfLine) + Emit(context, "\n"); + // TODO(tfoley): macro expansion can currently lead to whitespace getting dropped, + // so we will just insert it aggressively, to play it safe. + else // if(token.flags & TokenFlag::AfterWhitespace) + Emit(context, " "); + } + + // Emit the raw textual content of the token + emit(context, token.Content); +} + +static void EmitUnparsedStmt(EmitContext* context, RefPtr<UnparsedStmt> stmt) +{ + // TODO: actually emit the tokens that made up the statement... + Emit(context, "{\n"); + for( auto& token : stmt->tokens ) + { + emitTokenWithLocation(context, token); + } + Emit(context, "}\n"); +} + +static void EmitStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt) +{ + if (auto blockStmt = stmt.As<BlockStatementSyntaxNode>()) + { + EmitBlockStmt(context, blockStmt); + return; + } + else if( auto unparsedStmt = stmt.As<UnparsedStmt>() ) + { + EmitUnparsedStmt(context, unparsedStmt); + return; + } + else if (auto exprStmt = stmt.As<ExpressionStatementSyntaxNode>()) + { + EmitExpr(context, exprStmt->Expression); + Emit(context, ";\n"); + return; + } + else if (auto returnStmt = stmt.As<ReturnStatementSyntaxNode>()) + { + Emit(context, "return"); + if (auto expr = returnStmt->Expression) + { + Emit(context, " "); + EmitExpr(context, expr); + } + Emit(context, ";\n"); + return; + } + else if (auto declStmt = stmt.As<VarDeclrStatementSyntaxNode>()) + { + EmitDecl(context, declStmt->decl); + return; + } + else if (auto ifStmt = stmt.As<IfStatementSyntaxNode>()) + { + Emit(context, "if("); + EmitExpr(context, ifStmt->Predicate); + Emit(context, ")\n"); + EmitBlockStmt(context, ifStmt->PositiveStatement); + if(auto elseStmt = ifStmt->NegativeStatement) + { + Emit(context, "\nelse\n"); + EmitBlockStmt(context, elseStmt); + } + return; + } + else if (auto forStmt = stmt.As<ForStatementSyntaxNode>()) + { + EmitLoopAttributes(context, forStmt); + + Emit(context, "for("); + if (auto initStmt = forStmt->InitialStatement) + { + EmitStmt(context, initStmt); + } + else + { + Emit(context, ";"); + } + if (auto testExp = forStmt->PredicateExpression) + { + EmitExpr(context, testExp); + } + Emit(context, ";"); + if (auto incrExpr = forStmt->SideEffectExpression) + { + EmitExpr(context, incrExpr); + } + Emit(context, ")\n"); + EmitBlockStmt(context, forStmt->Statement); + return; + } + else if (auto discardStmt = stmt.As<DiscardStatementSyntaxNode>()) + { + Emit(context, "discard;\n"); + return; + } + else if (auto emptyStmt = stmt.As<EmptyStatementSyntaxNode>()) + { + return; + } + else if (auto switchStmt = stmt.As<SwitchStmt>()) + { + Emit(context, "switch("); + EmitExpr(context, switchStmt->condition); + Emit(context, ")\n"); + EmitBlockStmt(context, switchStmt->body); + return; + } + else if (auto caseStmt = stmt.As<CaseStmt>()) + { + Emit(context, "case "); + EmitExpr(context, caseStmt->expr); + Emit(context, ":\n"); + return; + } + else if (auto defaultStmt = stmt.As<DefaultStmt>()) + { + Emit(context, "default:{}\n"); + return; + } + else if (auto breakStmt = stmt.As<BreakStatementSyntaxNode>()) + { + Emit(context, "break;\n"); + return; + } + else if (auto continueStmt = stmt.As<ContinueStatementSyntaxNode>()) + { + Emit(context, "continue;\n"); + return; + } + + throw "unimplemented"; + +} + +// Declaration References + +static void EmitVal(EmitContext* context, RefPtr<Val> val) +{ + if (auto type = val.As<ExpressionType>()) + { + EmitType(context, type); + } + else if (auto intVal = val.As<IntVal>()) + { + Emit(context, intVal); + } + else + { + // Note(tfoley): ignore unhandled cases for semantics for now... +// assert(!"unimplemented"); + } +} + +static void EmitDeclRef(EmitContext* context, DeclRef declRef) +{ + // TODO: need to qualify a declaration name based on parent scopes/declarations + + // Emit the name for the declaration itself + emitName(context, declRef.GetName()); + + // If the declaration is nested directly in a generic, then + // we need to output the generic arguments here + auto parentDeclRef = declRef.GetParent(); + if (auto genericDeclRef = parentDeclRef.As<GenericDeclRef>()) + { + // Only do this for declarations of appropriate flavors + if(auto funcDeclRef = declRef.As<FuncDeclBaseRef>()) + { + // Don't emit generic arguments for functions, because HLSL doesn't allow them + return; + } + + Substitutions* subst = declRef.substitutions.Ptr(); + Emit(context, "<"); + int argCount = subst->args.Count(); + for (int aa = 0; aa < argCount; ++aa) + { + if (aa != 0) Emit(context, ","); + EmitVal(context, subst->args[aa]); + } + Emit(context, ">"); + } + +} + +// Declarations + +// Emit any modifiers that should go in front of a declaration +static void EmitModifiers(EmitContext* context, RefPtr<Decl> decl) +{ + // Emit any GLSL `layout` modifiers first + bool anyLayout = false; + for( auto mod : decl->GetModifiersOfType<GLSLUnparsedLayoutModifier>()) + { + if(!anyLayout) + { + Emit(context, "layout("); + anyLayout = true; + } + else + { + Emit(context, ", "); + } + + emit(context, mod->nameToken.Content); + if(mod->valToken.Type != TokenType::Unknown) + { + Emit(context, " = "); + emit(context, mod->valToken.Content); + } + } + if(anyLayout) + { + Emit(context, ")\n"); + } + + for (auto mod = decl->modifiers.first; mod; mod = mod->next) + { + if (0) {} + + #define CASE(TYPE, KEYWORD) \ + else if(auto mod_##TYPE = mod.As<TYPE>()) Emit(context, #KEYWORD " ") + + CASE(RowMajorLayoutModifier, row_major); + CASE(ColumnMajorLayoutModifier, column_major); + CASE(HLSLNoInterpolationModifier, nointerpolation); + CASE(HLSLPreciseModifier, precise); + CASE(HLSLEffectSharedModifier, shared); + CASE(HLSLGroupSharedModifier, groupshared); + CASE(HLSLStaticModifier, static); + CASE(HLSLUniformModifier, uniform); + CASE(HLSLVolatileModifier, volatile); + + CASE(InOutModifier, inout); + CASE(InModifier, in); + CASE(OutModifier, out); + + CASE(HLSLPointModifier, point); + CASE(HLSLLineModifier, line); + CASE(HLSLTriangleModifier, triangle); + CASE(HLSLLineAdjModifier, lineadj); + CASE(HLSLTriangleAdjModifier, triangleadj); + + CASE(HLSLLinearModifier, linear); + CASE(HLSLSampleModifier, sample); + CASE(HLSLCentroidModifier, centroid); + + CASE(ConstModifier, const); + + #undef CASE + + // TODO: eventually we should be checked these modifiers, but for + // now we can emit them unchecked, I guess + else if (auto uncheckedAttr = mod.As<HLSLAttribute>()) + { + Emit(context, "["); + emit(context, uncheckedAttr->nameToken.Content); + auto& args = uncheckedAttr->args; + auto argCount = args.Count(); + if (argCount != 0) + { + Emit(context, "("); + for (int aa = 0; aa < argCount; ++aa) + { + if (aa != 0) Emit(context, ", "); + EmitExpr(context, args[aa]); + } + Emit(context, ")"); + } + Emit(context, "]"); + } + + else if(auto simpleModifier = mod.As<SimpleModifier>()) + { + emit(context, simpleModifier->nameToken.Content); + Emit(context, " "); + } + + else + { + // skip any extra modifiers + } + } +} + + +typedef unsigned int ESemanticMask; +enum +{ + kESemanticMask_None = 0, + + kESemanticMask_NoPackOffset = 1 << 0, + + kESemanticMask_Default = kESemanticMask_NoPackOffset, +}; + +static void EmitSemantic(EmitContext* context, RefPtr<HLSLSemantic> semantic, ESemanticMask /*mask*/) +{ + if (auto simple = semantic.As<HLSLSimpleSemantic>()) + { + Emit(context, ": "); + emit(context, simple->name.Content); + } + else if(auto registerSemantic = semantic.As<HLSLRegisterSemantic>()) + { + // Don't print out semantic from the user, since we are going to print the same thing our own way... +#if 0 + Emit(context, ": register("); + Emit(context, registerSemantic->registerName.Content); + if(registerSemantic->componentMask.Type != TokenType::Unknown) + { + Emit(context, "."); + Emit(context, registerSemantic->componentMask.Content); + } + Emit(context, ")"); +#endif + } + else if(auto packOffsetSemantic = semantic.As<HLSLPackOffsetSemantic>()) + { + // Don't print out semantic from the user, since we are going to print the same thing our own way... +#if 0 + if(mask & kESemanticMask_NoPackOffset) + return; + + Emit(context, ": packoffset("); + Emit(context, packOffsetSemantic->registerName.Content); + if(packOffsetSemantic->componentMask.Type != TokenType::Unknown) + { + Emit(context, "."); + Emit(context, packOffsetSemantic->componentMask.Content); + } + Emit(context, ")"); +#endif + } + else + { + assert(!"unimplemented"); + } +} + + +static void EmitSemantics(EmitContext* context, RefPtr<Decl> decl, ESemanticMask mask = kESemanticMask_Default ) +{ + // Don't emit semantics if we aren't translating down to HLSL + switch (context->target) + { + case CodeGenTarget::HLSL: + break; + + default: + return; + } + + for (auto mod = decl->modifiers.first; mod; mod = mod->next) + { + auto semantic = mod.As<HLSLSemantic>(); + if (!semantic) + continue; + + EmitSemantic(context, semantic, mask); + } +} + +static void EmitDeclsInContainer(EmitContext* context, RefPtr<ContainerDecl> container) +{ + for (auto member : container->Members) + { + EmitDecl(context, member); + } +} + +static void EmitDeclsInContainerUsingLayout( + EmitContext* context, + RefPtr<ContainerDecl> container, + RefPtr<StructTypeLayout> containerLayout) +{ + for (auto member : container->Members) + { + RefPtr<VarLayout> memberLayout; + if( containerLayout->mapVarToLayout.TryGetValue(member.Ptr(), memberLayout) ) + { + EmitDeclUsingLayout(context, member, memberLayout); + } + else + { + // No layout for this decl + EmitDecl(context, member); + } + } +} + +static void EmitTypeDefDecl(EmitContext* context, RefPtr<TypeDefDecl> decl) +{ + // TODO(tfoley): check if current compilation target even supports typedefs + + Emit(context, "typedef "); + EmitType(context, decl->Type, decl->Name.Content); + Emit(context, ";\n"); +} + +static void EmitStructDecl(EmitContext* context, RefPtr<StructSyntaxNode> decl) +{ + // Don't emit a declaration that was only generated implicitly, for + // the purposes of semantic checking. + if(decl->HasModifier<ImplicitParameterBlockElementTypeModifier>()) + return; + + Emit(context, "struct "); + emitName(context, decl->Name.Content); + Emit(context, "\n{\n"); + + // TODO(tfoley): Need to hoist members functions, etc. out to global scope + EmitDeclsInContainer(context, decl); + + Emit(context, "};\n"); +} + +// Shared emit logic for variable declarations (used for parameters, locals, globals, fields) +static void EmitVarDeclCommon(EmitContext* context, VarDeclBaseRef declRef) +{ + EmitModifiers(context, declRef.GetDecl()); + + EmitType(context, declRef.GetType(), declRef.GetName()); + + EmitSemantics(context, declRef.GetDecl()); + + // TODO(tfoley): technically have to apply substitution here too... + if (auto initExpr = declRef.GetDecl()->Expr) + { + Emit(context, " = "); + EmitExpr(context, initExpr); + } +} + +// Shared emit logic for variable declarations (used for parameters, locals, globals, fields) +static void EmitVarDeclCommon(EmitContext* context, RefPtr<VarDeclBase> decl) +{ + EmitVarDeclCommon(context, DeclRef(decl.Ptr(), nullptr).As<VarDeclBaseRef>()); +} + +// Emit a single `regsiter` semantic, as appropriate for a given resource-type-specific layout info +static void emitHLSLRegisterSemantic( + EmitContext* context, + VarLayout::ResourceInfo const& info) +{ + if( info.kind == LayoutResourceKind::Uniform ) + { + size_t offset = info.index; + + // The HLSL `c` register space is logically grouped in 16-byte registers, + // while we try to traffic in byte offsets. That means we need to pick + // a register number, based on the starting offset in 16-byte register + // units, and then a "component" within that register, based on 4-byte + // offsets from there. We cannot support more fine-grained offsets than that. + + Emit(context, ": packoffset(c"); + + // Size of a logical `c` register in bytes + auto registerSize = 16; + + // Size of each component of a logical `c` register, in bytes + auto componentSize = 4; + + size_t startRegister = offset / registerSize; + Emit(context, int(startRegister)); + + size_t byteOffsetInRegister = offset % registerSize; + + // If this field doesn't start on an even register boundary, + // then we need to emit additional information to pick the + // right component to start from + if (byteOffsetInRegister != 0) + { + // The value had better occupy a whole number of components. + assert(byteOffsetInRegister % componentSize == 0); + + size_t startComponent = byteOffsetInRegister / componentSize; + + static const char* kComponentNames[] = {"x", "y", "z", "w"}; + Emit(context, "."); + Emit(context, kComponentNames[startComponent]); + } + Emit(context, ")"); + } + else + { + Emit(context, ": register("); + switch( info.kind ) + { + case LayoutResourceKind::ConstantBuffer: + Emit(context, "b"); + break; + case LayoutResourceKind::ShaderResource: + Emit(context, "t"); + break; + case LayoutResourceKind::UnorderedAccess: + Emit(context, "u"); + break; + case LayoutResourceKind::SamplerState: + Emit(context, "s"); + break; + default: + assert(!"unexpected"); + break; + } + Emit(context, info.index); + if(info.space) + { + Emit(context, ", space"); + Emit(context, info.space); + } + Emit(context, ")"); + } +} + +// Emit all the `register` semantics that are appropriate for a particular variable layout +static void emitHLSLRegisterSemantics( + EmitContext* context, + RefPtr<VarLayout> layout) +{ + if (!layout) return; + + switch( context->target ) + { + default: + return; + + case CodeGenTarget::HLSL: + break; + } + + for( auto rr : layout->resourceInfos ) + { + emitHLSLRegisterSemantic(context, rr); + } +} + +static void emitHLSLParameterBlockDecl( + EmitContext* context, + RefPtr<VarDeclBase> varDecl, + RefPtr<ParameterBlockType> parameterBlockType, + RefPtr<VarLayout> layout) +{ + // The data type that describes where stuff in the constant buffer should go + RefPtr<ExpressionType> dataType = parameterBlockType->elementType; + + // We expect/require the data type to be a user-defined `struct` type + auto declRefType = dataType->As<DeclRefType>(); + assert(declRefType); + + // We expect to always have layout information + assert(layout); + + // We expect the layout to be for a structured type... + RefPtr<ParameterBlockTypeLayout> bufferLayout = layout->typeLayout.As<ParameterBlockTypeLayout>(); + assert(bufferLayout); + + RefPtr<StructTypeLayout> structTypeLayout = bufferLayout->elementTypeLayout.As<StructTypeLayout>(); + assert(structTypeLayout); + + if( auto constantBufferType = parameterBlockType->As<ConstantBufferType>() ) + { + Emit(context, "cbuffer "); + } + else if( auto textureBufferType = parameterBlockType->As<TextureBufferType>() ) + { + Emit(context, "tbuffer "); + } + + if( auto reflectionNameModifier = varDecl->FindModifier<ParameterBlockReflectionName>() ) + { + Emit(context, " "); + emitName(context, reflectionNameModifier->nameToken.Content); + } + + EmitSemantics(context, varDecl, kESemanticMask_None); + + auto info = layout->FindResourceInfo(LayoutResourceKind::ConstantBuffer); + assert(info); + emitHLSLRegisterSemantic(context, *info); + + Emit(context, "\n{\n"); + if (auto structRef = declRefType->declRef.As<StructDeclRef>()) + { + for (auto field : structRef.GetMembersOfType<FieldDeclRef>()) + { + EmitVarDeclCommon(context, field); + + RefPtr<VarLayout> fieldLayout; + structTypeLayout->mapVarToLayout.TryGetValue(field.GetDecl(), fieldLayout); + assert(fieldLayout); + + // Emit explicit layout annotations for every field + for( auto rr : fieldLayout->resourceInfos ) + { + auto kind = rr.kind; + + auto offsetResource = rr; + + if(kind != LayoutResourceKind::Uniform) + { + // Add the base index from the cbuffer into the index of the field + // + // TODO(tfoley): consider maybe not doing this, since it actually + // complicates logic around constant buffers... + + // If the member of the cbuffer uses a resource, it had better + // appear as part of the cubffer layout as well. + auto cbufferResource = layout->FindResourceInfo(kind); + assert(cbufferResource); + + offsetResource.index += cbufferResource->index; + offsetResource.space += cbufferResource->space; + } + + emitHLSLRegisterSemantic(context, offsetResource); + } + + Emit(context, ";\n"); + } + } + Emit(context, "}\n"); +} + +static void +emitGLSLLayoutQualifier( + EmitContext* context, + VarLayout::ResourceInfo const& info) +{ + switch(info.kind) + { + case LayoutResourceKind::Uniform: + Emit(context, "layout(offset = "); + Emit(context, info.index); + Emit(context, ")\n"); + break; + + case LayoutResourceKind::VertexInput: + case LayoutResourceKind::FragmentOutput: + Emit(context, "layout(location = "); + Emit(context, info.index); + Emit(context, ")\n"); + break; + + case LayoutResourceKind::SpecializationConstant: + Emit(context, "layout(constant_id = "); + Emit(context, info.index); + Emit(context, ")\n"); + break; + + case LayoutResourceKind::ConstantBuffer: + case LayoutResourceKind::ShaderResource: + case LayoutResourceKind::UnorderedAccess: + case LayoutResourceKind::SamplerState: + case LayoutResourceKind::DescriptorTableSlot: + Emit(context, "layout(binding = "); + Emit(context, info.index); + if(info.space) + { + Emit(context, ", set = "); + Emit(context, info.space); + } + Emit(context, ")\n"); + break; + } +} + +static void +emitGLSLLayoutQualifiers( + EmitContext* context, + RefPtr<VarLayout> layout) +{ + if(!layout) return; + + switch( context->target ) + { + default: + return; + + case CodeGenTarget::GLSL: + break; + } + + for( auto info : layout->resourceInfos ) + { + emitGLSLLayoutQualifier(context, info); + } +} + +static void emitGLSLParameterBlockDecl( + EmitContext* context, + RefPtr<VarDeclBase> varDecl, + RefPtr<ParameterBlockType> parameterBlockType, + RefPtr<VarLayout> layout) +{ + // The data type that describes where stuff in the constant buffer should go + RefPtr<ExpressionType> dataType = parameterBlockType->elementType; + + // We expect/require the data type to be a user-defined `struct` type + auto declRefType = dataType->As<DeclRefType>(); + assert(declRefType); + + // We expect to always have layout information + assert(layout); + + // We expect the layout to be for a structured type... + RefPtr<ParameterBlockTypeLayout> bufferLayout = layout->typeLayout.As<ParameterBlockTypeLayout>(); + assert(bufferLayout); + + RefPtr<StructTypeLayout> structTypeLayout = bufferLayout->elementTypeLayout.As<StructTypeLayout>(); + assert(structTypeLayout); + + emitGLSLLayoutQualifiers(context, layout); + + EmitModifiers(context, varDecl); + + // Emit an apprpriate declaration keyword based on the kind of block + if (parameterBlockType->As<ConstantBufferType>()) + { + Emit(context, "uniform"); + } + else if (parameterBlockType->As<GLSLInputParameterBlockType>()) + { + Emit(context, "in"); + } + else if (parameterBlockType->As<GLSLOutputParameterBlockType>()) + { + Emit(context, "out"); + } + else if (parameterBlockType->As<GLSLShaderStorageBufferType>()) + { + Emit(context, "buffer"); + } + else + { + assert(!"unexpected"); + Emit(context, "uniform"); + } + + if( auto reflectionNameModifier = varDecl->FindModifier<ParameterBlockReflectionName>() ) + { + Emit(context, " "); + emitName(context, reflectionNameModifier->nameToken.Content); + } + + Emit(context, "\n{\n"); + if (auto structRef = declRefType->declRef.As<StructDeclRef>()) + { + for (auto field : structRef.GetMembersOfType<FieldDeclRef>()) + { + RefPtr<VarLayout> fieldLayout; + structTypeLayout->mapVarToLayout.TryGetValue(field.GetDecl(), fieldLayout); + assert(fieldLayout); + + // TODO(tfoley): We may want to emit *some* of these, + // some of the time... +// emitGLSLLayoutQualifiers(context, fieldLayout); + + EmitVarDeclCommon(context, field); + + Emit(context, ";\n"); + } + } + Emit(context, "}"); + + if( varDecl->Name.Type != TokenType::Unknown ) + { + Emit(context, " "); + emitName(context, varDecl->Name.Content); + } + + Emit(context, ";\n"); +} + +static void emitParameterBlockDecl( + EmitContext* context, + RefPtr<VarDeclBase> varDecl, + RefPtr<ParameterBlockType> parameterBlockType, + RefPtr<VarLayout> layout) +{ + switch(context->target) + { + case CodeGenTarget::HLSL: + emitHLSLParameterBlockDecl(context, varDecl, parameterBlockType, layout); + break; + + case CodeGenTarget::GLSL: + emitGLSLParameterBlockDecl(context, varDecl, parameterBlockType, layout); + break; + + default: + assert(!"unexpected"); + break; + } +} + +static void EmitVarDecl(EmitContext* context, RefPtr<VarDeclBase> decl, RefPtr<VarLayout> layout) +{ + // As a special case, a variable using a parameter block type + // will be translated into a declaration using the more primitive + // language syntax. + // + // TODO(tfoley): Be sure to unwrap arrays here, in the GLSL case. + // + // TODO(tfoley): Detect cases where we need to fall back to + // ordinary variable declaration syntax in HLSL. + // + // TODO(tfoley): there might be a better way to detect this, e.g., + // with an attribute that gets attached to the variable declaration. + if (auto parameterBlockType = decl->Type->As<ParameterBlockType>()) + { + emitParameterBlockDecl(context, decl, parameterBlockType, layout); + return; + } + + emitGLSLLayoutQualifiers(context, layout); + + EmitVarDeclCommon(context, decl); + + emitHLSLRegisterSemantics(context, layout); + + Emit(context, ";\n"); +} + +static void EmitParamDecl(EmitContext* context, RefPtr<ParameterSyntaxNode> decl) +{ + EmitVarDeclCommon(context, decl); +} + +static void EmitFuncDecl(EmitContext* context, RefPtr<FunctionSyntaxNode> decl) +{ + EmitModifiers(context, decl); + + // TODO: if a function returns an array type, or something similar that + // isn't allowed by declarator syntax and/or language rules, we could + // hypothetically wrap things in a `typedef` and work around it. + + EmitType(context, decl->ReturnType, decl->Name.Content); + + Emit(context, "("); + bool first = true; + for (auto paramDecl : decl->GetMembersOfType<ParameterSyntaxNode>()) + { + if (!first) Emit(context, ", "); + EmitParamDecl(context, paramDecl); + first = false; + } + Emit(context, ")"); + + EmitSemantics(context, decl); + + if (auto bodyStmt = decl->Body) + { + EmitBlockStmt(context, bodyStmt); + } + else + { + Emit(context, ";\n"); + } +} + +static void emitGLSLPreprocessorDirectives( + EmitContext* context, + RefPtr<ProgramSyntaxNode> program) +{ + switch(context->target) + { + // Don't emit this stuff unless we are targetting GLSL + default: + return; + + case CodeGenTarget::GLSL: + break; + } + + if( auto versionDirective = program->FindModifier<GLSLVersionDirective>() ) + { + // TODO(tfoley): Emit an appropriate `#line` directive... + + Emit(context, "#version "); + emit(context, versionDirective->versionNumberToken.Content); + if(versionDirective->glslProfileToken.Type != TokenType::Unknown) + { + Emit(context, " "); + emit(context, versionDirective->glslProfileToken.Content); + } + Emit(context, "\n"); + } + else + { + // No explicit version was given (probably because we are cross-compiling). + // + // We need to pick an appropriate version, ideally based on the features + // that the shader ends up using. + // + // For now we just fall back to a reasonably recent version. + + Emit(context, "#version 420\n"); + } + + // TODO: when cross-compiling we may need to output additional `#extension` directives + // based on the features that we have used. + + for( auto extensionDirective : program->GetModifiersOfType<GLSLExtensionDirective>() ) + { + // TODO(tfoley): Emit an appropriate `#line` directive... + + Emit(context, "#extension "); + emit(context, extensionDirective->extensionNameToken.Content); + Emit(context, " : "); + emit(context, extensionDirective->dispositionToken.Content); + Emit(context, "\n"); + } + + // TODO: handle other cases... +} + +static void EmitProgram( + EmitContext* context, + RefPtr<ProgramSyntaxNode> program, + RefPtr<ProgramLayout> programLayout) +{ + // There may be global-scope modifiers that we should emit now + emitGLSLPreprocessorDirectives(context, program); + + switch(context->target) + { + case CodeGenTarget::GLSL: + { + // TODO(tfoley): Need a plan for how to enable/disable these as needed... +// Emit(context, "#extension GL_GOOGLE_cpp_style_line_directive : require\n"); + } + break; + + default: + break; + } + + + // Layout information for the global scope is either an ordinary + // `struct` in the common case, or a constant buffer in the case + // where there were global-scope uniforms. + auto globalScopeLayout = programLayout->globalScopeLayout; + if( auto globalStructLayout = globalScopeLayout.As<StructTypeLayout>() ) + { + // The `struct` case is easy enough to handle: we just + // emit all the declarations directly, using their layout + // information as a guideline. + EmitDeclsInContainerUsingLayout(context, program, globalStructLayout); + } + else if(auto globalConstantBufferLayout = globalScopeLayout.As<ParameterBlockTypeLayout>()) + { + // TODO: the `cbuffer` case really needs to be emitted very + // carefully, but that is beyond the scope of what a simple rewriter + // can easily do (without semantic analysis, etc.). + // + // The crux of the problem is that we need to collect all the + // global-scope uniforms (but not declarations that don't involve + // uniform storage...) and put them in a single `cbuffer` declaration, + // so that we can give it an explicit location. The fields in that + // declaration might use various type declarations, so we'd really + // need to emit all the type declarations first, and that involves + // some large scale reorderings. + // + // For now we will punt and just emit the declarations normally, + // and hope that the global-scope block (`$Globals`) gets auto-assigned + // the same location that we manually asigned it. + + auto elementTypeLayout = globalConstantBufferLayout->elementTypeLayout; + auto elementTypeStructLayout = elementTypeLayout.As<StructTypeLayout>(); + + // We expect all constant buffers to contain `struct` types for now + assert(elementTypeStructLayout); + + EmitDeclsInContainerUsingLayout( + context, + program, + elementTypeStructLayout); + } + else + { + assert(!"unexpected"); + } +} + +static void EmitDeclImpl(EmitContext* context, RefPtr<Decl> decl, RefPtr<VarLayout> layout) +{ + // Don't emit code for declarations that came from the stdlib. + // + // TODO(tfoley): We probably need to relax this eventually, + // since different targets might have different sets of builtins. + if (decl->HasModifier<FromStdLibModifier>()) + return; + + if (auto typeDefDecl = decl.As<TypeDefDecl>()) + { + EmitTypeDefDecl(context, typeDefDecl); + return; + } + else if (auto structDecl = decl.As<StructSyntaxNode>()) + { + EmitStructDecl(context, structDecl); + return; + } + else if (auto varDecl = decl.As<VarDeclBase>()) + { + EmitVarDecl(context, varDecl, layout); + return; + } + else if (auto funcDecl = decl.As<FunctionSyntaxNode>()) + { + EmitFuncDecl(context, funcDecl); + return; + } + else if (auto genericDecl = decl.As<GenericDecl>()) + { + // Don't emit generic decls directly; we will only + // ever emit particular instantiations of them. + return; + } + else if (auto classDecl = decl.As<ClassSyntaxNode>()) + { + return; + } + else if( auto emptyDecl = decl.As<EmptyDecl>() ) + { + EmitModifiers(context, emptyDecl); + Emit(context, ";\n"); + return; + } + throw "unimplemented"; +} + +static void EmitDecl(EmitContext* context, RefPtr<Decl> decl) +{ + EmitDeclImpl(context, decl, nullptr); +} + +static void EmitDeclUsingLayout(EmitContext* context, RefPtr<Decl> decl, RefPtr<VarLayout> layout) +{ + EmitDeclImpl(context, decl, layout); +} + +static void EmitDecl(EmitContext* context, RefPtr<DeclBase> declBase) +{ + if( auto decl = declBase.As<Decl>() ) + { + EmitDecl(context, decl); + } + else if(auto declGroup = declBase.As<DeclGroup>()) + { + for(auto d : declGroup->decls) + EmitDecl(context, d); + } + else + { + throw "unimplemented"; + } +} + +static void registerReservedWord( + EmitContext* context, + String const& name) +{ + context->reservedWords.Add(name, name); +} + +static void registerReservedWords( + EmitContext* context) +{ +#define WORD(NAME) registerReservedWord(context, #NAME) + + switch (context->target) + { + case CodeGenTarget::GLSL: + WORD(attribute); + WORD(const); + WORD(uniform); + WORD(varying); + WORD(buffer); + + WORD(shared); + WORD(coherent); + WORD(volatile); + WORD(restrict); + WORD(readonly); + WORD(writeonly); + WORD(atomic_unit); + WORD(layout); + WORD(centroid); + WORD(flat); + WORD(smooth); + WORD(noperspective); + WORD(patch); + WORD(sample); + WORD(break); + WORD(continue); + WORD(do); + WORD(for); + WORD(while); + WORD(switch); + WORD(case); + WORD(default); + WORD(if); + WORD(else); + WORD(subroutine); + WORD(in); + WORD(out); + WORD(inout); + WORD(float); + WORD(double); + WORD(int); + WORD(void); + WORD(bool); + WORD(true); + WORD(false); + WORD(invariant); + WORD(precise); + WORD(discard); + WORD(return); + + WORD(lowp); + WORD(mediump); + WORD(highp); + WORD(precision); + WORD(struct); + WORD(uint); + + WORD(common); + WORD(partition); + WORD(active); + WORD(asm); + WORD(class); + WORD(union); + WORD(enum); + WORD(typedef); + WORD(template); + WORD(this); + WORD(resource); + + WORD(goto); + WORD(inline); + WORD(noinline); + WORD(public); + WORD(static); + WORD(extern); + WORD(external); + WORD(interface); + WORD(long); + WORD(short); + WORD(half); + WORD(fixed); + WORD(unsigned); + WORD(superp); + WORD(input); + WORD(output); + WORD(filter); + WORD(sizeof); + WORD(cast); + WORD(namespace); + WORD(using); + +#define CASE(NAME) \ + WORD(NAME ## 2); WORD(NAME ## 3); WORD(NAME ## 4) + + CASE(mat); + CASE(dmat); + CASE(mat2x); + CASE(mat3x); + CASE(mat4x); + CASE(dmat2x); + CASE(dmat3x); + CASE(dmat4x); + CASE(vec); + CASE(ivec); + CASE(bvec); + CASE(dvec); + CASE(uvec); + CASE(hvec); + CASE(fvec); + +#undef CASE + +#define CASE(NAME) \ + WORD(NAME ## 1D); \ + WORD(NAME ## 2D); \ + WORD(NAME ## 3D); \ + WORD(NAME ## Cube); \ + WORD(NAME ## 1DArray); \ + WORD(NAME ## 2DArray); \ + WORD(NAME ## 3DArray); \ + WORD(NAME ## CubeArray);\ + WORD(NAME ## 2DMS); \ + WORD(NAME ## 2DMSArray) \ + /* end */ + +#define CASE2(NAME) \ + CASE(NAME); \ + CASE(i ## NAME); \ + CASE(u ## NAME) \ + /* end */ + + CASE2(sampler); + CASE2(image); + CASE2(texture); + +#undef CASE2 +#undef CASE + break; + + default: + break; + } +} + +String emitProgram( + ProgramSyntaxNode* program, + ProgramLayout* programLayout, + CodeGenTarget target) +{ + // TODO(tfoley): only emit symbols on-demand, as needed by a particular entry point + + EmitContext context; + context.target = target; + + registerReservedWords(&context); + + EmitProgram(&context, program, programLayout); + + String code = context.sb.ProduceString(); + + return code; + +#if 0 + // HACK(tfoley): Invoke the D3D HLSL compiler on the result, to validate it + +#ifdef _WIN32 + { + HMODULE d3dCompiler = LoadLibraryA("d3dcompiler_47"); + assert(d3dCompiler); + + pD3DCompile D3DCompile_ = (pD3DCompile)GetProcAddress(d3dCompiler, "D3DCompile"); + assert(D3DCompile_); + + ID3DBlob* codeBlob; + ID3DBlob* diagnosticsBlob; + HRESULT hr = D3DCompile_( + code.begin(), + code.Length(), + "slang", + nullptr, + nullptr, + "main", + "ps_5_0", + 0, + 0, + &codeBlob, + &diagnosticsBlob); + if (codeBlob) codeBlob->Release(); + if (diagnosticsBlob) + { + String diagnostics = (char const*) diagnosticsBlob->GetBufferPointer(); + fprintf(stderr, "%s", diagnostics.begin()); + OutputDebugStringA(diagnostics.begin()); + diagnosticsBlob->Release(); + } + if (FAILED(hr)) + { + int f = 9; + } + } + + #include <d3dcompiler.h> +#endif +#endif + +} + + +}} // Slang::Compiler diff --git a/source/slang/emit.h b/source/slang/emit.h new file mode 100644 index 000000000..05ea1550f --- /dev/null +++ b/source/slang/emit.h @@ -0,0 +1,24 @@ +// Emit.h +#ifndef SLANG_EMIT_H_INCLUDED +#define SLANG_EMIT_H_INCLUDED + +#include "../core/basic.h" + +#include "compiler.h" + +namespace Slang +{ + namespace Compiler + { + using namespace CoreLib::Basic; + + class ProgramSyntaxNode; + class ProgramLayout; + + String emitProgram( + ProgramSyntaxNode* program, + ProgramLayout* programLayout, + CodeGenTarget target); + } +} +#endif diff --git a/source/slang/intrinsic-defs.h b/source/slang/intrinsic-defs.h new file mode 100644 index 000000000..19a3899a3 --- /dev/null +++ b/source/slang/intrinsic-defs.h @@ -0,0 +1,94 @@ +// intrinsic-defs.h + +// The file is meant to be included multiple times, to produce different +// pieces of code related to intrinsic operations +// +// Each intrinsic op is declared here with: +// +// INTRINSIC(name) +// + +#ifndef INTRINSIC +#error Need to define INTRINSIC(NAME) before including "intrinsic-defs.h" +#endif + +INTRINSIC(Add) +INTRINSIC(Sub) +INTRINSIC(Mul) +INTRINSIC(Div) +INTRINSIC(Mod) + +INTRINSIC(Lsh) +INTRINSIC(Rsh) + +INTRINSIC(Eql) +INTRINSIC(Neq) +INTRINSIC(Greater) +INTRINSIC(Less) +INTRINSIC(Geq) +INTRINSIC(Leq) +INTRINSIC(BitAnd) +INTRINSIC(BitXor) +INTRINSIC(BitOr) + +// TODO(tfoley): need to distinguish short-circuiting and not... +INTRINSIC(And) +INTRINSIC(Or) + +INTRINSIC(Assign) +INTRINSIC(AddAssign) +INTRINSIC(SubAssign) +INTRINSIC(MulAssign) +INTRINSIC(DivAssign) +INTRINSIC(ModAssign) +INTRINSIC(LshAssign) +INTRINSIC(RshAssign) +INTRINSIC(OrAssign) +INTRINSIC(AndAssign) +INTRINSIC(XorAssign) +INTRINSIC(Neg) +INTRINSIC(Not) +INTRINSIC(BitNot) +INTRINSIC(PreInc) +INTRINSIC(PreDec) +INTRINSIC(PostInc) +INTRINSIC(PostDec) + +INTRINSIC(Sequence) +INTRINSIC(Select) + +INTRINSIC(Mul_Scalar_Scalar) +INTRINSIC(Mul_Vector_Scalar) +INTRINSIC(Mul_Scalar_Vector) +INTRINSIC(Mul_Matrix_Scalar) +INTRINSIC(Mul_Scalar_Matrix) +INTRINSIC(InnerProduct_Vector_Vector) +INTRINSIC(InnerProduct_Vector_Matrix) +INTRINSIC(InnerProduct_Matrix_Vector) +INTRINSIC(InnerProduct_Matrix_Matrix) + + + + + + + + + + + + + + + + + + + + + + + + +// Un-deefine the macor here, so that the client does not have to. +#undef INTRINSIC diff --git a/source/slang/lexer.cpp b/source/slang/lexer.cpp new file mode 100644 index 000000000..7234c4983 --- /dev/null +++ b/source/slang/lexer.cpp @@ -0,0 +1,1012 @@ +#include "Lexer.h" + +#include <assert.h> + +namespace Slang +{ + namespace Compiler + { + static Token GetEndOfFileToken() + { + return Token(TokenType::EndOfFile, "", 0, 0, 0, ""); + } + + Token* TokenList::begin() const + { + assert(mTokens.Count()); + return &mTokens[0]; + } + + Token* TokenList::end() const + { + assert(mTokens.Count()); + assert(mTokens[mTokens.Count()-1].Type == TokenType::EndOfFile); + return &mTokens[mTokens.Count() - 1]; + } + + TokenSpan::TokenSpan() + : mBegin(NULL) + , mEnd (NULL) + {} + + TokenReader::TokenReader() + : mCursor(NULL) + , mEnd (NULL) + {} + + + Token TokenReader::PeekToken() const + { + if (!mCursor) + return GetEndOfFileToken(); + + Token token = *mCursor; + if (mCursor == mEnd) + token.Type = TokenType::EndOfFile; + return token; + } + + TokenType TokenReader::PeekTokenType() const + { + if (mCursor == mEnd) + return TokenType::EndOfFile; + assert(mCursor); + return mCursor->Type; + } + + CodePosition TokenReader::PeekLoc() const + { + if (!mCursor) + return CodePosition(); + assert(mCursor); + return mCursor->Position; + } + + Token TokenReader::AdvanceToken() + { + if (!mCursor) + return GetEndOfFileToken(); + + Token token = *mCursor; + if (mCursor == mEnd) + token.Type = TokenType::EndOfFile; + else + mCursor++; + return token; + } + + // Lexer + + Lexer::Lexer( + String const& path, + String const& content, + DiagnosticSink* sink) + : path(path) + , content(content) + , sink(sink) + { + cursor = content.begin(); + end = content.end(); + + loc = CodePosition(1, 1, 0, path); + tokenFlags = TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; + lexerFlags = 0; + } + + Lexer::~Lexer() + { + } + + enum { kEOF = -1 }; + + static int peek(Lexer* lexer) + { + if(lexer->cursor == lexer->end) + return kEOF; + + return *lexer->cursor; + } + + static int advance(Lexer* lexer) + { + if(lexer->cursor == lexer->end) + return kEOF; + + lexer->loc.Col++; + lexer->loc.Pos++; + + return *lexer->cursor++; + } + + static void handleNewLine(Lexer* lexer) + { + int c = advance(lexer); + assert(c == '\n' || c == '\r'); + + int d = peek(lexer); + if( (c ^ d) == ('\n' ^ '\r') ) + { + advance(lexer); + } + + lexer->loc.Line++; + lexer->loc.Col = 1; + } + + static void lexLineComment(Lexer* lexer) + { + for(;;) + { + switch(peek(lexer)) + { + case '\n': case '\r': case kEOF: + return; + + default: + advance(lexer); + continue; + } + } + } + + static void lexBlockComment(Lexer* lexer) + { + for(;;) + { + switch(peek(lexer)) + { + case kEOF: + // TODO(tfoley) diagnostic! + return; + + case '\n': case '\r': + handleNewLine(lexer); + continue; + + case '*': + advance(lexer); + switch( peek(lexer) ) + { + case '/': + advance(lexer); + return; + + default: + continue; + } + + default: + advance(lexer); + continue; + } + } + } + + static void lexHorizontalSpace(Lexer* lexer) + { + for(;;) + { + switch(peek(lexer)) + { + case ' ': case '\t': + advance(lexer); + continue; + + default: + return; + } + } + } + + static void lexIdentifier(Lexer* lexer) + { + for(;;) + { + int c = peek(lexer); + if(('a' <= c ) && (c <= 'z') + || ('A' <= c) && (c <= 'Z') + || ('0' <= c) && (c <= '9') + || (c == '_')) + { + advance(lexer); + continue; + } + + return; + } + } + + static void lexDigits(Lexer* lexer, int base) + { + for(;;) + { + int c = peek(lexer); + + int digitVal = 0; + switch(c) + { + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + digitVal = c - '0'; + break; + + case 'a': case 'b': case 'c': case 'd': case 'e': case 'f': + if(base <= 10) return; + digitVal = 10 + c - 'a'; + break; + + case 'A': case 'B': case 'C': case 'D': case 'E': case 'F': + if(base <= 10) return; + digitVal = 10 + c - 'A'; + break; + + default: + // Not more digits! + return; + } + + if(digitVal >= base) + { + char buffer[] = { (char) c, 0 }; + lexer->sink->diagnose(lexer->loc, Diagnostics::invalidDigitForBase, buffer, base); + } + + advance(lexer); + } + } + + static TokenType maybeLexNumberSuffix(Lexer* lexer, TokenType tokenType) + { + // First check for suffixes that + // indicate a floating-point number + switch(peek(lexer)) + { + case 'f': case 'F': + advance(lexer); + return TokenType::DoubleLiterial; + + default: + break; + } + + // Once we've ruled out floating-point + // suffixes, we can check for the inter cases + + // TODO: allow integer suffixes in any order... + + // Leading `u` or `U` for unsigned + switch(peek(lexer)) + { + default: + break; + + case 'u': case 'U': + advance(lexer); + break; + } + + // Optional `l`, `L`, `ll`, or `LL` + switch(peek(lexer)) + { + default: + break; + + case 'l': case 'L': + advance(lexer); + switch(peek(lexer)) + { + default: + break; + + case 'l': case 'L': + advance(lexer); + break; + } + break; + } + + return tokenType; + } + + static bool maybeLexNumberExponent(Lexer* lexer, int base) + { + switch( peek(lexer) ) + { + default: + return false; + + case 'e': case 'E': + if(base != 10) return false; + advance(lexer); + break; + + case 'p': case 'P': + if(base != 16) return false; + advance(lexer); + break; + } + + // we saw an exponent marker, so we must + switch( peek(lexer) ) + { + case '+': case '-': + advance(lexer); + break; + } + + // TODO(tfoley): it would be an error to not see digits here... + + lexDigits(lexer, 10); + + return true; + } + + static TokenType lexNumberAfterDecimalPoint(Lexer* lexer, int base) + { + lexDigits(lexer, base); + maybeLexNumberExponent(lexer, base); + + return maybeLexNumberSuffix(lexer, TokenType::DoubleLiterial); + } + + static TokenType lexNumber(Lexer* lexer, int base) + { + // TODO(tfoley): Need to consider whehter to allow any kind of digit separator character. + + TokenType tokenType = TokenType::IntLiterial; + + // At the start of things, we just concern ourselves with digits + lexDigits(lexer, base); + + if( peek(lexer) == '.' ) + { + tokenType = TokenType::DoubleLiterial; + + advance(lexer); + lexDigits(lexer, base); + } + + if( maybeLexNumberExponent(lexer, base)) + { + tokenType = TokenType::DoubleLiterial; + } + + maybeLexNumberSuffix(lexer, tokenType); + return tokenType; + } + + static void lexStringLiteralBody(Lexer* lexer, char quote) + { + for(;;) + { + int c = peek(lexer); + if(c == quote) + { + advance(lexer); + return; + } + + switch(c) + { + case kEOF: + lexer->sink->diagnose(lexer->loc, Diagnostics::endOfFileInLiteral); + return; + + case '\n': case '\r': + lexer->sink->diagnose(lexer->loc, Diagnostics::newlineInLiteral); + return; + + case '\\': + // Need to handle various escape sequence cases + advance(lexer); + switch(peek(lexer)) + { + case '\'': + case '\"': + case '\\': + case '?': + case 'a': + case 'b': + case 'f': + case 'n': + case 'r': + case 't': + case 'v': + advance(lexer); + break; + + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': + // octal escape: up to 3 characters + advance(lexer); + for(int ii = 0; ii < 3; ++ii) + { + int d = peek(lexer); + if(('0' <= d) && (d <= '7')) + { + advance(lexer); + continue; + } + else + { + break; + } + } + break; + + case 'x': + // hexadecimal escape: any number of characters + advance(lexer); + for(;;) + { + int d = peek(lexer); + if(('0' <= d) && (d <= '9') + || ('a' <= d) && (d <= 'f') + || ('A' <= d) && (d <= 'F')) + { + advance(lexer); + continue; + } + else + { + break; + } + } + break; + + // TODO: Unicode escape sequences + + } + break; + + default: + advance(lexer); + continue; + } + } + } + + String getStringLiteralTokenValue(Token const& token) + { + assert(token.Type == TokenType::StringLiterial + || token.Type == TokenType::CharLiterial); + + char const* cursor = token.Content.begin(); + char const* end = token.Content.end(); + + auto quote = *cursor++; + assert(quote == '\'' || quote == '"'); + + StringBuilder valueBuilder; + for(;;) + { + assert(cursor != end); + + auto c = *cursor++; + + // If we see a closing quote, then we are at the end of the string literal + if(c == quote) + { + assert(cursor == end); + return valueBuilder.ProduceString(); + } + + // Charcters that don't being escape sequences are easy; + // just append them to the buffer and move on. + if(c != '\\') + { + valueBuilder.Append(c); + continue; + } + + // Now we look at another character to figure out the kind of + // escape sequence we are dealing with: + + int d = *cursor++; + + switch(d) + { + // Simple characters that just needed to be escaped + case '\'': + case '\"': + case '\\': + case '?': + valueBuilder.Append(d); + continue; + + // Traditional escape sequences for special characters + case 'a': valueBuilder.Append('\a'); continue; + case 'b': valueBuilder.Append('\b'); continue; + case 'f': valueBuilder.Append('\f'); continue; + case 'n': valueBuilder.Append('\n'); continue; + case 'r': valueBuilder.Append('\r'); continue; + case 't': valueBuilder.Append('\t'); continue; + case 'v': valueBuilder.Append('\v'); continue; + + // Octal escape: up to 3 characterws + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': + { + cursor--; + int value = 0; + for(int ii = 0; ii < 3; ++ii) + { + d = *cursor; + if(('0' <= d) && (d <= '7')) + { + value = value*8 + (d - '0'); + + cursor++; + continue; + } + else + { + break; + } + } + + // TODO: add support for appending an arbitrary code point? + valueBuilder.Append((char) value); + } + continue; + + // Hexadecimal escape: any number of characters + case 'x': + { + cursor--; + int value = 0; + for(;;) + { + d = *cursor++; + int digitValue = 0; + if(('0' <= d) && (d <= '9')) + { + digitValue = d - '0'; + } + else if( ('a' <= d) && (d <= 'f') ) + { + digitValue = d - 'a'; + } + else if( ('A' <= d) && (d <= 'F') ) + { + digitValue = d - 'A'; + } + else + { + cursor--; + break; + } + + value = value*16 + digitValue; + } + + // TODO: add support for appending an arbitrary code point? + valueBuilder.Append((char) value); + } + continue; + + // TODO: Unicode escape sequences + + } + } + } + + String getFileNameTokenValue(Token const& token) + { + // A file name usually doesn't process escape sequences + // (this is import on Windows, where `\\` is a valid + // path separator cahracter). + + // Just trim off the first and last characters to remove the quotes + // (whether they were `""` or `<>`. + return token.Content.SubString(1, token.Content.Length()-2); + } + + + + static TokenType lexTokenImpl(Lexer* lexer) + { + switch(peek(lexer)) + { + default: + break; + + case kEOF: + if((lexer->lexerFlags & kLexerFlag_InDirective) != 0) + return TokenType::EndOfDirective; + return TokenType::EndOfFile; + + case '\r': case '\n': + if((lexer->lexerFlags & kLexerFlag_InDirective) != 0) + return TokenType::EndOfDirective; + handleNewLine(lexer); + return TokenType::NewLine; + + case ' ': case '\t': + lexHorizontalSpace(lexer); + return TokenType::WhiteSpace; + + case '.': + advance(lexer); + switch(peek(lexer)) + { + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + return lexNumberAfterDecimalPoint(lexer, 10); + + // TODO(tfoley): handle ellipsis (`...`) + + default: + return TokenType::Dot; + } + + case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + return lexNumber(lexer, 10); + + case '0': + { + auto loc = lexer->loc; + advance(lexer); + switch(peek(lexer)) + { + default: + return TokenType::IntLiterial; + + case '.': + advance(lexer); + return lexNumberAfterDecimalPoint(lexer, 10); + + case 'x': case 'X': + advance(lexer); + return lexNumber(lexer, 16); + + case 'b': case 'B': + advance(lexer); + return lexNumber(lexer, 2); + + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + lexer->sink->diagnose(loc, Diagnostics::octalLiteral); + return lexNumber(lexer, 8); + } + } + + case 'a': case 'b': case 'c': case 'd': case 'e': + case 'f': case 'g': case 'h': case 'i': case 'j': + case 'k': case 'l': case 'm': case 'n': case 'o': + case 'p': case 'q': case 'r': case 's': case 't': + case 'u': case 'v': case 'w': case 'x': case 'y': + case 'z': + case 'A': case 'B': case 'C': case 'D': case 'E': + case 'F': case 'G': case 'H': case 'I': case 'J': + case 'K': case 'L': case 'M': case 'N': case 'O': + case 'P': case 'Q': case 'R': case 'S': case 'T': + case 'U': case 'V': case 'W': case 'X': case 'Y': + case 'Z': + case '_': + lexIdentifier(lexer); + return TokenType::Identifier; + + case '\"': + advance(lexer); + lexStringLiteralBody(lexer, '\"'); + return TokenType::StringLiterial; + + case '\'': + advance(lexer); + lexStringLiteralBody(lexer, '\''); + return TokenType::CharLiterial; + + case '+': + advance(lexer); + switch(peek(lexer)) + { + case '+': advance(lexer); return TokenType::OpInc; + case '=': advance(lexer); return TokenType::OpAddAssign; + default: + return TokenType::OpAdd; + } + + case '-': + advance(lexer); + switch(peek(lexer)) + { + case '-': advance(lexer); return TokenType::OpDec; + case '=': advance(lexer); return TokenType::OpSubAssign; + case '>': advance(lexer); return TokenType::RightArrow; + default: + return TokenType::OpSub; + } + + case '*': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpMulAssign; + default: + return TokenType::OpMul; + } + + case '/': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpDivAssign; + case '/': advance(lexer); lexLineComment(lexer); return TokenType::LineComment; + case '*': advance(lexer); lexBlockComment(lexer); return TokenType::BlockComment; + default: + return TokenType::OpDiv; + } + + case '%': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpModAssign; + default: + return TokenType::OpMod; + } + + case '|': + advance(lexer); + switch(peek(lexer)) + { + case '|': advance(lexer); return TokenType::OpOr; + case '=': advance(lexer); return TokenType::OpOrAssign; + default: + return TokenType::OpBitOr; + } + + case '&': + advance(lexer); + switch(peek(lexer)) + { + case '&': advance(lexer); return TokenType::OpAnd; + case '=': advance(lexer); return TokenType::OpAndAssign; + default: + return TokenType::OpBitAnd; + } + + case '^': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpXorAssign; + default: + return TokenType::OpBitXor; + } + + case '>': + advance(lexer); + switch(peek(lexer)) + { + case '>': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpShrAssign; + default: return TokenType::OpRsh; + } + case '=': advance(lexer); return TokenType::OpGeq; + default: + return TokenType::OpGreater; + } + + case '<': + advance(lexer); + switch(peek(lexer)) + { + case '<': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpShlAssign; + default: return TokenType::OpLsh; + } + case '=': advance(lexer); return TokenType::OpLeq; + default: + return TokenType::OpLess; + } + + case '=': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpEql; + default: + return TokenType::OpAssign; + } + + case '!': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpNeq; + default: + return TokenType::OpNot; + } + + case '#': + advance(lexer); + switch(peek(lexer)) + { + case '#': advance(lexer); return TokenType::PoundPound; + default: + return TokenType::Pound; + } + + case '~': advance(lexer); return TokenType::OpBitNot; + + case ':': advance(lexer); return TokenType::Colon; + case ';': advance(lexer); return TokenType::Semicolon; + case ',': advance(lexer); return TokenType::Comma; + + case '{': advance(lexer); return TokenType::LBrace; + case '}': advance(lexer); return TokenType::RBrace; + case '[': advance(lexer); return TokenType::LBracket; + case ']': advance(lexer); return TokenType::RBracket; + case '(': advance(lexer); return TokenType::LParent; + case ')': advance(lexer); return TokenType::RParent; + + case '?': advance(lexer); return TokenType::QuestionMark; + case '@': advance(lexer); return TokenType::At; + case '$': advance(lexer); return TokenType::Dollar; + + } + + // TODO(tfoley): If we ever wanted to support proper Unicode + // in identifiers, etc., then this would be the right place + // to perform a more expensive dispatch based on the actual + // code point (and not just the first byte). + + { + // If none of the above cases matched, then we have an + // unexpected/invalid character. + + auto loc = lexer->loc; + auto sink = lexer->sink; + int c = advance(lexer); + if(c >= 0x20 && c <= 0x7E) + { + char buffer[] = { (char) c, 0 }; + sink->diagnose(loc, Diagnostics::illegalCharacterPrint, buffer); + } + else + { + // Fallback: print as hexadecimal + sink->diagnose(loc, Diagnostics::illegalCharacterHex, String((unsigned char)c, 16)); + } + + return TokenType::Invalid; + } + } + + Token Lexer::lexToken() + { + auto flags = this->tokenFlags; + for(;;) + { + Token token; + token.Position = loc; + + char const* textBegin = cursor; + + auto tokenType = lexTokenImpl(this); + + // The low-level lexer produces tokens for things we want + // to ignore, such as white space, so we skip them here. + switch(tokenType) + { + case TokenType::Invalid: + flags = 0; + continue; + + case TokenType::NewLine: + flags = TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; + continue; + + case TokenType::WhiteSpace: + case TokenType::LineComment: + case TokenType::BlockComment: + flags |= TokenFlag::AfterWhitespace; + continue; + + // We don't want to skip the end-of-file token, but we *do* + // want to make sure it has appropriate flags to make our life easier + case TokenType::EndOfFile: + flags = TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; + break; + + // We will also do some book-keeping around preprocessor directives here: + // + // If we see a `#` at the start of a line, then we are entering a + // preprocessor directive. + case TokenType::Pound: + if((flags & TokenFlag::AtStartOfLine) != 0) + lexerFlags |= kLexerFlag_InDirective; + break; + // + // And if we saw an end-of-line during a directive, then we are + // now leaving that directive. + // + case TokenType::EndOfDirective: + lexerFlags &= ~kLexerFlag_InDirective; + break; + + default: + break; + } + + token.Type = tokenType; + + char const* textEnd = cursor; + + // Note(tfoley): `StringBuilder::Append()` seems to crash when appending zero bytes + if(textEnd != textBegin) + { + StringBuilder valueBuilder; + valueBuilder.Append(textBegin, int(textEnd - textBegin)); + token.Content = valueBuilder.ProduceString(); + } + + token.flags = flags; + + this->tokenFlags = 0; + + return token; + } + } + + TokenList Lexer::lexAllTokens() + { + TokenList tokenList; + for(;;) + { + Token token = lexToken(); + tokenList.mTokens.Add(token); + + if(token.Type == TokenType::EndOfFile) + return tokenList; + } + } + + + +#if 0 + TokenList Lexer::Parse(const String & fileName, const String & str, DiagnosticSink * sink) + { + TokenList tokenList; + tokenList.mTokens = TokenizeText(fileName, str, [&](TokenizeErrorType errType, CodePosition pos) + { + auto curChar = str[pos.Pos]; + switch (errType) + { + case TokenizeErrorType::InvalidCharacter: + // Check if inside the ASCII "printable" range + if(curChar >= 0x20 && curChar <= 0x7E) + { + char buffer[] = { curChar, 0 }; + sink->diagnose(pos, Diagnostics::illegalCharacterPrint, buffer); + } + else + { + // Fallback: print as hexadecimal + sink->diagnose(pos, Diagnostics::illegalCharacterHex, String((unsigned char)curChar, 16)); + } + break; + case TokenizeErrorType::InvalidEscapeSequence: + sink->diagnose(pos, Diagnostics::illegalCharacterLiteral); + break; + default: + break; + } + }); + + // Add an end-of-file token so that we can reference it in diagnostic messages + tokenList.mTokens.Add(Token(TokenType::EndOfFile, "", 0, 0, 0, fileName, TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace)); + return tokenList; + } +#endif + } +}
\ No newline at end of file diff --git a/source/slang/lexer.h b/source/slang/lexer.h new file mode 100644 index 000000000..d11e92d84 --- /dev/null +++ b/source/slang/lexer.h @@ -0,0 +1,101 @@ +#ifndef RASTER_RENDERER_LEXER_H +#define RASTER_RENDERER_LEXER_H + +#include "../core/basic.h" +#include "diagnostics.h" + +namespace Slang +{ + namespace Compiler + { + using namespace CoreLib::Basic; + + struct TokenList + { + Token* begin() const; + Token* end() const; + + List<Token> mTokens; + }; + + struct TokenSpan + { + TokenSpan(); + TokenSpan( + TokenList const& tokenList) + : mBegin(tokenList.begin()) + , mEnd (tokenList.end ()) + {} + + Token* begin() const { return mBegin; } + Token* end () const { return mEnd ; } + + int GetCount() { return (int)(mEnd - mBegin); } + + Token* mBegin; + Token* mEnd; + }; + + struct TokenReader + { + TokenReader(); + explicit TokenReader(TokenSpan const& tokens) + : mCursor(tokens.begin()) + , mEnd (tokens.end ()) + {} + explicit TokenReader(TokenList const& tokens) + : mCursor(tokens.begin()) + , mEnd (tokens.end ()) + {} + + bool IsAtEnd() const { return mCursor == mEnd; } + Token PeekToken() const; + TokenType PeekTokenType() const; + CodePosition PeekLoc() const; + + Token AdvanceToken(); + + int GetCount() { return (int)(mEnd - mCursor); } + + Token* mCursor; + Token* mEnd; + }; + + typedef unsigned int LexerFlags; + enum + { + kLexerFlag_InDirective = 1 << 0, + kLexerFlag_ExpectFileName = 2 << 0, + }; + + struct Lexer + { + Lexer( + String const& path, + String const& content, + DiagnosticSink* sink); + + ~Lexer(); + + Token lexToken(); + + TokenList lexAllTokens(); + + String path; + String content; + DiagnosticSink* sink; + + char const* cursor; + char const* end; + CodePosition loc; + TokenFlags tokenFlags; + LexerFlags lexerFlags; + }; + + // Helper routines for extracting values from tokens + String getStringLiteralTokenValue(Token const& token); + String getFileNameTokenValue(Token const& token); + } +} + +#endif
\ No newline at end of file diff --git a/source/slang/lookup.cpp b/source/slang/lookup.cpp new file mode 100644 index 000000000..9731b1c8a --- /dev/null +++ b/source/slang/lookup.cpp @@ -0,0 +1,311 @@ +// lookup.cpp +#include "lookup.h" + +namespace Slang { +namespace Compiler { + +// + +// Helper for constructing breadcrumb trails during lookup, without unnecessary heap allocaiton +struct BreadcrumbInfo +{ + LookupResultItem::Breadcrumb::Kind kind; + DeclRef declRef; + BreadcrumbInfo* prev = nullptr; +}; + +void DoLocalLookupImpl( + String const& name, + ContainerDeclRef containerDeclRef, + LookupRequest const& request, + LookupResult& result, + BreadcrumbInfo* inBreadcrumbs); + +// + +void buildMemberDictionary(ContainerDecl* decl) +{ + // Don't rebuild if already built + if (decl->memberDictionaryIsValid) + return; + + decl->memberDictionary.Clear(); + decl->transparentMembers.Clear(); + + for (auto m : decl->Members) + { + auto name = m->Name.Content; + + // Add any transparent members to a separate list for lookup + if (m->HasModifier<TransparentModifier>()) + { + TransparentMemberInfo info; + info.decl = m.Ptr(); + decl->transparentMembers.Add(info); + } + + // Ignore members with an empty name + if (name.Length() == 0) + continue; + + m->nextInContainerWithSameName = nullptr; + + Decl* next = nullptr; + if (decl->memberDictionary.TryGetValue(name, next)) + m->nextInContainerWithSameName = next; + + decl->memberDictionary[name] = m.Ptr(); + + } + decl->memberDictionaryIsValid = true; +} + + +bool DeclPassesLookupMask(Decl* decl, LookupMask mask) +{ + // type declarations + if(auto aggTypeDecl = dynamic_cast<AggTypeDecl*>(decl)) + { + return int(mask) & int(LookupMask::Type); + } + else if(auto simpleTypeDecl = dynamic_cast<SimpleTypeDecl*>(decl)) + { + return int(mask) & int(LookupMask::Type); + } + // function declarations + else if(auto funcDecl = dynamic_cast<FunctionDeclBase*>(decl)) + { + return (int(mask) & int(LookupMask::Function)) != 0; + } + + // default behavior is to assume a value declaration + // (no overloading allowed) + + return (int(mask) & int(LookupMask::Value)) != 0; +} + +void AddToLookupResult( + LookupResult& result, + LookupResultItem item) +{ + if (!result.isValid()) + { + // If we hadn't found a hit before, we have one now + result.item = item; + } + else if (!result.isOverloaded()) + { + // We are about to make this overloaded + result.items.Add(result.item); + result.items.Add(item); + } + else + { + // The result was already overloaded, so we pile on + result.items.Add(item); + } +} + +LookupResult refineLookup(LookupResult const& inResult, LookupMask mask) +{ + if (!inResult.isValid()) return inResult; + if (!inResult.isOverloaded()) return inResult; + + LookupResult result; + for (auto item : inResult.items) + { + if (!DeclPassesLookupMask(item.declRef.GetDecl(), mask)) + continue; + + AddToLookupResult(result, item); + } + return result; +} + +LookupResultItem CreateLookupResultItem( + DeclRef declRef, + BreadcrumbInfo* breadcrumbInfos) +{ + LookupResultItem item; + item.declRef = declRef; + + // breadcrumbs were constructed "backwards" on the stack, so we + // reverse them here by building a linked list the other way + RefPtr<LookupResultItem::Breadcrumb> breadcrumbs; + for (auto bb = breadcrumbInfos; bb; bb = bb->prev) + { + breadcrumbs = new LookupResultItem::Breadcrumb( + bb->kind, + bb->declRef, + breadcrumbs); + } + item.breadcrumbs = breadcrumbs; + return item; +} + +void DoMemberLookupImpl( + String const& name, + RefPtr<ExpressionType> baseType, + LookupRequest const& request, + LookupResult& ioResult, + BreadcrumbInfo* breadcrumbs) +{ + // If the type was pointer-like, then dereference it + // automatically here. + if (auto pointerLikeType = baseType->As<PointerLikeType>()) + { + // Need to leave a breadcrumb to indicate that we + // did an implicit dereference here + BreadcrumbInfo derefBreacrumb; + derefBreacrumb.kind = LookupResultItem::Breadcrumb::Kind::Deref; + derefBreacrumb.prev = breadcrumbs; + + // Recursively perform lookup on the result of deref + return DoMemberLookupImpl(name, pointerLikeType->elementType, request, ioResult, &derefBreacrumb); + } + + // Default case: no dereference needed + + if (auto baseDeclRefType = baseType->As<DeclRefType>()) + { + if (auto baseAggTypeDeclRef = baseDeclRefType->declRef.As<AggTypeDeclRef>()) + { + DoLocalLookupImpl(name, baseAggTypeDeclRef, request, ioResult, breadcrumbs); + } + } + + // TODO(tfoley): any other cases to handle here? +} + +void DoMemberLookupImpl( + String const& name, + DeclRef baseDeclRef, + LookupRequest const& request, + LookupResult& ioResult, + BreadcrumbInfo* breadcrumbs) +{ + auto baseType = getTypeForDeclRef(baseDeclRef); + return DoMemberLookupImpl(name, baseType, request, ioResult, breadcrumbs); +} + +// Look for members of the given name in the given container for declarations +void DoLocalLookupImpl( + String const& name, + ContainerDeclRef containerDeclRef, + LookupRequest const& request, + LookupResult& result, + BreadcrumbInfo* inBreadcrumbs) +{ + ContainerDecl* containerDecl = containerDeclRef.GetDecl(); + + // Ensure that the lookup dictionary in the container is up to date + if (!containerDecl->memberDictionaryIsValid) + { + buildMemberDictionary(containerDecl); + } + + // Look up the declarations with the chosen name in the container. + Decl* firstDecl = nullptr; + containerDecl->memberDictionary.TryGetValue(name, firstDecl); + + // Now iterate over those declarations (if any) and see if + // we find any that meet our filtering criteria. + // For example, we might be filtering so that we only consider + // type declarations. + for (auto m = firstDecl; m; m = m->nextInContainerWithSameName) + { + if (!DeclPassesLookupMask(m, request.mask)) + continue; + + // The declaration passed the test, so add it! + AddToLookupResult(result, CreateLookupResultItem(DeclRef(m, containerDeclRef.substitutions), inBreadcrumbs)); + } + + + // TODO(tfoley): should we look up in the transparent decls + // if we already has a hit in the current container? + + for(auto transparentInfo : containerDecl->transparentMembers) + { + // The reference to the transparent member should use whatever + // substitutions we used in referring to its outer container + DeclRef transparentMemberDeclRef(transparentInfo.decl, containerDeclRef.substitutions); + + // We need to leave a breadcrumb so that we know that the result + // of lookup involves a member lookup step here + + BreadcrumbInfo memberRefBreadcrumb; + memberRefBreadcrumb.kind = LookupResultItem::Breadcrumb::Kind::Member; + memberRefBreadcrumb.declRef = transparentMemberDeclRef; + memberRefBreadcrumb.prev = inBreadcrumbs; + + DoMemberLookupImpl(name, transparentMemberDeclRef, request, result, &memberRefBreadcrumb); + } + + // TODO(tfoley): need to consider lookup via extension here? +} + +void DoLookupImpl( + String const& name, + LookupRequest const& request, + LookupResult& result) +{ + auto scope = request.scope; + auto endScope = request.endScope; + for (;scope != endScope; scope = scope->parent) + { + // Note that we consider all "peer" scopes together, + // so that a hit in one of them does not proclude + // also finding a hit in another + for(auto link = scope; link; link = link->nextSibling) + { + if(!link->containerDecl) + continue; + + ContainerDeclRef containerRef = DeclRef(link->containerDecl, nullptr).As<ContainerDeclRef>(); + DoLocalLookupImpl(name, containerRef, request, result, nullptr); + } + + if (result.isValid()) + { + // If we've found a result in this scope, then there + // is no reason to look further up (for now). + return; + } + } + + // If we run out of scopes, then we are done. +} + +LookupResult DoLookup(String const& name, LookupRequest const& request) +{ + LookupResult result; + DoLookupImpl(name, request, result); + return result; +} + +LookupResult LookUp(String const& name, RefPtr<Scope> scope) +{ + LookupRequest request; + request.scope = scope; + return DoLookup(name, request); +} + +// perform lookup within the context of a particular container declaration, +// and do *not* look further up the chain +LookupResult LookUpLocal(String const& name, ContainerDeclRef containerDeclRef) +{ + LookupRequest request; + LookupResult result; + DoLocalLookupImpl(name, containerDeclRef, request, result, nullptr); + return result; +} + +LookupResult LookUpLocal(String const& name, ContainerDecl* containerDecl) +{ + ContainerDeclRef containerRef = DeclRef(containerDecl, nullptr).As<ContainerDeclRef>(); + return LookUpLocal(name, containerRef); +} + + +}} diff --git a/source/slang/lookup.h b/source/slang/lookup.h new file mode 100644 index 000000000..25b62738f --- /dev/null +++ b/source/slang/lookup.h @@ -0,0 +1,41 @@ +#ifndef SLANG_LOOKUP_H_INCLUDED +#define SLANG_LOOKUP_H_INCLUDED + +#include "Syntax.h" + +namespace Slang { +namespace Compiler { + +// Take an existing lookup result and refine it to only include +// results that pass the given `LookupMask`. +LookupResult refineLookup(LookupResult const& inResult, LookupMask mask); + +// Ensure that the dictionary for name-based member lookup has been +// built for the given container declaration. +void buildMemberDictionary(ContainerDecl* decl); + +// Look up a name in the given scope, proceeding up through +// parent scopes as needed. +LookupResult LookUp(String const& name, RefPtr<Scope> scope); + +// perform lookup within the context of a particular container declaration, +// and do *not* look further up the chain +LookupResult LookUpLocal(String const& name, ContainerDeclRef containerDeclRef); +LookupResult LookUpLocal(String const& name, ContainerDecl* containerDecl); + +// TODO: this belongs somewhere else + +class SemanticsVisitor; +QualType getTypeForDeclRef( + SemanticsVisitor* sema, + DiagnosticSink* sink, + DeclRef declRef, + RefPtr<ExpressionType>* outTypeResult); + +QualType getTypeForDeclRef( + DeclRef declRef); + + +}} + +#endif
\ No newline at end of file diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp new file mode 100644 index 000000000..8bbb566af --- /dev/null +++ b/source/slang/parameter-binding.cpp @@ -0,0 +1,1252 @@ +// parameter-binding.cpp +#include "parameter-binding.h" + +#include "lookup.h" +#include "compiler.h" +#include "type-layout.h" + +#include "../../slang.h" + +#define SLANG_EXHAUSTIVE_SWITCH() default: assert(!"unexpected"); break; + +namespace Slang { +namespace Compiler { + +// Information on ranges of registers already claimed/used +struct UsedRange +{ + int begin; + int end; +}; +bool operator<(UsedRange left, UsedRange right) +{ + if (left.begin != right.begin) + return left.begin < right.begin; + if (left.end != right.end) + return left.end < right.end; + return false; +} + +struct UsedRanges +{ + List<UsedRange> ranges; + + // Add a range to the set, either by extending + // an existing range, or by adding a new one... + void Add(UsedRange const& range) + { + for (auto& rr : ranges) + { + if (rr.begin == range.end) + { + rr.begin = range.begin; + return; + } + else if (rr.end == range.begin) + { + rr.end = range.end; + return; + } + } + ranges.Add(range); + ranges.Sort(); + } + + void Add(int begin, int end) + { + UsedRange range; + range.begin = begin; + range.end = end; + Add(range); + } + + + // Try to find space for `count` entries + int Allocate(int count) + { + int begin = 0; + + int rangeCount = ranges.Count(); + for (int rr = 0; rr < rangeCount; ++rr) + { + // try to fit in before this range... + + int end = ranges[rr].begin; + + // If there is enough space... + if (end >= begin + count) + { + // ... then claim it and be done + Add(begin, begin + count); + return begin; + } + + // ... otherwise, we need to look at the + // space between this range and the next + begin = ranges[rr].end; + } + + // We've run out of ranges to check, so we + // can safely go after the last one! + Add(begin, begin + count); + return begin; + } +}; + +struct ParameterBindingInfo +{ + size_t space; + size_t index; + size_t count; +}; + +enum +{ + kLayoutResourceKindCount = SLANG_PARAMETER_CATEGORY_MIXED, +}; + +// Information on a single parameter +struct ParameterInfo : RefObject +{ + // Layout info for the concrete variables that will make up this parameter + List<RefPtr<VarLayout>> varLayouts; + + ParameterBindingInfo bindingInfo[kLayoutResourceKindCount]; + + // The next parameter that has the same name... + ParameterInfo* nextOfSameName; + + ParameterInfo() + { + // Make sure we aren't claiming any resources yet + for( int ii = 0; ii < kLayoutResourceKindCount; ++ii ) + { + bindingInfo[ii].count = 0; + } + } +}; + +// State that is shared during parameter binding, +// across all translation units +struct SharedParameterBindingContext +{ + LayoutRulesFamilyImpl* defaultLayoutRules; + + // All shader parameters we've discovered so far, and started to lay out... + List<RefPtr<ParameterInfo>> parameters; + + // A dictionary to accellerate looking up parameters by name + Dictionary<String, ParameterInfo*> mapNameToParameterInfo; + + // The program layout we are trying to construct + RefPtr<ProgramLayout> programLayout; + + // The source language we are trying to use + SourceLanguage sourceLanguage; + + // Information on what ranges of "registers" have already + // been claimed, for each resource type + UsedRanges usedResourceRanges[kLayoutResourceKindCount]; +}; + +// State that might be specific to a single translation unit +// or event to an entry point. +struct ParameterBindingContext +{ + // All the shared state needs to be available + SharedParameterBindingContext* shared; + + // The layout rules to use while computing usage... + LayoutRulesFamilyImpl* layoutRules; + + // What stage (if any) are we compiling for? + Stage stage; +}; + +struct LayoutSemanticInfo +{ + LayoutResourceKind kind; // the register kind + int space; + int index; + + // TODO: need to deal with component-granularity binding... +}; + +LayoutSemanticInfo ExtractLayoutSemanticInfo( + ParameterBindingContext* /*context*/, + HLSLLayoutSemantic* semantic) +{ + LayoutSemanticInfo info; + info.space = 0; + info.index = 0; + info.kind = LayoutResourceKind::None; + + auto registerName = semantic->registerName.Content; + if (registerName.Length() == 0) + return info; + + LayoutResourceKind kind = LayoutResourceKind::None; + switch (registerName[0]) + { + case 'b': + kind = LayoutResourceKind::ConstantBuffer; + break; + + case 't': + kind = LayoutResourceKind::ShaderResource; + break; + + case 'u': + kind = LayoutResourceKind::UnorderedAccess; + break; + + case 's': + kind = LayoutResourceKind::SamplerState; + break; + + default: + // TODO: issue an error here! + return info; + } + + // TODO: need to parse and handle `space` binding + int space = 0; + + int index = 0; + for (int ii = 1; ii < registerName.Length(); ++ii) + { + int c = registerName[ii]; + if (c >= '0' && c <= '9') + { + index = index * 10 + (c - '0'); + } + else + { + // TODO: issue an error here! + return info; + } + } + + // TODO: handle component mask part of things... + + info.kind = kind; + info.index = index; + info.space = space; + return info; +} + +static bool doesParameterMatch( + ParameterBindingContext* context, + RefPtr<VarLayout> varLayout, + ParameterInfo* parameterInfo) +{ + // TODO: need to implement this eventually + return true; +} + +// + +// Given a GLSL `layout` modifier, we need to be able to check for +// a particular sub-argument and extract its value if present. +template<typename T> +static bool findLayoutArg( + RefPtr<ModifiableSyntaxNode> syntax, + int* outVal) +{ + for( auto modifier : syntax->GetModifiersOfType<T>() ) + { + *outVal = (int) strtol(modifier->valToken.Content.Buffer(), nullptr, 10); + return true; + } + return false; +} + +template<typename T> +static bool findLayoutArg( + DeclRef declRef, + int* outVal) +{ + return findLayoutArg<T>(declRef.GetDecl(), outVal); +} + +// + +RefPtr<TypeLayout> +getTypeLayoutForGlobalShaderParameter_GLSL( + ParameterBindingContext* context, + VarDeclBase* varDecl) +{ + auto rules = context->layoutRules; + auto type = varDecl->getType(); + + // A GLSL shader parameter will be marked with + // a qualifier to match the boundary it uses + // + // In the case of a parameter block, we will have + // consumed this qualifier as part of parsing, + // so that it won't be present on the declaration + // any more. As such we also inspect the type + // of the variable. + + // TODO(tfoley): We have multiple variations of + // the `uniform` modifier right now, and that + // needs to get fixed... + if(varDecl->HasModifier<HLSLUniformModifier>() || type->As<ConstantBufferType>()) + return CreateTypeLayout(type, rules->getConstantBufferRules()); + + if(varDecl->HasModifier<GLSLBufferModifier>() || type->As<GLSLShaderStorageBufferType>()) + return CreateTypeLayout(type, rules->getShaderStorageBufferRules()); + + if( varDecl->HasModifier<InModifier>() || type->As<GLSLInputParameterBlockType>()) + { + // Special case to handle "arrayed" shader inputs, as used + // for Geometry and Hull input + switch( context->stage ) + { + case Stage::Geometry: + case Stage::Hull: + case Stage::Domain: + // Tessellation `patch` variables should stay as written + if( !varDecl->HasModifier<GLSLPatchModifier>() ) + { + // Unwrap array type, if prsent + if( auto arrayType = type->As<ArrayExpressionType>() ) + { + type = arrayType->BaseType.Ptr(); + } + } + break; + + default: + break; + } + + return CreateTypeLayout(type, rules->getVaryingInputRules()); + } + + if( varDecl->HasModifier<OutModifier>() || type->As<GLSLOutputParameterBlockType>()) + { + // Special case to handle "arrayed" shader outputs, as used + // for Hull Shader output + // + // Note(tfoley): there is unfortunate code duplication + // with the `in` case above. + switch( context->stage ) + { + case Stage::Hull: + // Tessellation `patch` variables should stay as written + if( !varDecl->HasModifier<GLSLPatchModifier>() ) + { + // Unwrap array type, if prsent + if( auto arrayType = type->As<ArrayExpressionType>() ) + { + type = arrayType->BaseType.Ptr(); + } + } + break; + + default: + break; + } + + return CreateTypeLayout(type, rules->getVaryingOutputRules()); + } + + // A `const` global with a `layout(constant_id = ...)` modifier + // is a declaration of a specialization constant. + if(varDecl->HasModifier<GLSLConstantIDLayoutModifier>()) + return CreateTypeLayout(type, rules->getSpecializationConstantRules()); + + // GLSL says that an "ordinary" global variable + // is just a (thread local) global and not a + // parameter + return nullptr; +} + +RefPtr<TypeLayout> +getTypeLayoutForGlobalShaderParameter_HLSL( + ParameterBindingContext* context, + VarDeclBase* varDecl) +{ + auto rules = context->layoutRules; + auto type = varDecl->getType(); + + // HLSL `static` modifier indicates "thread local" + if(varDecl->HasModifier<HLSLStaticModifier>()) + return nullptr; + + // HLSL `groupshared` modifier indicates "thread-group local" + if(varDecl->HasModifier<HLSLGroupSharedModifier>()) + return nullptr; + + // TODO(tfoley): there may be other cases that we need to handle here + + // An "ordinary" global variable is implicitly a uniform + // shader parameter. + return CreateTypeLayout(type, rules->getConstantBufferRules()); +} + +// Determine how to lay out a global variable that might be +// a shader parameter. +// Returns `nullptr` if the declaration does not represent +// a shader parameter. + +RefPtr<TypeLayout> +getTypeLayoutForGlobalShaderParameter( + ParameterBindingContext* context, + VarDeclBase* varDecl) +{ + auto rules = context->layoutRules; + switch( context->shared->sourceLanguage ) + { + case SourceLanguage::Slang: + case SourceLanguage::HLSL: + return getTypeLayoutForGlobalShaderParameter_HLSL(context, varDecl); + + case SourceLanguage::GLSL: + return getTypeLayoutForGlobalShaderParameter_GLSL(context, varDecl); + + default: + assert(false); + return nullptr; + } +} + + +// + + + +// Collect a single declaration into our set of parameters +static void collectGlobalScopeParameter( + ParameterBindingContext* context, + RefPtr<VarDeclBase> varDecl) +{ + // We use a single operation to both check whether the + // variable represents a shader parameter, and to compute + // the layout for that parameter's type. + auto typeLayout = getTypeLayoutForGlobalShaderParameter( + context, + varDecl.Ptr()); + + // If we did not find appropriate layout rules, then it + // must mean that this global variable is *not* a shader + // parameter. + if(!typeLayout) + return; + + // Now create a variable layout that we can use + RefPtr<VarLayout> varLayout = new VarLayout(); + varLayout->typeLayout = typeLayout; + varLayout->varDecl = DeclRef(varDecl.Ptr(), nullptr).As<VarDeclBaseRef>(); + + // This declaration may represent the same logical parameter + // as a declaration that came from a different translation unit. + // If that is the case, we want to re-use the same `VarLayout` + // across both parameters. + // + // First we look for an existing entry matching the name + // of this parameter: + auto parameterName = varDecl->Name.Content; + ParameterInfo* parameterInfo = nullptr; + if( context->shared->mapNameToParameterInfo.TryGetValue(parameterName, parameterInfo) ) + { + // If the parameters have the same name, but don't "match" according to some reasonable rules, + // then we need to bail out. + if( !doesParameterMatch(context, varLayout, parameterInfo) ) + { + parameterInfo = nullptr; + } + } + + // If we didn't find a matching parameter, then we need to create one here + if( !parameterInfo ) + { + parameterInfo = new ParameterInfo(); + context->shared->parameters.Add(parameterInfo); + context->shared->mapNameToParameterInfo.Add(parameterName, parameterInfo); + } + else + { + varLayout->flags |= VarLayoutFlag::IsRedeclaration; + } + + // Add this variable declaration to the list of declarations for the parameter + parameterInfo->varLayouts.Add(varLayout); +} + +static void addExplicitParameterBinding( + ParameterBindingContext* context, + RefPtr<ParameterInfo> parameterInfo, + LayoutSemanticInfo const& semanticInfo, + int count) +{ + auto kind = semanticInfo.kind; + + auto& bindingInfo = parameterInfo->bindingInfo[(int)kind]; + if( bindingInfo.count != 0 ) + { + // We already have a binding here, so we want to + // confirm that it matches the new one that is + // incoming... + if( bindingInfo.count != count + || bindingInfo.index != semanticInfo.index + || bindingInfo.space != semanticInfo.space ) + { + // TODO: diagnose! + } + + // TODO(tfoley): `register` semantics can technically be + // profile-specific (not sure if anybody uses that)... + } + else + { + bindingInfo.count = count; + bindingInfo.index = semanticInfo.index; + bindingInfo.space = semanticInfo.space; + + // If things are bound in `space0` (the default), then we need + // to lay claim to the register range used, so that automatic + // assignment doesn't go and use the same registers. + if (semanticInfo.space == 0) + { + context->shared->usedResourceRanges[(int)semanticInfo.kind].Add( + semanticInfo.index, + semanticInfo.index + count); + } + } +} + +static void addExplicitParameterBindings_HLSL( + ParameterBindingContext* context, + RefPtr<ParameterInfo> parameterInfo, + RefPtr<VarLayout> varLayout) +{ + auto typeLayout = varLayout->typeLayout; + auto varDecl = varLayout->varDecl; + + // If the declaration has explicit binding modifiers, then + // here is where we want to extract and apply them... + + // Look for HLSL `register` or `packoffset` semantics. + for (auto semantic : varDecl.GetDecl()->GetModifiersOfType<HLSLLayoutSemantic>()) + { + // Need to extract the information encoded in the semantic + LayoutSemanticInfo semanticInfo = ExtractLayoutSemanticInfo(context, semantic); + auto kind = semanticInfo.kind; + if (kind == LayoutResourceKind::None) + continue; + + // TODO: need to special-case when this is a `c` register binding... + + // Find the appropriate resource-binding information + // inside the type, to see if we even use any resources + // of the given kind. + + auto typeRes = typeLayout->FindResourceInfo(kind); + int count = 0; + if (typeRes) + { + count = (int) typeRes->count; + } + else + { + // TODO: warning here! + } + + addExplicitParameterBinding(context, parameterInfo, semanticInfo, count); + } +} + +static void addExplicitParameterBindings_GLSL( + ParameterBindingContext* context, + RefPtr<ParameterInfo> parameterInfo, + RefPtr<VarLayout> varLayout) +{ + auto typeLayout = varLayout->typeLayout; + auto varDecl = varLayout->varDecl; + + // The catch in GLSL is that the expected resource type + // is implied by the parameter declaration itself, and + // the `layout` modifier is only allowed to adjust + // the index/offset/etc. + // + + TypeLayout::ResourceInfo* resInfo = nullptr; + LayoutSemanticInfo semanticInfo; + semanticInfo.index = 0; + semanticInfo.space = 0; + if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::DescriptorTableSlot)) ) + { + // Try to find `binding` and `set` + if(!findLayoutArg<GLSLBindingLayoutModifier>(varDecl, &semanticInfo.index)) + return; + + findLayoutArg<GLSLSetLayoutModifier>(varDecl, &semanticInfo.space); + } + else if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::VertexInput)) ) + { + // Try to find `location` binding + if(!findLayoutArg<GLSLLocationLayoutModifier>(varDecl, &semanticInfo.index)) + return; + } + else if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::FragmentOutput)) ) + { + // Try to find `location` binding + if(!findLayoutArg<GLSLLocationLayoutModifier>(varDecl, &semanticInfo.index)) + return; + } + else if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::SpecializationConstant)) ) + { + // Try to find `constant_id` binding + if(!findLayoutArg<GLSLConstantIDLayoutModifier>(varDecl, &semanticInfo.index)) + return; + } + + // If we didn't find any matches, then bail + if(!resInfo) + return; + + auto kind = resInfo->kind; + auto count = resInfo->count; + semanticInfo.kind = kind; + + addExplicitParameterBinding(context, parameterInfo, semanticInfo, int(count)); +} + +// Given a single parameter, collect whatever information we have on +// how it has been explicitly bound, which may come from multiple declarations +void generateParameterBindings( + ParameterBindingContext* context, + RefPtr<ParameterInfo> parameterInfo) +{ + // There must be at least one declaration for the parameter. + assert(parameterInfo->varLayouts.Count() != 0); + + // Iterate over all declarations looking for explicit binding information. + for( auto& varLayout : parameterInfo->varLayouts ) + { + // Handle HLSL `register` and `packoffset` modifiers + addExplicitParameterBindings_HLSL(context, parameterInfo, varLayout); + + + // Handle GLSL `layout` modifiers + addExplicitParameterBindings_GLSL(context, parameterInfo, varLayout); + } +} + +// Generate the binding information for a shader parameter. +static void completeBindingsForParameter( + ParameterBindingContext* context, + RefPtr<ParameterInfo> parameterInfo) +{ + // For any resource kind used by the parameter + // we need to update its layout information + // to include a binding for that resource kind. + // + // We will use the first declaration of the parameter as + // a stand-in for all the declarations, so it is important + // that earlier code has validated that the declarations + // "match". + + assert(parameterInfo->varLayouts.Count() != 0); + auto firstVarLayout = parameterInfo->varLayouts.First(); + auto firstTypeLayout = firstVarLayout->typeLayout; + + for(auto typeRes : firstTypeLayout->resourceInfos) + { + // Did we already apply some explicit binding information + // for this resource kind? + auto kind = typeRes.kind; + auto& bindingInfo = parameterInfo->bindingInfo[(int)kind]; + if( bindingInfo.count != 0 ) + { + // If things have already been bound, our work is done. + continue; + } + + auto count = typeRes.count; + bindingInfo.count = count; + bindingInfo.index = context->shared->usedResourceRanges[(int)kind].Allocate((int) count); + + // For now we only auto-generate bindings in space zero + bindingInfo.space = 0; + } + + // At this point we should have explicit binding locations chosen for + // all the relevant resource kinds, so we can apply these to the + // declarations: + + for(auto& varLayout : parameterInfo->varLayouts) + { + for(auto k = 0; k < kLayoutResourceKindCount; ++k) + { + auto kind = LayoutResourceKind(k); + auto& bindingInfo = parameterInfo->bindingInfo[k]; + + // skip resources we aren't consuming + if(bindingInfo.count == 0) + continue; + + // Add a record to the variable layout + auto varRes = varLayout->AddResourceInfo(kind); + varRes->space = (int) bindingInfo.space; + varRes->index = (int) bindingInfo.index; + } + } +} + +static void collectGlobalScopeParameters( + ParameterBindingContext* context, + ProgramSyntaxNode* program) +{ + // First enumerate parameters at global scope + for( auto decl : program->Members ) + { + // A shader parameter is always a variable, + // so skip declarations that aren't variables. + auto varDecl = decl.As<VarDeclBase>(); + if (!varDecl) + continue; + + collectGlobalScopeParameter(context, varDecl); + } + + // Next, we need to enumerate the parameters of + // each entry point (which requires knowing what the + // entry points *are*) + + // TODO(tfoley): Entry point functions should be identified + // by looking for a generated modifier that is attached + // to global-scope function declarations. +} + +struct SimpleSemanticInfo +{ + String name; + int index; +}; + +SimpleSemanticInfo decomposeSimpleSemantic( + HLSLSimpleSemantic* semantic) +{ + auto composedName = semantic->name.Content; + + // look for a trailing sequence of decimal digits + // at the end of the composed name + int length = composedName.Length(); + int indexLoc = length; + while( indexLoc > 0 ) + { + auto c = composedName[indexLoc-1]; + if( c >= '0' && c <= '9' ) + { + indexLoc--; + continue; + } + else + { + break; + } + } + + SimpleSemanticInfo info; + + // + if( indexLoc == length ) + { + // No index suffix + info.name = composedName; + info.index = 0; + } + else + { + // The name is everything before the digits + info.name = composedName.SubString(0, indexLoc); + info.index = strtol(composedName.SubString(indexLoc, length - indexLoc).begin(), nullptr, 10); + } + return info; +} + +enum class EntryPointParameterDirection +{ + Input, + Output, +}; + +struct EntryPointParameterState +{ + String* optSemanticName; + int* ioSemanticIndex; + EntryPointParameterDirection direction; + int semanticSlotCount; +}; + +static void processSimpleEntryPointInput( + ParameterBindingContext* context, + RefPtr<ExpressionType> type, + EntryPointParameterState const& state) +{ + auto optSemanticName = state.optSemanticName; + auto semanticIndex = *state.ioSemanticIndex; + auto semanticSlotCount = state.semanticSlotCount; +} + +static void processSimpleEntryPointOutput( + ParameterBindingContext* context, + RefPtr<ExpressionType> type, + EntryPointParameterState const& state) +{ + auto optSemanticName = state.optSemanticName; + auto semanticIndex = *state.ioSemanticIndex; + auto semanticSlotCount = state.semanticSlotCount; + + if(!optSemanticName) + return; + + auto semanticName = *optSemanticName; + + // Note: I'm just doing something expedient here and detecting `SV_Target` + // outputs and claiming the appropriate register range right away. + // + // TODO: we should really be building up some representation of all of this, + // once we've gone to the trouble of looking it all up... + if( semanticName.ToLower() == "sv_target" ) + { + context->shared->usedResourceRanges[int(LayoutResourceKind::UnorderedAccess)].Add(semanticIndex, semanticIndex + semanticSlotCount); + } +} + +static void processSimpleEntryPointParameter( + ParameterBindingContext* context, + RefPtr<ExpressionType> type, + EntryPointParameterState const& inState, + int semanticSlotCount = 1) +{ + EntryPointParameterState state = inState; + state.semanticSlotCount = semanticSlotCount; + + switch( state.direction ) + { + case EntryPointParameterDirection::Input: + processSimpleEntryPointInput(context, type, state); + break; + + case EntryPointParameterDirection::Output: + processSimpleEntryPointOutput(context, type, state); + break; + + SLANG_EXHAUSTIVE_SWITCH() + } + + *state.ioSemanticIndex += state.semanticSlotCount; +} + +static void processEntryPointParameter( + ParameterBindingContext* context, + RefPtr<ExpressionType> type, + EntryPointParameterState const& state); + +static void processEntryPointParameterWithPossibleSemantic( + ParameterBindingContext* context, + Decl* declForSemantic, + RefPtr<ExpressionType> type, + EntryPointParameterState const& state) +{ + // If there is no explicit semantic already in effect, *and* we find an explicit + // semantic on the associated declaration, then we'll use it. + if( !state.optSemanticName ) + { + if( auto semantic = declForSemantic->FindModifier<HLSLSimpleSemantic>() ) + { + auto semanticInfo = decomposeSimpleSemantic(semantic); + int semanticIndex = semanticInfo.index; + + EntryPointParameterState subState = state; + subState.optSemanticName = &semanticInfo.name; + subState.ioSemanticIndex = &semanticIndex; + + processEntryPointParameter(context, type, subState); + } + } + + // Default case: either there was an explicit semantic in effect already, + // *or* we couldn't find an explicit semantic to apply on the given + // declaration, so we will just recursive with whatever we have at + // the moment. + processEntryPointParameter(context, type, state); +} + + +static void processEntryPointParameter( + ParameterBindingContext* context, + RefPtr<ExpressionType> type, + EntryPointParameterState const& state) +{ + // Scalar and vector types are treated as outputs directly + if(auto basicType = type->As<BasicExpressionType>()) + { + processSimpleEntryPointParameter(context, basicType, state); + } + else if(auto basicType = type->As<VectorExpressionType>()) + { + processSimpleEntryPointParameter(context, basicType, state); + } + // A matrix is processed as if it was an array of rows + else if( auto matrixType = type->As<MatrixExpressionType>() ) + { + auto rowCount = GetIntVal(matrixType->getRowCount()); + processSimpleEntryPointParameter(context, basicType, state, rowCount); + } + else if( auto arrayType = type->As<ArrayExpressionType>() ) + { + auto elementCount = GetIntVal(arrayType->ArrayLength); + + for( int ii = 0; ii < elementCount; ++ii ) + { + processEntryPointParameter(context, arrayType->BaseType, state); + } + } + // Ignore a bunch of types that don't make sense here... + else if(auto textureType = type->As<TextureType>()) {} + else if(auto samplerStateType = type->As<SamplerStateType>()) {} + else if(auto constantBufferType = type->As<ConstantBufferType>()) {} + // Catch declaration-reference types late in the sequence, since + // otherwise they will include all of the above cases... + else if( auto declRefType = type->As<DeclRefType>() ) + { + auto declRef = declRefType->declRef; + + if (auto structDeclRef = declRef.As<StructDeclRef>()) + { + // Need to recursively walk the fields of the structure now... + for( auto field : structDeclRef.GetFields() ) + { + processEntryPointParameterWithPossibleSemantic( + context, + field.GetDecl(), + field.GetType(), + state); + } + } + else + { + assert(!"unimplemented"); + } + } + else + { + assert(!"unimplemented"); + } +} + +static void collectEntryPointParameters( + ParameterBindingContext* context, + EntryPointOption const& entryPoint, + ProgramSyntaxNode* translationUnitSyntax) +{ + // First, look for the entry point with the specified name + + // Make sure we've got a query-able member dictionary + buildMemberDictionary(translationUnitSyntax); + + Decl* entryPointDecl; + if( !translationUnitSyntax->memberDictionary.TryGetValue(entryPoint.name, entryPointDecl) ) + { + // No such entry point! + return; + } + if( entryPointDecl->nextInContainerWithSameName ) + { + // Not the only decl of that name! + return; + } + + FunctionSyntaxNode* entryPointFuncDecl = dynamic_cast<FunctionSyntaxNode*>(entryPointDecl); + if( !entryPointFuncDecl ) + { + // Not a function! + return; + } + + // Create the layout object here + auto entryPointLayout = new EntryPointLayout(); + entryPointLayout->profile = entryPoint.profile; + entryPointLayout->entryPoint = entryPointFuncDecl; + + + context->shared->programLayout->entryPoints.Add(entryPointLayout); + + // Okay, we seemingly have an entry-point function, and now we need to collect info on its parameters too + // + // TODO: Long-term we probably want complete information on all inputs/outputs of an entry point, + // but for now we are really just trying to scrape information on fragment outputs, so lets do that: + // + // TODO: check whether we should enumerate the parameters before the return type, or vice versa + + int defaultSemanticIndex = 0; + + EntryPointParameterState state; + state.ioSemanticIndex = &defaultSemanticIndex; + state.optSemanticName = nullptr; + state.semanticSlotCount = 0; + + for( auto m : entryPointFuncDecl->Members ) + { + auto paramDecl = m.As<VarDeclBase>(); + if(!paramDecl) + continue; + + // We have an entry-point parameter, and need to figure out what to do with it. + + // If it appears to be an input, process it as such. + if( paramDecl->HasModifier<InModifier>() || paramDecl->HasModifier<InOutModifier>() || !paramDecl->HasModifier<OutModifier>() ) + { + state.direction = EntryPointParameterDirection::Input; + + processEntryPointParameterWithPossibleSemantic( + context, + paramDecl.Ptr(), + paramDecl->Type.type, + state); + } + + // If it appears to be an output, process it as such. + if(paramDecl->HasModifier<OutModifier>() || paramDecl->HasModifier<InOutModifier>()) + { + state.direction = EntryPointParameterDirection::Output; + + processEntryPointParameterWithPossibleSemantic( + context, + paramDecl.Ptr(), + paramDecl->Type.type, + state); + } + } + + // If we can find an output type for the entry point, then process it as + // an output parameter. + if( auto resultType = entryPointFuncDecl->ReturnType.type ) + { + state.direction = EntryPointParameterDirection::Output; + + processEntryPointParameterWithPossibleSemantic( + context, + entryPointFuncDecl, + resultType, + state); + } +} + +// When doing parameter binding for global-scope stuff in GLSL, +// we may need to know what stage we are compiling for, so that +// we can handle special cases appropriately (e.g., "arrayed" +// inputs and outputs). +static Stage +inferStageForTranslationUnit( + CompileUnit const& translationUnit) +{ + // In the specific case where we are compiling GLSL input, + // and have only a single entry point, use the stage + // of the entry point. + // + // TODO: can we generalize this at all? + if( translationUnit.options.sourceLanguage == SourceLanguage::GLSL ) + { + if( translationUnit.options.entryPoints.Count() == 1 ) + { + return translationUnit.options.entryPoints[0].profile.GetStage(); + } + } + + return Stage::Unknown; +} + +static void collectParameters( + ParameterBindingContext* inContext, + CollectionOfTranslationUnits* program) +{ + ParameterBindingContext contextData = *inContext; + auto context = &contextData; + + for( auto& translationUnit : program->translationUnits ) + { + context->stage = inferStageForTranslationUnit(translationUnit); + + // First look at global-scope parameters + collectGlobalScopeParameters(context, translationUnit.SyntaxNode.Ptr()); + + // Next consider parameters for entry points + for( auto& entryPoint : translationUnit.options.entryPoints ) + { + context->stage = entryPoint.profile.GetStage(); + collectEntryPointParameters(context, entryPoint, translationUnit.SyntaxNode.Ptr()); + } + } +} + +void GenerateParameterBindings( + CollectionOfTranslationUnits* program) +{ + // TODO: infer a language or set of language rules to use based on the + // source files and entry points given + auto language = SourceLanguage::Unknown; + for( auto& translationUnit : program->translationUnits ) + { + auto translationUnitLanguage = translationUnit.options.sourceLanguage; + if( language == SourceLanguage::Unknown ) + { + language = translationUnitLanguage; + } + else if( language == translationUnitLanguage ) + { + // same language: nothing to do... + } + else + { + // mismatch! + // TODO(tfoley): emit a diagnostic + } + } + + // TODO(tfoley): We should really be picking layout rules + // based on the *target* language, and not the source... + auto rules = GetLayoutRulesFamilyImpl(language); + assert(rules); + + RefPtr<ProgramLayout> programLayout = new ProgramLayout; + + // Create a context to hold shared state during the process + // of generating parameter bindings + SharedParameterBindingContext sharedContext; + sharedContext.defaultLayoutRules = rules; + sharedContext.programLayout = programLayout; + sharedContext.sourceLanguage = language; + + // Create a sub-context to collect parameters that get + // declared into the global scope + ParameterBindingContext context; + context.shared = &sharedContext; + context.layoutRules = sharedContext.defaultLayoutRules; + + // Walk through AST to discover all the parameters + collectParameters(&context, program); + + // Now walk through the parameters to generate initial binding information + for( auto& parameter : sharedContext.parameters ) + { + generateParameterBindings(&context, parameter); + } + + bool anyGlobalUniforms = false; + for( auto& parameterInfo : sharedContext.parameters ) + { + assert(parameterInfo->varLayouts.Count() != 0); + auto firstVarLayout = parameterInfo->varLayouts.First(); + + // Does the field have any uniform data? + if( firstVarLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) + { + anyGlobalUniforms = true; + break; + } + } + + // If there are any global-scope uniforms, then we need to + // allocate a constant-buffer binding for them here. + ParameterBindingInfo globalConstantBufferBinding; + if( anyGlobalUniforms ) + { + globalConstantBufferBinding.index = + context.shared->usedResourceRanges[ + (int)LayoutResourceKind::ConstantBuffer].Allocate(1); + + // For now we only auto-generate bindings in space zero + globalConstantBufferBinding.space = 0; + } + + + // Now walk through again to actually give everything + // ranges of registers... + for( auto& parameter : sharedContext.parameters ) + { + completeBindingsForParameter(&context, parameter); + } + + // TODO: need to deal with parameters declared inside entry-point + // parameter lists at some point... + + + // Next we need to create a type layout to reflect the information + // we have collected. + + // We will lay out any bare uniforms at the global scope into + // a single constant buffer. This is appropriate for HLSL global-scope + // uniforms, and Vulkan GLSL doesn't allow uniforms at global scope, + // so it should work out. + // + // For legacy GLSL targets, we'd probably need a distinct resource + // kind and set of rules here, since legacy uniforms are not the + // same as the contents of a constant buffer. + auto globalScopeRules = context.layoutRules->getConstantBufferRules(); + + RefPtr<StructTypeLayout> globalScopeStructLayout = new StructTypeLayout(); + globalScopeStructLayout->rules = globalScopeRules; + + UniformLayoutInfo structLayoutInfo = globalScopeRules->BeginStructLayout(); + for( auto& parameterInfo : sharedContext.parameters ) + { + assert(parameterInfo->varLayouts.Count() != 0); + auto firstVarLayout = parameterInfo->varLayouts.First(); + + // Does the field have any uniform data? + auto layoutInfo = firstVarLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform); + size_t uniformSize = layoutInfo ? layoutInfo->count : 0; + if( uniformSize != 0 ) + { + // Make sure uniform fields get laid out properly... + + UniformLayoutInfo fieldInfo( + uniformSize, + firstVarLayout->typeLayout->uniformAlignment); + + size_t uniformOffset = globalScopeRules->AddStructField( + &structLayoutInfo, + fieldInfo); + + for( auto& varLayout : parameterInfo->varLayouts ) + { + varLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset; + } + } + + globalScopeStructLayout->fields.Add(firstVarLayout); + + for( auto& varLayout : parameterInfo->varLayouts ) + { + globalScopeStructLayout->mapVarToLayout.Add(varLayout->varDecl.GetDecl(), varLayout); + } + } + globalScopeRules->EndStructLayout(&structLayoutInfo); + + RefPtr<TypeLayout> globalScopeLayout = globalScopeStructLayout; + + // If there are global-scope uniforms, then we need to wrap + // up a global constant buffer type layout to hold them + if( anyGlobalUniforms ) + { + auto globalConstantBufferLayout = createParameterBlockTypeLayout( + nullptr, + globalScopeStructLayout, + globalScopeRules); + + globalScopeLayout = globalConstantBufferLayout; + } + + // We now have a bunch of layout information, which we should + // record into a suitable object that represents the program + programLayout->globalScopeLayout = globalScopeLayout; + program->layout = programLayout; +} + +}} diff --git a/source/slang/parameter-binding.h b/source/slang/parameter-binding.h new file mode 100644 index 000000000..8165f1b2e --- /dev/null +++ b/source/slang/parameter-binding.h @@ -0,0 +1,32 @@ +#ifndef SLANG_PARAMETER_BINDING_H +#define SLANG_PARAMETER_BINDING_H + +#include "../core/basic.h" +#include "syntax.h" + +#include "../../Slang.h" + +namespace Slang { +namespace Compiler { + +class CollectionOfTranslationUnits; + +// The parameter-binding interface is responsible for assigning +// binding locations/registers to every parameter of a shader +// program. This can include both parameters declared on a +// particular entry point, as well as parameters declared at +// global scope. +// + + +// Generate binding information for the given program, +// represented as a collection of different translation units, +// and attach that information to the syntax nodes +// of the program. + +void GenerateParameterBindings( + CollectionOfTranslationUnits* program); + +}} + +#endif // SLANG_REFLECTION_H diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp new file mode 100644 index 000000000..773e3b74b --- /dev/null +++ b/source/slang/parser.cpp @@ -0,0 +1,3106 @@ +#include "Parser.h" + +#include <assert.h> + +#include "lookup.h" + +namespace Slang +{ + namespace Compiler + { + enum Precedence : int + { + Invalid = -1, + Comma, + Assignment, + TernaryConditional, + LogicalOr, + LogicalAnd, + BitOr, + BitXor, + BitAnd, + EqualityComparison, + RelationalComparison, + BitShift, + Additive, + Multiplicative, + Prefix, + Postfix, + }; + + // TODO: implement two pass parsing for file reference and struct type recognition + + class Parser + { + public: + CompileOptions& options; + int anonymousCounter = 0; + + RefPtr<Scope> outerScope; + RefPtr<Scope> currentScope; + + TokenReader tokenReader; + DiagnosticSink * sink; + String fileName; + int genericDepth = 0; + + // Is the parser in a "recovering" state? + // During recovery we don't emit additional errors, until we find + // a token that we expected, when we exit recovery. + bool isRecovering = false; + + void FillPosition(SyntaxNode * node) + { + node->Position = tokenReader.PeekLoc(); + } + void PushScope(ContainerDecl* containerDecl) + { + RefPtr<Scope> newScope = new Scope(); + newScope->containerDecl = containerDecl; + newScope->parent = currentScope; + + currentScope = newScope; + } + void PopScope() + { + currentScope = currentScope->parent; + } + Parser( + CompileOptions& options, + TokenSpan const& _tokens, + DiagnosticSink * sink, + String _fileName, + RefPtr<Scope> const& outerScope) + : options(options) + , tokenReader(_tokens) + , sink(sink) + , fileName(_fileName) + , outerScope(outerScope) + {} + RefPtr<ProgramSyntaxNode> Parse(); + + Token ReadToken(); + Token ReadToken(TokenType type); + Token ReadToken(const char * string); + bool LookAheadToken(TokenType type, int offset = 0); + bool LookAheadToken(const char * string, int offset = 0); + void parseSourceFile(ProgramSyntaxNode* program); + RefPtr<ProgramSyntaxNode> ParseProgram(); + RefPtr<StructSyntaxNode> ParseStruct(); + RefPtr<ClassSyntaxNode> ParseClass(); + RefPtr<StatementSyntaxNode> ParseStatement(); + RefPtr<StatementSyntaxNode> ParseBlockStatement(); + RefPtr<VarDeclrStatementSyntaxNode> ParseVarDeclrStatement(Modifiers modifiers); + RefPtr<IfStatementSyntaxNode> ParseIfStatement(); + RefPtr<ForStatementSyntaxNode> ParseForStatement(); + RefPtr<WhileStatementSyntaxNode> ParseWhileStatement(); + RefPtr<DoWhileStatementSyntaxNode> ParseDoWhileStatement(); + RefPtr<BreakStatementSyntaxNode> ParseBreakStatement(); + RefPtr<ContinueStatementSyntaxNode> ParseContinueStatement(); + RefPtr<ReturnStatementSyntaxNode> ParseReturnStatement(); + RefPtr<ExpressionStatementSyntaxNode> ParseExpressionStatement(); + RefPtr<ExpressionSyntaxNode> ParseExpression(Precedence level = Precedence::Comma); + + // Parse an expression that might be used in an initializer or argument context, so we should avoid operator-comma + inline RefPtr<ExpressionSyntaxNode> ParseInitExpr() { return ParseExpression(Precedence::Assignment); } + inline RefPtr<ExpressionSyntaxNode> ParseArgExpr() { return ParseExpression(Precedence::Assignment); } + + RefPtr<ExpressionSyntaxNode> ParseLeafExpression(); + RefPtr<ParameterSyntaxNode> ParseParameter(); + RefPtr<ExpressionSyntaxNode> ParseType(); + TypeExp ParseTypeExp(); + + Parser & operator = (const Parser &) = delete; + }; + + // Forward Declarations + + static void ParseDeclBody( + Parser* parser, + ContainerDecl* containerDecl, + TokenType closingToken); + + static RefPtr<Modifier> ParseOptSemantics( + Parser* parser); + + static void ParseOptSemantics( + Parser* parser, + Decl* decl); + + static RefPtr<DeclBase> ParseDecl( + Parser* parser, + ContainerDecl* containerDecl); + + static RefPtr<Decl> ParseSingleDecl( + Parser* parser, + ContainerDecl* containerDecl); + + // + + static void Unexpected( + Parser* parser) + { + // Don't emit "unexpected token" errors if we are in recovering mode + if (!parser->isRecovering) + { + parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedToken, + parser->tokenReader.PeekTokenType()); + + // Switch into recovery mode, to suppress additional errors + parser->isRecovering = true; + } + } + + static void Unexpected( + Parser* parser, + char const* expected) + { + // Don't emit "unexpected token" errors if we are in recovering mode + if (!parser->isRecovering) + { + parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedTokenExpectedTokenName, + parser->tokenReader.PeekTokenType(), + expected); + + // Switch into recovery mode, to suppress additional errors + parser->isRecovering = true; + } + } + + static void Unexpected( + Parser* parser, + TokenType expected) + { + // Don't emit "unexpected token" errors if we are in recovering mode + if (!parser->isRecovering) + { + parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedTokenExpectedTokenType, + parser->tokenReader.PeekTokenType(), + expected); + + // Switch into recovery mode, to suppress additional errors + parser->isRecovering = true; + } + } + + static TokenType SkipToMatchingToken(TokenReader* reader, TokenType tokenType); + + // Skip a singel balanced token, which is either a single token in + // the common case, or a matched pair of tokens for `()`, `[]`, and `{}` + static TokenType SkipBalancedToken( + TokenReader* reader) + { + TokenType tokenType = reader->AdvanceToken().Type; + switch (tokenType) + { + default: + break; + + case TokenType::LParent: tokenType = SkipToMatchingToken(reader, TokenType::RParent); break; + case TokenType::LBrace: tokenType = SkipToMatchingToken(reader, TokenType::RBrace); break; + case TokenType::LBracket: tokenType = SkipToMatchingToken(reader, TokenType::RBracket); break; + } + return tokenType; + } + + // Skip balanced + static TokenType SkipToMatchingToken( + TokenReader* reader, + TokenType tokenType) + { + for (;;) + { + if (reader->IsAtEnd()) return TokenType::EndOfFile; + if (reader->PeekTokenType() == tokenType) + { + reader->AdvanceToken(); + return tokenType; + } + SkipBalancedToken(reader); + } + } + + // Is the given token type one that is used to "close" a + // balanced construct. + static bool IsClosingToken(TokenType tokenType) + { + switch (tokenType) + { + case TokenType::EndOfFile: + case TokenType::RBracket: + case TokenType::RParent: + case TokenType::RBrace: + return true; + + default: + return false; + } + } + + + // Expect an identifier token with the given content, and consume it. + Token Parser::ReadToken(const char* expected) + { + if (tokenReader.PeekTokenType() == TokenType::Identifier + && tokenReader.PeekToken().Content == expected) + { + isRecovering = false; + return tokenReader.AdvanceToken(); + } + + if (!isRecovering) + { + Unexpected(this, expected); + return tokenReader.PeekToken(); + } + else + { + // Try to find a place to recover + for (;;) + { + // The token we expected? + // Then exit recovery mode and pretend like all is well. + if (tokenReader.PeekTokenType() == TokenType::Identifier + && tokenReader.PeekToken().Content == expected) + { + isRecovering = false; + return tokenReader.AdvanceToken(); + } + + + // Don't skip past any "closing" tokens. + if (IsClosingToken(tokenReader.PeekTokenType())) + { + return tokenReader.PeekToken(); + } + + // Skip balanced tokens and try again. + SkipBalancedToken(&tokenReader); + } + } + } + + Token Parser::ReadToken() + { + return tokenReader.AdvanceToken(); + } + + static bool TryRecover( + Parser* parser, + TokenType const* recoverBefore, + int recoverBeforeCount, + TokenType const* recoverAfter, + int recoverAfterCount) + { + if (!parser->isRecovering) + return true; + + // Determine if we are looking for a closing token at all... + bool lookingForClose = false; + for (int ii = 0; ii < recoverBeforeCount; ++ii) + { + if (IsClosingToken(recoverBefore[ii])) + lookingForClose = true; + } + for (int ii = 0; ii < recoverAfterCount; ++ii) + { + if (IsClosingToken(recoverAfter[ii])) + lookingForClose = true; + } + + TokenReader* tokenReader = &parser->tokenReader; + for (;;) + { + TokenType peek = tokenReader->PeekTokenType(); + + // Is the next token in our recover-before set? + // If so, then we have recovered successfully! + for (int ii = 0; ii < recoverBeforeCount; ++ii) + { + if (peek == recoverBefore[ii]) + { + parser->isRecovering = false; + return true; + } + } + + // If we are looking at a token in our recover-after set, + // then consume it and recover + for (int ii = 0; ii < recoverAfterCount; ++ii) + { + if (peek == recoverAfter[ii]) + { + tokenReader->AdvanceToken(); + parser->isRecovering = false; + return true; + } + } + + // Don't try to skip past end of file + if (peek == TokenType::EndOfFile) + return false; + + switch (peek) + { + // Don't skip past simple "closing" tokens, *unless* + // we are looking for a closing token + case TokenType::RParent: + case TokenType::RBracket: + if (!lookingForClose) + return false; + break; + + // never skip a `}`, to avoid spurious errors + case TokenType::RBrace: + return false; + } + + // Skip balanced tokens and try again. + TokenType skipped = SkipBalancedToken(tokenReader); + + // If we happened to find a matched pair of tokens, and + // the end of it was a token we were looking for, + // then recover here + for (int ii = 0; ii < recoverAfterCount; ++ii) + { + if (skipped == recoverAfter[ii]) + { + parser->isRecovering = false; + return true; + } + } + } + } + + static bool TryRecoverBefore( + Parser* parser, + TokenType before0) + { + TokenType recoverBefore[] = { before0 }; + return TryRecover(parser, recoverBefore, 1, nullptr, 0); + } + + // Default recovery strategy, to use inside `{}`-delimeted blocks. + static bool TryRecover( + Parser* parser) + { + TokenType recoverBefore[] = { TokenType::RBrace }; + TokenType recoverAfter[] = { TokenType::Semicolon }; + return TryRecover(parser, recoverBefore, 1, recoverAfter, 1); + } + + Token Parser::ReadToken(TokenType expected) + { + if (tokenReader.PeekTokenType() == expected) + { + isRecovering = false; + return tokenReader.AdvanceToken(); + } + + if (!isRecovering) + { + Unexpected(this, expected); + return tokenReader.PeekToken(); + } + else + { + // Try to find a place to recover + if (TryRecoverBefore(this, expected)) + { + isRecovering = false; + return tokenReader.AdvanceToken(); + } + + return tokenReader.PeekToken(); + } + } + + bool Parser::LookAheadToken(const char * string, int offset) + { + TokenReader r = tokenReader; + for (int ii = 0; ii < offset; ++ii) + r.AdvanceToken(); + + return r.PeekTokenType() == TokenType::Identifier + && r.PeekToken().Content == string; + } + + bool Parser::LookAheadToken(TokenType type, int offset) + { + TokenReader r = tokenReader; + for (int ii = 0; ii < offset; ++ii) + r.AdvanceToken(); + + return r.PeekTokenType() == type; + } + + // Consume a token and return true it if matches, otherwise false + bool AdvanceIf(Parser* parser, TokenType tokenType) + { + if (parser->LookAheadToken(tokenType)) + { + parser->ReadToken(); + return true; + } + return false; + } + + // Consume a token and return true it if matches, otherwise false + bool AdvanceIf(Parser* parser, char const* text) + { + if (parser->LookAheadToken(text)) + { + parser->ReadToken(); + return true; + } + return false; + } + + // Consume a token and return true if it matches, otherwise check + // for end-of-file and expect that token (potentially producing + // an error) and return true to maintain forward progress. + // Otherwise return false. + bool AdvanceIfMatch(Parser* parser, TokenType tokenType) + { + // If we've run into a syntax error, but haven't recovered inside + // the block, then try to recover here. + if (parser->isRecovering) + { + TryRecoverBefore(parser, tokenType); + } + if (AdvanceIf(parser, tokenType)) + return true; + if (parser->tokenReader.PeekTokenType() == TokenType::EndOfFile) + { + parser->ReadToken(tokenType); + return true; + } + return false; + } + + RefPtr<ProgramSyntaxNode> Parser::Parse() + { + return ParseProgram(); + } + + RefPtr<TypeDefDecl> ParseTypeDef(Parser* parser) + { + // Consume the `typedef` keyword + parser->ReadToken("typedef"); + + // TODO(tfoley): parse an actual declarator + auto type = parser->ParseTypeExp(); + + auto nameToken = parser->ReadToken(TokenType::Identifier); + + RefPtr<TypeDefDecl> typeDefDecl = new TypeDefDecl(); + typeDefDecl->Name = nameToken; + typeDefDecl->Type = type; + + return typeDefDecl; + } + + // Add a modifier to a list of modifiers being built + static void AddModifier(RefPtr<Modifier>** ioModifierLink, RefPtr<Modifier> modifier) + { + RefPtr<Modifier>*& modifierLink = *ioModifierLink; + + while(*modifierLink) + modifierLink = &(*modifierLink)->next; + + *modifierLink = modifier; + modifierLink = &modifier->next; + } + + void addModifier( + RefPtr<ModifiableSyntaxNode> syntax, + RefPtr<Modifier> modifier) + { + auto modifierLink = &syntax->modifiers.first; + AddModifier(&modifierLink, modifier); + } + + // Parse HLSL-style `[name(arg, ...)]` style "attribute" modifiers + static void ParseSquareBracketAttributes(Parser* parser, RefPtr<Modifier>** ioModifierLink) + { + parser->ReadToken(TokenType::LBracket); + for(;;) + { + auto nameToken = parser->ReadToken(TokenType::Identifier); + RefPtr<HLSLUncheckedAttribute> modifier = new HLSLUncheckedAttribute(); + modifier->nameToken = nameToken; + + if (AdvanceIf(parser, TokenType::LParent)) + { + // HLSL-style `[name(arg0, ...)]` attribute + + while (!AdvanceIfMatch(parser, TokenType::RParent)) + { + auto arg = parser->ParseArgExpr(); + if (arg) + { + modifier->args.Add(arg); + } + + if (AdvanceIfMatch(parser, TokenType::RParent)) + break; + + parser->ReadToken(TokenType::Comma); + } + } + AddModifier(ioModifierLink, modifier); + + + if (AdvanceIfMatch(parser, TokenType::RBracket)) + break; + + parser->ReadToken(TokenType::Comma); + } + } + + static Modifiers ParseModifiers(Parser* parser) + { + Modifiers modifiers; + RefPtr<Modifier>* modifierLink = &modifiers.first; + for (;;) + { + CodePosition loc = parser->tokenReader.PeekLoc(); + + if (0) {} + + #define CASE(KEYWORD, TYPE) \ + else if(AdvanceIf(parser, #KEYWORD)) do { \ + RefPtr<TYPE> modifier = new TYPE(); \ + modifier->Position = loc; \ + AddModifier(&modifierLink, modifier); \ + } while(0) + + CASE(in, InModifier); + CASE(input, InputModifier); + CASE(out, OutModifier); + CASE(inout, InOutModifier); + CASE(const, ConstModifier); + CASE(instance, InstanceModifier); + CASE(__builtin, BuiltinModifier); + + CASE(inline, InlineModifier); + CASE(public, PublicModifier); + CASE(require, RequireModifier); + CASE(param, ParamModifier); + CASE(extern, ExternModifier); + + CASE(row_major, HLSLRowMajorLayoutModifier); + CASE(column_major, HLSLColumnMajorLayoutModifier); + + CASE(nointerpolation, HLSLNoInterpolationModifier); + CASE(linear, HLSLLinearModifier); + CASE(sample, HLSLSampleModifier); + CASE(centroid, HLSLCentroidModifier); + CASE(precise, HLSLPreciseModifier); + CASE(shared, HLSLEffectSharedModifier); + CASE(groupshared, HLSLGroupSharedModifier); + CASE(static, HLSLStaticModifier); + CASE(uniform, HLSLUniformModifier); + CASE(volatile, HLSLVolatileModifier); + + // Modifiers for geometry shader input + CASE(point, HLSLPointModifier); + CASE(line, HLSLLineModifier); + CASE(triangle, HLSLTriangleModifier); + CASE(lineadj, HLSLLineAdjModifier); + CASE(triangleadj, HLSLTriangleAdjModifier); + + // Modifiers for unary operator declarations + CASE(__prefix, PrefixModifier); + CASE(__postfix, PostfixModifier); + + #undef CASE + + else if (AdvanceIf(parser, "__intrinsic")) + { + auto modifier = new IntrinsicModifier(); + modifier->Position = loc; + + if (AdvanceIf(parser, TokenType::LParent)) + { + if (parser->LookAheadToken(TokenType::IntLiterial)) + { + modifier->op = (IntrinsicOp)StringToInt(parser->ReadToken().Content); + } + else + { + modifier->opToken = parser->ReadToken(TokenType::Identifier); + + modifier->op = findIntrinsicOp(modifier->opToken.Content.Buffer()); + + if (modifier->op == IntrinsicOp::Unknown) + { + parser->sink->diagnose(loc, Diagnostics::unimplemented, "unknown intrinsic op"); + } + } + + parser->ReadToken(TokenType::RParent); + } + + AddModifier(&modifierLink, modifier); + } + + + else if (AdvanceIf(parser, "layout")) + { + parser->ReadToken(TokenType::LParent); + while (!AdvanceIfMatch(parser, TokenType::RParent)) + { + auto nameToken = parser->ReadToken(TokenType::Identifier); + + RefPtr<GLSLLayoutModifier> modifier; + + // TODO: better handling of this choise (e.g., lookup in scope) + if(0) {} + #define CASE(KEYWORD, CLASS) \ + else if(nameToken.Content == #KEYWORD) modifier = new CLASS() + + CASE(constant_id, GLSLConstantIDLayoutModifier); + CASE(binding, GLSLBindingLayoutModifier); + CASE(set, GLSLSetLayoutModifier); + CASE(location, GLSLLocationLayoutModifier); + + #undef CASE + else + { + modifier = new GLSLUnparsedLayoutModifier(); + } + + modifier->nameToken = nameToken; + + if(AdvanceIf(parser, TokenType::OpAssign)) + { + modifier->valToken = parser->ReadToken(TokenType::IntLiterial); + } + + AddModifier(&modifierLink, modifier); + + if (AdvanceIf(parser, TokenType::RParent)) + break; + parser->ReadToken(TokenType::Comma); + } + } + else if (parser->tokenReader.PeekTokenType() == TokenType::LBracket) + { + ParseSquareBracketAttributes(parser, &modifierLink); + } + else if (AdvanceIf(parser,"__builtin_type")) + { + RefPtr<BuiltinTypeModifier> modifier = new BuiltinTypeModifier(); + parser->ReadToken(TokenType::LParent); + modifier->tag = BaseType(StringToInt(parser->ReadToken(TokenType::IntLiterial).Content)); + parser->ReadToken(TokenType::RParent); + + AddModifier(&modifierLink, modifier); + } + else if (AdvanceIf(parser,"__magic_type")) + { + RefPtr<MagicTypeModifier> modifier = new MagicTypeModifier(); + parser->ReadToken(TokenType::LParent); + modifier->name = parser->ReadToken(TokenType::Identifier).Content; + if (AdvanceIf(parser, TokenType::Comma)) + { + modifier->tag = uint32_t(StringToInt(parser->ReadToken(TokenType::IntLiterial).Content)); + } + parser->ReadToken(TokenType::RParent); + + AddModifier(&modifierLink, modifier); + } + else + { + // Fallback case if none of the above explicit cases matched. + + // If we are looking at an identifier, then it may map to a + // modifier declaration visible in the current scope + if( parser->LookAheadToken(TokenType::Identifier) ) + { + LookupResult lookupResult = LookUp( + parser->tokenReader.PeekToken().Content, + parser->currentScope); + + if( lookupResult.isValid() && !lookupResult.isOverloaded() ) + { + auto& item = lookupResult.item; + auto decl = item.declRef.GetDecl(); + + if( auto modifierDecl = dynamic_cast<ModifierDecl*>(decl) ) + { + // We found a declaration for some modifier syntax, + // so lets create an instance of the type it names + // here. + + auto syntax = createInstanceOfSyntaxClassByName(modifierDecl->classNameToken.Content); + auto modifier = dynamic_cast<Modifier*>(syntax); + + if( modifier ) + { + modifier->Position = parser->tokenReader.PeekLoc(); + modifier->nameToken = parser->ReadToken(TokenType::Identifier); + + AddModifier(&modifierLink, modifier); + continue; + } + else + { + parser->ReadToken(TokenType::Identifier); + assert(!"unexpected"); + } + } + } + } + + // Done with modifier list + return modifiers; + } + } + } + + static RefPtr<Decl> ParseUsing( + Parser* parser) + { + parser->ReadToken("using"); + if (parser->tokenReader.PeekTokenType() == TokenType::StringLiterial) + { + auto usingDecl = new UsingFileDecl(); + usingDecl->fileName = parser->ReadToken(TokenType::StringLiterial); + parser->ReadToken(TokenType::Semicolon); + return usingDecl; + } + else + { + unexpected(); + } + } + + static Token ParseDeclName( + Parser* parser) + { + Token name; + if (AdvanceIf(parser, "operator")) + { + name = parser->ReadToken(); + switch (name.Type) + { + case TokenType::OpAdd: case TokenType::OpSub: case TokenType::OpMul: case TokenType::OpDiv: + case TokenType::OpMod: case TokenType::OpNot: case TokenType::OpBitNot: case TokenType::OpLsh: case TokenType::OpRsh: + case TokenType::OpEql: case TokenType::OpNeq: case TokenType::OpGreater: case TokenType::OpLess: case TokenType::OpGeq: + case TokenType::OpLeq: case TokenType::OpAnd: case TokenType::OpOr: case TokenType::OpBitXor: case TokenType::OpBitAnd: + case TokenType::OpBitOr: case TokenType::OpInc: case TokenType::OpDec: + case TokenType::OpAddAssign: + case TokenType::OpSubAssign: + case TokenType::OpMulAssign: + case TokenType::OpDivAssign: + case TokenType::OpModAssign: + case TokenType::OpShlAssign: + case TokenType::OpShrAssign: + case TokenType::OpOrAssign: + case TokenType::OpAndAssign: + case TokenType::OpXorAssign: + + // Note(tfoley): A bit of a hack: + case TokenType::Comma: + case TokenType::OpAssign: + break; + + // Note(tfoley): Even more of a hack! + case TokenType::QuestionMark: + if (AdvanceIf(parser, TokenType::Colon)) + { + name.Content = name.Content + ":"; + break; + } + + default: + parser->sink->diagnose(name.Position, Diagnostics::invalidOperator, name.Content); + break; + } + } + else + { + name = parser->ReadToken(TokenType::Identifier); + } + return name; + } + + // A "declarator" as used in C-style languages + struct Declarator : RefObject + { + // Different cases of declarator appear as "flavors" here + enum class Flavor + { + Name, + Pointer, + Array, + }; + Flavor flavor; + }; + + // The most common case of declarator uses a simple name + struct NameDeclarator : Declarator + { + Token nameToken; + }; + + // A declarator that declares a pointer type + struct PointerDeclarator : Declarator + { + // location of the `*` token + CodePosition starLoc; + + RefPtr<Declarator> inner; + }; + + // A declarator that declares an array type + struct ArrayDeclarator : Declarator + { + RefPtr<Declarator> inner; + + // location of the `[` token + CodePosition openBracketLoc; + + // The expression that yields the element count, or NULL + RefPtr<ExpressionSyntaxNode> elementCountExpr; + }; + + // "Unwrapped" information about a declarator + struct DeclaratorInfo + { + RefPtr<ExpressionSyntaxNode> typeSpec; + Token nameToken; + RefPtr<Modifier> semantics; + RefPtr<ExpressionSyntaxNode> initializer; + }; + + // Add a member declaration to its container, and ensure that its + // parent link is set up correctly. + static void AddMember(RefPtr<ContainerDecl> container, RefPtr<Decl> member) + { + if (container) + { + member->ParentDecl = container.Ptr(); + container->Members.Add(member); + + container->memberDictionaryIsValid = false; + } + } + + static void AddMember(RefPtr<Scope> scope, RefPtr<Decl> member) + { + if (scope) + { + AddMember(scope->containerDecl, member); + } + } + + static void parseParameterList( + Parser* parser, + RefPtr<CallableDecl> decl) + { + parser->ReadToken(TokenType::LParent); + + // Allow a declaration to use the keyword `void` for a parameter list, + // since that was required in ancient C, and continues to be supported + // in a bunc hof its derivatives even if it is a Bad Design Choice + // + // TODO: conditionalize this so we don't keep this around for "pure" + // Slang code + if( parser->LookAheadToken("void") && parser->LookAheadToken(TokenType::RParent, 1) ) + { + parser->ReadToken("void"); + parser->ReadToken(TokenType::RParent); + return; + } + + while (!AdvanceIfMatch(parser, TokenType::RParent)) + { + AddMember(decl, parser->ParseParameter()); + if (AdvanceIf(parser, TokenType::RParent)) + break; + parser->ReadToken(TokenType::Comma); + } + } + + static void ParseFuncDeclHeader( + Parser* parser, + DeclaratorInfo const& declaratorInfo, + RefPtr<FunctionSyntaxNode> decl) + { + parser->PushScope(decl.Ptr()); + + parser->FillPosition(decl.Ptr()); + decl->Position = declaratorInfo.nameToken.Position; + + decl->Name = declaratorInfo.nameToken; + decl->ReturnType = TypeExp(declaratorInfo.typeSpec); + parseParameterList(parser, decl); + ParseOptSemantics(parser, decl.Ptr()); + } + + static RefPtr<Decl> ParseFuncDecl( + Parser* parser, + ContainerDecl* /*containerDecl*/, + DeclaratorInfo const& declaratorInfo) + { + RefPtr<FunctionSyntaxNode> decl = new FunctionSyntaxNode(); + ParseFuncDeclHeader(parser, declaratorInfo, decl); + + if (AdvanceIf(parser, TokenType::Semicolon)) + { + // empty body + } + else + { + decl->Body = parser->ParseBlockStatement(); + } + + parser->PopScope(); + return decl; + } + + static RefPtr<VarDeclBase> CreateVarDeclForContext( + ContainerDecl* containerDecl ) + { + if (dynamic_cast<StructSyntaxNode*>(containerDecl) || dynamic_cast<ClassSyntaxNode*>(containerDecl)) + { + return new StructField(); + } + else if (dynamic_cast<CallableDecl*>(containerDecl)) + { + return new ParameterSyntaxNode(); + } + else + { + return new Variable(); + } + } + + // Add modifiers to the end of the modifier list for a declaration + void AddModifiers(Decl* decl, RefPtr<Modifier> modifiers) + { + if (!modifiers) + return; + + RefPtr<Modifier>* link = &decl->modifiers.first; + while (*link) + { + link = &(*link)->next; + } + *link = modifiers; + } + + + static String GenerateName(Parser* parser, String const& base) + { + // TODO: somehow mangle the name to avoid clashes + return base; + } + + static String GenerateName(Parser* parser) + { + return GenerateName(parser, "_anonymous_" + String(parser->anonymousCounter++)); + } + + + // Set up a variable declaration based on what we saw in its declarator... + static void CompleteVarDecl( + Parser* parser, + RefPtr<VarDeclBase> decl, + DeclaratorInfo const& declaratorInfo) + { + parser->FillPosition(decl.Ptr()); + + if( declaratorInfo.nameToken.Type == TokenType::Unknown ) + { + // HACK(tfoley): we always give a name, even if the declarator didn't include one... :( + decl->Name.Content = GenerateName(parser); + } + else + { + decl->Position = declaratorInfo.nameToken.Position; + decl->Name = declaratorInfo.nameToken; + } + decl->Type = TypeExp(declaratorInfo.typeSpec); + + AddModifiers(decl.Ptr(), declaratorInfo.semantics); + + decl->Expr = declaratorInfo.initializer; + } + + static RefPtr<Declarator> ParseDeclarator(Parser* parser); + + static RefPtr<Declarator> ParseDirectAbstractDeclarator( + Parser* parser) + { + RefPtr<Declarator> declarator; + switch( parser->tokenReader.PeekTokenType() ) + { + case TokenType::Identifier: + { + auto nameDeclarator = new NameDeclarator(); + nameDeclarator->flavor = Declarator::Flavor::Name; + nameDeclarator->nameToken = ParseDeclName(parser); + declarator = nameDeclarator; + } + break; + + case TokenType::LParent: + { + // Note(tfoley): This is a point where disambiguation is required. + // We could be looking at an abstract declarator for a function-type + // parameter: + // + // void F( int(int) ); + // + // Or we could be looking at the use of parenthesese in an ordinary + // declarator: + // + // void (*f)(int); + // + // The difference really doesn't matter right now, but we err in + // the direction of assuming the second case. + parser->ReadToken(TokenType::LParent); + declarator = ParseDeclarator(parser); + parser->ReadToken(TokenType::RParent); + } + break; + + default: + // an empty declarator is allowed + return nullptr; + } + + // postifx additions + for( ;;) + { + switch( parser->tokenReader.PeekTokenType() ) + { + case TokenType::LBracket: + { + auto arrayDeclarator = new ArrayDeclarator(); + arrayDeclarator->openBracketLoc = parser->tokenReader.PeekLoc(); + arrayDeclarator->flavor = Declarator::Flavor::Array; + arrayDeclarator->inner = declarator; + + parser->ReadToken(TokenType::LBracket); + if( parser->tokenReader.PeekTokenType() != TokenType::RBracket ) + { + arrayDeclarator->elementCountExpr = parser->ParseExpression(); + } + parser->ReadToken(TokenType::RBracket); + + declarator = arrayDeclarator; + continue; + } + + case TokenType::LParent: + break; + + default: + break; + } + + break; + } + + return declarator; + } + + // Parse a declarator (or at least as much of one as we support) + static RefPtr<Declarator> ParseDeclarator( + Parser* parser) + { + if( parser->tokenReader.PeekTokenType() == TokenType::OpMul ) + { + auto ptrDeclarator = new PointerDeclarator(); + ptrDeclarator->starLoc = parser->tokenReader.PeekLoc(); + ptrDeclarator->flavor = Declarator::Flavor::Pointer; + + parser->ReadToken(TokenType::OpMul); + + // TODO(tfoley): allow qualifiers like `const` here? + + ptrDeclarator->inner = ParseDeclarator(parser); + return ptrDeclarator; + } + else + { + return ParseDirectAbstractDeclarator(parser); + } + } + + // A declarator plus optional semantics and initializer + struct InitDeclarator + { + RefPtr<Declarator> declarator; + RefPtr<Modifier> semantics; + RefPtr<ExpressionSyntaxNode> initializer; + }; + + // Parse a declarator plus optional semantics + static InitDeclarator ParseSemanticDeclarator( + Parser* parser) + { + InitDeclarator result; + result.declarator = ParseDeclarator(parser); + result.semantics = ParseOptSemantics(parser); + return result; + } + + // Parse a declarator plus optional semantics and initializer + static InitDeclarator ParseInitDeclarator( + Parser* parser) + { + InitDeclarator result = ParseSemanticDeclarator(parser); + if (AdvanceIf(parser, TokenType::OpAssign)) + { + result.initializer = parser->ParseInitExpr(); + } + return result; + } + + static void UnwrapDeclarator( + RefPtr<Declarator> declarator, + DeclaratorInfo* ioInfo) + { + while( declarator ) + { + switch(declarator->flavor) + { + case Declarator::Flavor::Name: + { + auto nameDeclarator = (NameDeclarator*) declarator.Ptr(); + ioInfo->nameToken = nameDeclarator->nameToken; + return; + } + break; + + case Declarator::Flavor::Pointer: + { + auto ptrDeclarator = (PointerDeclarator*) declarator.Ptr(); + + // TODO(tfoley): we don't support pointers for now + // ioInfo->typeSpec = new PointerTypeExpr(ioInfo->typeSpec); + + declarator = ptrDeclarator->inner; + } + break; + + case Declarator::Flavor::Array: + { + // TODO(tfoley): we don't support pointers for now + auto arrayDeclarator = (ArrayDeclarator*) declarator.Ptr(); + + auto arrayTypeExpr = new IndexExpressionSyntaxNode(); + arrayTypeExpr->Position = arrayDeclarator->openBracketLoc; + arrayTypeExpr->BaseExpression = ioInfo->typeSpec; + arrayTypeExpr->IndexExpression = arrayDeclarator->elementCountExpr; + ioInfo->typeSpec = arrayTypeExpr; + + declarator = arrayDeclarator->inner; + } + break; + + default: + SLANG_UNREACHABLE("all cases handled"); + break; + } + } + } + + static void UnwrapDeclarator( + InitDeclarator const& initDeclarator, + DeclaratorInfo* ioInfo) + { + UnwrapDeclarator(initDeclarator.declarator, ioInfo); + ioInfo->semantics = initDeclarator.semantics; + ioInfo->initializer = initDeclarator.initializer; + } + + // Either a single declaration, or a group of them + struct DeclGroupBuilder + { + CodePosition startPosition; + RefPtr<Decl> decl; + RefPtr<DeclGroup> group; + + // Add a new declaration to the potential group + void addDecl( + RefPtr<Decl> newDecl) + { + assert(newDecl); + + if( decl ) + { + group = new DeclGroup(); + group->Position = startPosition; + group->decls.Add(decl); + decl = nullptr; + } + + if( group ) + { + group->decls.Add(newDecl); + } + else + { + decl = newDecl; + } + } + + RefPtr<DeclBase> getResult() + { + if(group) return group; + return decl; + } + }; + + // Pares an argument to an application of a generic + RefPtr<ExpressionSyntaxNode> ParseGenericArg(Parser* parser) + { + return parser->ParseArgExpr(); + } + + // Create a type expression that will refer to the given declaration + static RefPtr<ExpressionSyntaxNode> + createDeclRefType(Parser* parser, RefPtr<Decl> decl) + { + // For now we just construct an expression that + // will look up the given declaration by name. + // + // TODO: do this better, e.g. by filling in the `declRef` field directly + + auto expr = new VarExpressionSyntaxNode(); + expr->scope = parser->currentScope.Ptr(); + expr->Position = decl->getNameToken().Position; + expr->Variable = decl->getName(); + return expr; + } + + // Representation for a parsed type specifier, which might + // include a declaration (e.g., of a `struct` type) + struct TypeSpec + { + // If the type-spec declared something, then put it here + RefPtr<Decl> decl; + + // Put the resulting expression (which should evaluate to a type) here + RefPtr<ExpressionSyntaxNode> expr; + }; + + static TypeSpec + parseTypeSpec(Parser* parser) + { + TypeSpec typeSpec; + + // We may see a `struct` type specified here, and need to act accordingly + // + // TODO(tfoley): Handle the case where the user is just using `struct` + // as a way to name an existing struct "tag" (e.g., `struct Foo foo;`) + // + if( parser->LookAheadToken("struct") ) + { + auto decl = parser->ParseStruct(); + typeSpec.decl = decl; + typeSpec.expr = createDeclRefType(parser, decl); + return typeSpec; + } + else if( parser->LookAheadToken("class") ) + { + auto decl = parser->ParseClass(); + typeSpec.decl = decl; + typeSpec.expr = createDeclRefType(parser, decl); + return typeSpec; + } + + Token typeName = parser->ReadToken(TokenType::Identifier); + + auto basicType = new VarExpressionSyntaxNode(); + basicType->scope = parser->currentScope.Ptr(); + basicType->Position = typeName.Position; + basicType->Variable = typeName.Content; + + RefPtr<ExpressionSyntaxNode> typeExpr = basicType; + + if (parser->LookAheadToken(TokenType::OpLess)) + { + RefPtr<GenericAppExpr> gtype = new GenericAppExpr(); + parser->FillPosition(gtype.Ptr()); // set up scope for lookup + gtype->Position = typeName.Position; + gtype->FunctionExpr = typeExpr; + parser->ReadToken(TokenType::OpLess); + parser->genericDepth++; + // For now assume all generics have at least one argument + gtype->Arguments.Add(ParseGenericArg(parser)); + while (AdvanceIf(parser, TokenType::Comma)) + { + gtype->Arguments.Add(ParseGenericArg(parser)); + } + parser->genericDepth--; + parser->ReadToken(TokenType::OpGreater); + typeExpr = gtype; + } + + typeSpec.expr = typeExpr; + return typeSpec; + } + + + static RefPtr<DeclBase> ParseDeclaratorDecl( + Parser* parser, + ContainerDecl* containerDecl) + { + CodePosition startPosition = parser->tokenReader.PeekLoc(); + + auto typeSpec = parseTypeSpec(parser); + + // We may need to build up multiple declarations in a group, + // but the common case will be when we have just a single + // declaration + DeclGroupBuilder declGroupBuilder; + declGroupBuilder.startPosition = startPosition; + + // The type specifier may include a declaration. E.g., + // it might declare a `struct` type. + if(typeSpec.decl) + declGroupBuilder.addDecl(typeSpec.decl); + + if( AdvanceIf(parser, TokenType::Semicolon) ) + { + // No actual variable is being declared here, but + // that might not be an error. + + auto result = declGroupBuilder.getResult(); + if( !result ) + { + parser->sink->diagnose(startPosition, Diagnostics::declarationDidntDeclareAnything); + } + return result; + } + + + InitDeclarator initDeclarator = ParseInitDeclarator(parser); + + DeclaratorInfo declaratorInfo; + declaratorInfo.typeSpec = typeSpec.expr; + + + // Rather than parse function declarators properly for now, + // we'll just do a quick disambiguation here. This won't + // matter unless we actually decide to support function-type parameters, + // using C syntax. + // + if( parser->tokenReader.PeekTokenType() == TokenType::LParent + + // Only parse as a function if we didn't already see mutually-exclusive + // constructs when parsing the declarator. + && !initDeclarator.initializer + && !initDeclarator.semantics) + { + // Looks like a function, so parse it like one. + UnwrapDeclarator(initDeclarator, &declaratorInfo); + return ParseFuncDecl(parser, containerDecl, declaratorInfo); + } + + // Otherwise we are looking at a variable declaration, which could be one in a sequence... + + if( AdvanceIf(parser, TokenType::Semicolon) ) + { + // easy case: we only had a single declaration! + UnwrapDeclarator(initDeclarator, &declaratorInfo); + RefPtr<VarDeclBase> firstDecl = CreateVarDeclForContext(containerDecl); + CompleteVarDecl(parser, firstDecl, declaratorInfo); + + declGroupBuilder.addDecl(firstDecl); + return declGroupBuilder.getResult(); + + return firstDecl; + } + + // Otherwise we have multiple declarations in a sequence, and these + // declarations need to somehow share both the type spec and modifiers. + // + // If there are any errors in the type specifier, we only want to hear + // about it once, so we need to share structure rather than just + // clone syntax. + + auto sharedTypeSpec = new SharedTypeExpr(); + sharedTypeSpec->Position = typeSpec.expr->Position; + sharedTypeSpec->base = TypeExp(typeSpec.expr); + + for(;;) + { + declaratorInfo.typeSpec = sharedTypeSpec; + UnwrapDeclarator(initDeclarator, &declaratorInfo); + + RefPtr<VarDeclBase> varDecl = CreateVarDeclForContext(containerDecl); + CompleteVarDecl(parser, varDecl, declaratorInfo); + + declGroupBuilder.addDecl(varDecl); + + // end of the sequence? + if(AdvanceIf(parser, TokenType::Semicolon)) + return declGroupBuilder.getResult(); + + // ad-hoc recovery, to avoid infinite loops + if( parser->isRecovering ) + { + parser->ReadToken(TokenType::Semicolon); + return declGroupBuilder.getResult(); + } + + // Let's default to assuming that a missing `,` + // indicates the end of a declaration, + // where a `;` would be expected, and not + // a continuation of this declaration, where + // a `,` would be expected (this is tailoring + // the diagnostic message a bit). + // + // TODO: a more advanced heuristic here might + // look at whether the next token is on the + // same line, to predict whether `,` or `;` + // would be more likely... + + if (!AdvanceIf(parser, TokenType::Comma)) + { + parser->ReadToken(TokenType::Semicolon); + return declGroupBuilder.getResult(); + } + + // expect another variable declaration... + initDeclarator = ParseInitDeclarator(parser); + } + } + + // + // layout-semantic ::= (register | packoffset) '(' register-name component-mask? ')' + // register-name ::= identifier + // component-mask ::= '.' identifier + // + static void ParseHLSLLayoutSemantic( + Parser* parser, + HLSLLayoutSemantic* semantic) + { + semantic->name = parser->ReadToken(TokenType::Identifier); + + parser->ReadToken(TokenType::LParent); + semantic->registerName = parser->ReadToken(TokenType::Identifier); + if (AdvanceIf(parser, TokenType::Dot)) + { + semantic->componentMask = parser->ReadToken(TokenType::Identifier); + } + parser->ReadToken(TokenType::RParent); + } + + // + // semantic ::= identifier ( '(' args ')' )? + // + static RefPtr<Modifier> ParseSemantic( + Parser* parser) + { + if (parser->LookAheadToken("register")) + { + RefPtr<HLSLRegisterSemantic> semantic = new HLSLRegisterSemantic(); + ParseHLSLLayoutSemantic(parser, semantic.Ptr()); + return semantic; + } + else if (parser->LookAheadToken("packoffset")) + { + RefPtr<HLSLPackOffsetSemantic> semantic = new HLSLPackOffsetSemantic(); + ParseHLSLLayoutSemantic(parser, semantic.Ptr()); + return semantic; + } + else + { + RefPtr<HLSLSimpleSemantic> semantic = new HLSLSimpleSemantic(); + semantic->name = parser->ReadToken(TokenType::Identifier); + return semantic; + } + } + + // + // opt-semantics ::= (':' semantic)* + // + static RefPtr<Modifier> ParseOptSemantics( + Parser* parser) + { + if (!AdvanceIf(parser, TokenType::Colon)) + return nullptr; + + RefPtr<Modifier> result; + RefPtr<Modifier>* link = &result; + assert(!*link); + + for (;;) + { + RefPtr<Modifier> semantic = ParseSemantic(parser); + if (semantic) + { + *link = semantic; + link = &semantic->next; + } + + switch (parser->tokenReader.PeekTokenType()) + { + case TokenType::LBrace: + case TokenType::Semicolon: + case TokenType::Comma: + case TokenType::RParent: + case TokenType::EndOfFile: + return result; + + default: + break; + } + + parser->ReadToken(TokenType::Colon); + } + + } + + + static void ParseOptSemantics( + Parser* parser, + Decl* decl) + { + AddModifiers(decl, ParseOptSemantics(parser)); + } + + static RefPtr<Decl> ParseHLSLBufferDecl( + Parser* parser) + { + // An HLSL declaration of a constant buffer like this: + // + // cbuffer Foo : register(b0) { int a; float b; }; + // + // is treated as syntax sugar for a type declaration + // and then a global variable declaration using that type: + // + // struct $anonymous { int a; float b; }; + // ConstantBuffer<$anonymous> Foo; + // + // where `$anonymous` is a fresh name, and the variable + // declaration is made to be "transparent" so that lookup + // will see through it to the members inside. + + // We first look at the declaration keywrod to determine + // the type of buffer to declare: + String bufferWrapperTypeName; + CodePosition bufferWrapperTypeNamePos = parser->tokenReader.PeekLoc(); + if (AdvanceIf(parser, "cbuffer")) + { + bufferWrapperTypeName = "ConstantBuffer"; + } + else if (AdvanceIf(parser, "tbuffer")) + { + bufferWrapperTypeName = "TextureBuffer"; + } + else + { + Unexpected(parser); + } + + // We are going to represent each buffer as a pair of declarations. + // The first is a type declaration that holds all the members, while + // the second is a variable declaration that uses the buffer type. + RefPtr<StructSyntaxNode> bufferDataTypeDecl = new StructSyntaxNode(); + RefPtr<Variable> bufferVarDecl = new Variable(); + + // Both declarations will have a location that points to the name + parser->FillPosition(bufferDataTypeDecl.Ptr()); + parser->FillPosition(bufferVarDecl.Ptr()); + + auto reflectionNameToken = parser->ReadToken(TokenType::Identifier); + + // Attach the reflection name to the block so we can use it + auto reflectionNameModifier = new ParameterBlockReflectionName(); + reflectionNameModifier->nameToken = reflectionNameToken; + addModifier(bufferVarDecl, reflectionNameModifier); + + // Both the buffer variable and its type need to have names generated + bufferVarDecl->Name.Content = GenerateName(parser, "SLANG_constantBuffer_" + reflectionNameToken.Content); + bufferDataTypeDecl->Name.Content = GenerateName(parser, "SLANG_ConstantBuffer_" + reflectionNameToken.Content); + + addModifier(bufferDataTypeDecl, new ImplicitParameterBlockElementTypeModifier()); + addModifier(bufferVarDecl, new ImplicitParameterBlockVariableModifier()); + + // TODO(tfoley): We end up constructing unchecked syntax here that + // is expected to type check into the right form, but it might be + // cleaner to have a more explicit desugaring pass where we parse + // these constructs directly into the AST and *then* desugar them. + + // Construct a type expression to reference the buffer data type + auto bufferDataTypeExpr = new VarExpressionSyntaxNode(); + bufferDataTypeExpr->Position = bufferDataTypeDecl->Position; + bufferDataTypeExpr->Variable = bufferDataTypeDecl->Name.Content; + bufferDataTypeExpr->scope = parser->currentScope.Ptr(); + + // Construct a type exrpession to reference the type constructor + auto bufferWrapperTypeExpr = new VarExpressionSyntaxNode(); + bufferWrapperTypeExpr->Position = bufferWrapperTypeNamePos; + bufferWrapperTypeExpr->Variable = bufferWrapperTypeName; + + // Always need to look this up in the outer scope, + // so that it won't collide with, e.g., a local variable called `ConstantBuffer` + bufferWrapperTypeExpr->scope = parser->outerScope; + + // Construct a type expression that represents the type for the variable, + // which is the wrapper type applied to the data type + auto bufferVarTypeExpr = new GenericAppExpr(); + bufferVarTypeExpr->Position = bufferVarDecl->Position; + bufferVarTypeExpr->FunctionExpr = bufferWrapperTypeExpr; + bufferVarTypeExpr->Arguments.Add(bufferDataTypeExpr); + + bufferVarDecl->Type.exp = bufferVarTypeExpr; + + // Any semantics applied to the bufer declaration are taken as applying + // to the variable instead. + ParseOptSemantics(parser, bufferVarDecl.Ptr()); + + // The declarations in the body belong to the data type. + parser->ReadToken(TokenType::LBrace); + ParseDeclBody(parser, bufferDataTypeDecl.Ptr(), TokenType::RBrace); + + // All HLSL buffer declarations are "transparent" in that their + // members are implicitly made visible in the parent scope. + // We achieve this by applying the transparent modifier to the variable. + auto transparentModifier = new TransparentModifier(); + transparentModifier->next = bufferVarDecl->modifiers.first; + bufferVarDecl->modifiers.first = transparentModifier; + + // Because we are constructing two declarations, we have a thorny + // issue that were are only supposed to return one. + // For now we handle this by adding the type declaration to + // the current scope manually, and then returning the variable + // declaration. + // + // Note: this means that any modifiers that have already been parsed + // will get attached to the variable declaration, not the type. + // There might be cases where we need to shuffle things around. + + AddMember(parser->currentScope, bufferDataTypeDecl); + + return bufferVarDecl; + } + + static void removeModifier( + Modifiers& modifiers, + RefPtr<Modifier> modifier) + { + RefPtr<Modifier>* link = &modifiers.first; + while (*link) + { + if (*link == modifier) + { + *link = (*link)->next; + return; + } + + link = &(*link)->next; + } + } + + static RefPtr<Decl> parseGLSLBlockDecl( + Parser* parser, + Modifiers& modifiers) + { + // An GLSL block like this: + // + // uniform Foo { int a; float b; } foo; + // + // is treated as syntax sugar for a type declaration + // and then a global variable declaration using that type: + // + // struct $anonymous { int a; float b; }; + // Block<$anonymous> foo; + // + // where `$anonymous` is a fresh name. + // + // If a "local name" like `foo` is not given, then + // we make the declaration "transparent" so that lookup + // will see through it to the members inside. + + + CodePosition pos = parser->tokenReader.PeekLoc(); + + // The initial name before the `{` is only supposed + // to be made visible to reflection + auto reflectionNameToken = parser->ReadToken(TokenType::Identifier); + + // Look at the qualifiers present on the block to decide what kind + // of block we are looking at. Also *remove* those qualifiers so + // that they don't interfere with downstream work. + String blockWrapperTypeName; + if( auto uniformMod = modifiers.findModifier<HLSLUniformModifier>() ) + { + removeModifier(modifiers, uniformMod); + blockWrapperTypeName = "ConstantBuffer"; + } + else if( auto inMod = modifiers.findModifier<InModifier>() ) + { + removeModifier(modifiers, inMod); + blockWrapperTypeName = "__GLSLInputParameterBlock"; + } + else if( auto outMod = modifiers.findModifier<OutModifier>() ) + { + removeModifier(modifiers, outMod); + blockWrapperTypeName = "__GLSLOutputParameterBlock"; + } + else if( auto bufferMod = modifiers.findModifier<GLSLBufferModifier>() ) + { + removeModifier(modifiers, bufferMod); + blockWrapperTypeName = "__GLSLShaderStorageBuffer"; + } + else + { + // Unknown case: just map to a constant buffer and hope for the best + blockWrapperTypeName = "ConstantBuffer"; + } + + // We are going to represent each buffer as a pair of declarations. + // The first is a type declaration that holds all the members, while + // the second is a variable declaration that uses the buffer type. + RefPtr<StructSyntaxNode> blockDataTypeDecl = new StructSyntaxNode(); + RefPtr<Variable> blockVarDecl = new Variable(); + + addModifier(blockDataTypeDecl, new ImplicitParameterBlockElementTypeModifier()); + addModifier(blockVarDecl, new ImplicitParameterBlockVariableModifier()); + + // Attach the reflection name to the block so we can use it + auto reflectionNameModifier = new ParameterBlockReflectionName(); + reflectionNameModifier->nameToken = reflectionNameToken; + addModifier(blockVarDecl, reflectionNameModifier); + + // Both declarations will have a location that points to the name + parser->FillPosition(blockDataTypeDecl.Ptr()); + parser->FillPosition(blockVarDecl.Ptr()); + + // Generate a unique name for the data type + blockDataTypeDecl->Name.Content = GenerateName(parser, "SLANG_ParameterBlock_" + reflectionNameToken.Content); + + // TODO(tfoley): We end up constructing unchecked syntax here that + // is expected to type check into the right form, but it might be + // cleaner to have a more explicit desugaring pass where we parse + // these constructs directly into the AST and *then* desugar them. + + // Construct a type expression to reference the buffer data type + auto blockDataTypeExpr = new VarExpressionSyntaxNode(); + blockDataTypeExpr->Position = blockDataTypeDecl->Position; + blockDataTypeExpr->Variable = blockDataTypeDecl->Name.Content; + blockDataTypeExpr->scope = parser->currentScope.Ptr(); + + // Construct a type exrpession to reference the type constructor + auto blockWrapperTypeExpr = new VarExpressionSyntaxNode(); + blockWrapperTypeExpr->Position = pos; + blockWrapperTypeExpr->Variable = blockWrapperTypeName; + // Always need to look this up in the outer scope, + // so that it won't collide with, e.g., a local variable called `ConstantBuffer` + blockWrapperTypeExpr->scope = parser->outerScope; + + // Construct a type expression that represents the type for the variable, + // which is the wrapper type applied to the data type + auto blockVarTypeExpr = new GenericAppExpr(); + blockVarTypeExpr->Position = blockVarDecl->Position; + blockVarTypeExpr->FunctionExpr = blockWrapperTypeExpr; + blockVarTypeExpr->Arguments.Add(blockDataTypeExpr); + + blockVarDecl->Type.exp = blockVarTypeExpr; + + // The declarations in the body belong to the data type. + parser->ReadToken(TokenType::LBrace); + ParseDeclBody(parser, blockDataTypeDecl.Ptr(), TokenType::RBrace); + + if( parser->LookAheadToken(TokenType::Identifier) ) + { + // The user gave an explicit name to the block, + // so we need to use that as our variable name + blockVarDecl->Name = parser->ReadToken(TokenType::Identifier); + + // TODO: in this case we make actually have a more complex + // declarator, including `[]` brackets. + } + else + { + // synthesize a dummy name + blockVarDecl->Name.Content = GenerateName(parser, "SLANG_parameterBlock_" + reflectionNameToken.Content); + + // Otherwise we have a transparent declaration, similar + // to an HLSL `cbuffer` + auto transparentModifier = new TransparentModifier(); + transparentModifier->Position = pos; + addModifier(blockVarDecl, transparentModifier); + } + + // Expect a trailing `;` + parser->ReadToken(TokenType::Semicolon); + + // Because we are constructing two declarations, we have a thorny + // issue that were are only supposed to return one. + // For now we handle this by adding the type declaration to + // the current scope manually, and then returning the variable + // declaration. + // + // Note: this means that any modifiers that have already been parsed + // will get attached to the variable declaration, not the type. + // There might be cases where we need to shuffle things around. + + AddMember(parser->currentScope, blockDataTypeDecl); + + return blockVarDecl; + } + + + + + static RefPtr<Decl> ParseGenericParamDecl( + Parser* parser, + RefPtr<GenericDecl> genericDecl) + { + // simple syntax to introduce a value parameter + if (AdvanceIf(parser, "let")) + { + // default case is a type parameter + auto paramDecl = new GenericValueParamDecl(); + paramDecl->Name = parser->ReadToken(TokenType::Identifier); + if (AdvanceIf(parser, TokenType::Colon)) + { + paramDecl->Type = parser->ParseTypeExp(); + } + if (AdvanceIf(parser, TokenType::OpAssign)) + { + paramDecl->Expr = parser->ParseInitExpr(); + } + return paramDecl; + } + else + { + // default case is a type parameter + auto paramDecl = new GenericTypeParamDecl(); + parser->FillPosition(paramDecl); + paramDecl->Name = parser->ReadToken(TokenType::Identifier); + if (AdvanceIf(parser, TokenType::Colon)) + { + // The user is apply a constraint to this type parameter... + + auto paramConstraint = new GenericTypeConstraintDecl(); + parser->FillPosition(paramConstraint); + + auto paramType = DeclRefType::Create(DeclRef(paramDecl, nullptr)); + + auto paramTypeExpr = new SharedTypeExpr(); + paramTypeExpr->Position = paramDecl->Position; + paramTypeExpr->base.type = paramType; + paramTypeExpr->Type = new TypeType(paramType); + + paramConstraint->sub = TypeExp(paramTypeExpr); + paramConstraint->sup = parser->ParseTypeExp(); + + AddMember(genericDecl, paramConstraint); + + + } + if (AdvanceIf(parser, TokenType::OpAssign)) + { + paramDecl->initType = parser->ParseTypeExp(); + } + return paramDecl; + } + } + + static RefPtr<Decl> ParseGenericDecl( + Parser* parser) + { + RefPtr<GenericDecl> decl = new GenericDecl(); + parser->FillPosition(decl.Ptr()); + parser->PushScope(decl.Ptr()); + parser->ReadToken("__generic"); + parser->ReadToken(TokenType::OpLess); + parser->genericDepth++; + while (!parser->LookAheadToken(TokenType::OpGreater)) + { + AddMember(decl, ParseGenericParamDecl(parser, decl)); + +if( parser->LookAheadToken(TokenType::OpGreater) ) +break; + +parser->ReadToken(TokenType::Comma); + } + parser->genericDepth--; + parser->ReadToken(TokenType::OpGreater); + + decl->inner = ParseSingleDecl(parser, decl.Ptr()); + + // A generic decl hijacks the name of the declaration + // it wraps, so that lookup can find it. + decl->Name = decl->inner->Name; + + parser->PopScope(); + return decl; + } + + static RefPtr<Decl> ParseTraitConformanceDecl( + Parser* parser) + { + RefPtr<TraitConformanceDecl> decl = new TraitConformanceDecl(); + parser->FillPosition(decl.Ptr()); + parser->ReadToken("__conforms"); + + decl->base = parser->ParseTypeExp(); + + return decl; + } + + + static RefPtr<ExtensionDecl> ParseExtensionDecl(Parser* parser) + { + RefPtr<ExtensionDecl> decl = new ExtensionDecl(); + parser->FillPosition(decl.Ptr()); + parser->ReadToken("__extension"); + decl->targetType = parser->ParseTypeExp(); + parser->ReadToken(TokenType::LBrace); + ParseDeclBody(parser, decl.Ptr(), TokenType::RBrace); + return decl; + } + + static RefPtr<TraitDecl> ParseTraitDecl(Parser* parser) + { + RefPtr<TraitDecl> decl = new TraitDecl(); + parser->FillPosition(decl.Ptr()); + parser->ReadToken("__trait"); + decl->Name = parser->ReadToken(TokenType::Identifier); + + if( AdvanceIf(parser, TokenType::Colon) ) + { + do + { + auto base = parser->ParseTypeExp(); + decl->bases.Add(base); + } while( AdvanceIf(parser, TokenType::Comma) ); + } + + parser->ReadToken(TokenType::LBrace); + ParseDeclBody(parser, decl.Ptr(), TokenType::RBrace); + return decl; + } + + static RefPtr<ConstructorDecl> ParseConstructorDecl(Parser* parser) + { + RefPtr<ConstructorDecl> decl = new ConstructorDecl(); + parser->FillPosition(decl.Ptr()); + parser->ReadToken("__init"); + + parseParameterList(parser, decl); + + if( AdvanceIf(parser, TokenType::Semicolon) ) + { + // empty body + } + else + { + decl->Body = parser->ParseBlockStatement(); + } + return decl; + } + + static RefPtr<AccessorDecl> parseAccessorDecl(Parser* parser) + { + RefPtr<AccessorDecl> decl; + if( AdvanceIf(parser, "get") ) + { + decl = new GetterDecl(); + } + else if( AdvanceIf(parser, "set") ) + { + decl = new SetterDecl(); + } + else + { + Unexpected(parser); + return nullptr; + } + + if( parser->tokenReader.PeekTokenType() == TokenType::LBrace ) + { + decl->Body = parser->ParseBlockStatement(); + } + else + { + parser->ReadToken(TokenType::Semicolon); + } + + return decl; + } + + static RefPtr<SubscriptDecl> ParseSubscriptDecl(Parser* parser) + { + RefPtr<SubscriptDecl> decl = new SubscriptDecl(); + parser->FillPosition(decl.Ptr()); + parser->ReadToken("__subscript"); + + // TODO: the use of this name here is a bit magical... + decl->Name.Content = "operator[]"; + + parseParameterList(parser, decl); + + if( AdvanceIf(parser, TokenType::RightArrow) ) + { + decl->ReturnType = parser->ParseTypeExp(); + } + + if( AdvanceIf(parser, TokenType::LBrace) ) + { + // We want to parse nested "accessor" declarations + while( !AdvanceIfMatch(parser, TokenType::RBrace) ) + { + auto accessor = parseAccessorDecl(parser); + AddMember(decl, accessor); + } + } + else + { + parser->ReadToken(TokenType::Semicolon); + + // empty body should be treated like `{ get; }` + } + + return decl; + } + + // Parse a declaration of a new modifier keyword + static RefPtr<ModifierDecl> parseModifierDecl(Parser* parser) + { + RefPtr<ModifierDecl> decl = new ModifierDecl(); + + // read the `__modifier` keyword + parser->ReadToken(TokenType::Identifier); + + parser->ReadToken(TokenType::LParent); + decl->classNameToken = parser->ReadToken(TokenType::Identifier); + parser->ReadToken(TokenType::RParent); + + parser->FillPosition(decl.Ptr()); + decl->Name = parser->ReadToken(TokenType::Identifier); + + parser->ReadToken(TokenType::Semicolon); + return decl; + } + + // Finish up work on a declaration that was parsed + static void CompleteDecl( + Parser* /*parser*/, + RefPtr<Decl> decl, + ContainerDecl* containerDecl, + Modifiers modifiers) + { + // Add any modifiers we parsed before the declaration to the list + // of modifiers on the declaration itself. + AddModifiers(decl.Ptr(), modifiers.first); + + // Make sure the decl is properly nested inside its lexical parent + if (containerDecl) + { + AddMember(containerDecl, decl); + } + } + + static RefPtr<DeclBase> ParseDeclWithModifiers( + Parser* parser, + ContainerDecl* containerDecl, + Modifiers modifiers ) + { + RefPtr<DeclBase> decl; + + auto loc = parser->tokenReader.PeekLoc(); + + // TODO: actual dispatch! + if (parser->LookAheadToken("struct")) + decl = ParseDeclaratorDecl(parser, containerDecl); + else if (parser->LookAheadToken("class")) + decl = ParseDeclaratorDecl(parser, containerDecl); + else if (parser->LookAheadToken("typedef")) + decl = ParseTypeDef(parser); + else if (parser->LookAheadToken("using")) + decl = ParseUsing(parser); + else if (parser->LookAheadToken("cbuffer") || parser->LookAheadToken("tbuffer")) + decl = ParseHLSLBufferDecl(parser); + else if (parser->LookAheadToken("__generic")) + decl = ParseGenericDecl(parser); + else if (parser->LookAheadToken("__conforms")) + decl = ParseTraitConformanceDecl(parser); + else if (parser->LookAheadToken("__extension")) + decl = ParseExtensionDecl(parser); + else if (parser->LookAheadToken("__init")) + decl = ParseConstructorDecl(parser); + else if (parser->LookAheadToken("__subscript")) + decl = ParseSubscriptDecl(parser); + else if (parser->LookAheadToken("__trait")) + decl = ParseTraitDecl(parser); + else if(parser->LookAheadToken("__modifier")) + decl = parseModifierDecl(parser); + else if (AdvanceIf(parser, TokenType::Semicolon)) + { + decl = new EmptyDecl(); + decl->Position = loc; + } + // GLSL requires that we be able to parse "block" declarations, + // which look superficially similar to declarator declarations + else if( parser->LookAheadToken(TokenType::Identifier) + && parser->LookAheadToken(TokenType::LBrace, 1) ) + { + decl = parseGLSLBlockDecl(parser, modifiers); + } + else + { + // Default case: just parse a declarator-based declaration + decl = ParseDeclaratorDecl(parser, containerDecl); + } + + if (decl) + { + if( auto dd = decl.As<Decl>() ) + { + CompleteDecl(parser, dd, containerDecl, modifiers); + } + else if(auto declGroup = decl.As<DeclGroup>()) + { + // We are going to add the same modifiers to *all* of these declarations, + // so we want to give later passes a way to detect which modifiers + // were shared, vs. which ones are specific to a single declaration. + + auto sharedModifiers = new SharedModifiers(); + sharedModifiers->next = modifiers.first; + modifiers.first = sharedModifiers; + + for( auto subDecl : declGroup->decls ) + { + CompleteDecl(parser, subDecl, containerDecl, modifiers); + } + } + } + return decl; + } + + static RefPtr<DeclBase> ParseDecl( + Parser* parser, + ContainerDecl* containerDecl) + { + Modifiers modifiers = ParseModifiers(parser); + return ParseDeclWithModifiers(parser, containerDecl, modifiers); + } + + static RefPtr<Decl> ParseSingleDecl( + Parser* parser, + ContainerDecl* containerDecl) + { + auto declBase = ParseDecl(parser, containerDecl); + if(!declBase) + return nullptr; + if( auto decl = declBase.As<Decl>() ) + { + return decl; + } + else if( auto declGroup = declBase.As<DeclGroup>() ) + { + if( declGroup->decls.Count() == 1 ) + { + return declGroup->decls[0]; + } + } + + parser->sink->diagnose(declBase->Position, Diagnostics::unimplemented, "didn't expect multiple declarations here"); + return nullptr; + } + + + // Parse a body consisting of declarations + static void ParseDeclBody( + Parser* parser, + ContainerDecl* containerDecl, + TokenType closingToken) + { + while(!AdvanceIfMatch(parser, closingToken)) + { + ParseDecl(parser, containerDecl); + TryRecover(parser); + } + } + + void Parser::parseSourceFile(ProgramSyntaxNode* program) + { + if (outerScope) + { + currentScope = outerScope; + } + + PushScope(program); + program->Position = CodePosition(0, 0, 0, fileName); + ParseDeclBody(this, program, TokenType::EndOfFile); + PopScope(); + + assert(currentScope == outerScope); + currentScope = nullptr; + } + + RefPtr<ProgramSyntaxNode> Parser::ParseProgram() + { + RefPtr<ProgramSyntaxNode> program = new ProgramSyntaxNode(); + + parseSourceFile(program.Ptr()); + + return program; + } + + RefPtr<StructSyntaxNode> Parser::ParseStruct() + { + RefPtr<StructSyntaxNode> rs = new StructSyntaxNode(); + FillPosition(rs.Ptr()); + ReadToken("struct"); + rs->Name = ReadToken(TokenType::Identifier); + ReadToken(TokenType::LBrace); + ParseDeclBody(this, rs.Ptr(), TokenType::RBrace); + + return rs; + } + + RefPtr<ClassSyntaxNode> Parser::ParseClass() + { + RefPtr<ClassSyntaxNode> rs = new ClassSyntaxNode(); + FillPosition(rs.Ptr()); + ReadToken("class"); + rs->Name = ReadToken(TokenType::Identifier); + ReadToken(TokenType::LBrace); + ParseDeclBody(this, rs.Ptr(), TokenType::RBrace); + return rs; + } + + static RefPtr<StatementSyntaxNode> ParseSwitchStmt(Parser* parser) + { + RefPtr<SwitchStmt> stmt = new SwitchStmt(); + parser->FillPosition(stmt.Ptr()); + parser->ReadToken("switch"); + parser->ReadToken(TokenType::LParent); + stmt->condition = parser->ParseExpression(); + parser->ReadToken(TokenType::RParent); + stmt->body = parser->ParseBlockStatement(); + return stmt; + } + + static RefPtr<StatementSyntaxNode> ParseCaseStmt(Parser* parser) + { + RefPtr<CaseStmt> stmt = new CaseStmt(); + parser->FillPosition(stmt.Ptr()); + parser->ReadToken("case"); + stmt->expr = parser->ParseExpression(); + parser->ReadToken(TokenType::Colon); + return stmt; + } + + static RefPtr<StatementSyntaxNode> ParseDefaultStmt(Parser* parser) + { + RefPtr<DefaultStmt> stmt = new DefaultStmt(); + parser->FillPosition(stmt.Ptr()); + parser->ReadToken("default"); + parser->ReadToken(TokenType::Colon); + return stmt; + } + + static bool peekTypeName(Parser* parser) + { + if(!parser->LookAheadToken(TokenType::Identifier)) + return false; + + auto name = parser->tokenReader.PeekToken().Content; + + auto lookupResult = LookUp(name, parser->currentScope); + if(!lookupResult.isValid() || lookupResult.isOverloaded()) + return false; + + auto decl = lookupResult.item.declRef.GetDecl(); + if( auto typeDecl = dynamic_cast<AggTypeDecl*>(decl) ) + { + return true; + } + else if( auto typeVarDecl = dynamic_cast<SimpleTypeDecl*>(decl) ) + { + return true; + } + else + { + return false; + } + } + + RefPtr<StatementSyntaxNode> Parser::ParseStatement() + { + auto modifiers = ParseModifiers(this); + + RefPtr<StatementSyntaxNode> statement; + if (LookAheadToken(TokenType::LBrace)) + statement = ParseBlockStatement(); + else if (peekTypeName(this)) + statement = ParseVarDeclrStatement(modifiers); + else if (LookAheadToken("if")) + statement = ParseIfStatement(); + else if (LookAheadToken("for")) + statement = ParseForStatement(); + else if (LookAheadToken("while")) + statement = ParseWhileStatement(); + else if (LookAheadToken("do")) + statement = ParseDoWhileStatement(); + else if (LookAheadToken("break")) + statement = ParseBreakStatement(); + else if (LookAheadToken("continue")) + statement = ParseContinueStatement(); + else if (LookAheadToken("return")) + statement = ParseReturnStatement(); + else if (LookAheadToken("discard")) + { + statement = new DiscardStatementSyntaxNode(); + FillPosition(statement.Ptr()); + ReadToken("discard"); + ReadToken(TokenType::Semicolon); + } + else if (LookAheadToken("switch")) + statement = ParseSwitchStmt(this); + else if (LookAheadToken("case")) + statement = ParseCaseStmt(this); + else if (LookAheadToken("default")) + statement = ParseDefaultStmt(this); + else if (LookAheadToken(TokenType::Identifier)) + { + // We might be looking at a local declaration, or an + // expression statement, and we need to figure out which. + // + // We'll solve this with backtracking for now. + + Token* startPos = tokenReader.mCursor; + + // Try to parse a type (knowing that the type grammar is + // a subset of the expression grammar, and so this should + // always succeed). + RefPtr<ExpressionSyntaxNode> type = ParseType(); + // We don't actually care about the type, though, so + // don't retain it + type = nullptr; + + // If the next token after we parsed a type looks like + // we are going to declare a variable, then lets guess + // that this is a declaration. + // + // TODO(tfoley): this wouldn't be robust for more + // general kinds of declarators (notably pointer declarators), + // so we'll need to be careful about this. + if (LookAheadToken(TokenType::Identifier)) + { + // Reset the cursor and try to parse a declaration now. + // Note: the declaration will consume any modifiers + // that had been in place on the statement. + tokenReader.mCursor = startPos; + statement = ParseVarDeclrStatement(modifiers); + return statement; + } + + // Fallback: reset and parse an expression + tokenReader.mCursor = startPos; + statement = ParseExpressionStatement(); + } + else if (LookAheadToken(TokenType::Semicolon)) + { + statement = new EmptyStatementSyntaxNode(); + FillPosition(statement.Ptr()); + ReadToken(TokenType::Semicolon); + } + else + { + // Default case should always fall back to parsing an expression, + // and then let that detect any errors + statement = ParseExpressionStatement(); + } + + if (statement) + { + // Install any modifiers onto the statement. + // Note: this path is bypassed in the case of a + // declaration statement, so we don't end up + // doubling up the modifiers. + statement->modifiers = modifiers; + } + + return statement; + } + + RefPtr<StatementSyntaxNode> Parser::ParseBlockStatement() + { + if( options.flags & SLANG_COMPILE_FLAG_NO_CHECKING ) + { + // We have been asked to parse the input, but not attempt to understand it. + + // TODO: record start/end locations... + + List<Token> tokens; + + ReadToken(TokenType::LBrace); + + int depth = 1; + for( ;;) + { + switch( tokenReader.PeekTokenType() ) + { + case TokenType::EndOfFile: + goto done; + + case TokenType::RBrace: + depth--; + if(depth == 0) + goto done; + break; + + case TokenType::LBrace: + depth++; + break; + + default: + break; + } + + auto token = tokenReader.AdvanceToken(); + tokens.Add(token); + } + done: + ReadToken(TokenType::RBrace); + + RefPtr<UnparsedStmt> unparsedStmt = new UnparsedStmt(); + unparsedStmt->tokens = tokens; + return unparsedStmt; + } + + + RefPtr<ScopeDecl> scopeDecl = new ScopeDecl(); + RefPtr<BlockStatementSyntaxNode> blockStatement = new BlockStatementSyntaxNode(); + blockStatement->scopeDecl = scopeDecl; + PushScope(scopeDecl.Ptr()); + ReadToken(TokenType::LBrace); + if(!tokenReader.IsAtEnd()) + { + FillPosition(blockStatement.Ptr()); + } + while (!AdvanceIfMatch(this, TokenType::RBrace)) + { + auto stmt = ParseStatement(); + if(stmt) + { + blockStatement->Statements.Add(stmt); + } + TryRecover(this); + } + PopScope(); + return blockStatement; + } + + RefPtr<VarDeclrStatementSyntaxNode> Parser::ParseVarDeclrStatement( + Modifiers modifiers) + { + RefPtr<VarDeclrStatementSyntaxNode>varDeclrStatement = new VarDeclrStatementSyntaxNode(); + + FillPosition(varDeclrStatement.Ptr()); + auto decl = ParseDeclWithModifiers(this, currentScope->containerDecl, modifiers); + varDeclrStatement->decl = decl; + return varDeclrStatement; + } + + RefPtr<IfStatementSyntaxNode> Parser::ParseIfStatement() + { + RefPtr<IfStatementSyntaxNode> ifStatement = new IfStatementSyntaxNode(); + FillPosition(ifStatement.Ptr()); + ReadToken("if"); + ReadToken(TokenType::LParent); + ifStatement->Predicate = ParseExpression(); + ReadToken(TokenType::RParent); + ifStatement->PositiveStatement = ParseStatement(); + if (LookAheadToken("else")) + { + ReadToken("else"); + ifStatement->NegativeStatement = ParseStatement(); + } + return ifStatement; + } + + RefPtr<ForStatementSyntaxNode> Parser::ParseForStatement() + { + RefPtr<ScopeDecl> scopeDecl = new ScopeDecl(); + RefPtr<ForStatementSyntaxNode> stmt = new ForStatementSyntaxNode(); + stmt->scopeDecl = scopeDecl; + + // Note(tfoley): HLSL implements `for` with incorrect scoping. + // We need an option to turn on this behavior in a kind of "legacy" mode +// PushScope(scopeDecl.Ptr()); + FillPosition(stmt.Ptr()); + ReadToken("for"); + ReadToken(TokenType::LParent); + if (peekTypeName(this)) + { + stmt->InitialStatement = ParseVarDeclrStatement(Modifiers()); + } + else + { + if (!LookAheadToken(TokenType::Semicolon)) + { + stmt->InitialStatement = ParseExpressionStatement(); + } + else + { + ReadToken(TokenType::Semicolon); + } + } + if (!LookAheadToken(TokenType::Semicolon)) + stmt->PredicateExpression = ParseExpression(); + ReadToken(TokenType::Semicolon); + if (!LookAheadToken(TokenType::RParent)) + stmt->SideEffectExpression = ParseExpression(); + ReadToken(TokenType::RParent); + stmt->Statement = ParseStatement(); +// PopScope(); + return stmt; + } + + RefPtr<WhileStatementSyntaxNode> Parser::ParseWhileStatement() + { + RefPtr<WhileStatementSyntaxNode> whileStatement = new WhileStatementSyntaxNode(); + FillPosition(whileStatement.Ptr()); + ReadToken("while"); + ReadToken(TokenType::LParent); + whileStatement->Predicate = ParseExpression(); + ReadToken(TokenType::RParent); + whileStatement->Statement = ParseStatement(); + return whileStatement; + } + + RefPtr<DoWhileStatementSyntaxNode> Parser::ParseDoWhileStatement() + { + RefPtr<DoWhileStatementSyntaxNode> doWhileStatement = new DoWhileStatementSyntaxNode(); + FillPosition(doWhileStatement.Ptr()); + ReadToken("do"); + doWhileStatement->Statement = ParseStatement(); + ReadToken("while"); + ReadToken(TokenType::LParent); + doWhileStatement->Predicate = ParseExpression(); + ReadToken(TokenType::RParent); + ReadToken(TokenType::Semicolon); + return doWhileStatement; + } + + RefPtr<BreakStatementSyntaxNode> Parser::ParseBreakStatement() + { + RefPtr<BreakStatementSyntaxNode> breakStatement = new BreakStatementSyntaxNode(); + FillPosition(breakStatement.Ptr()); + ReadToken("break"); + ReadToken(TokenType::Semicolon); + return breakStatement; + } + + RefPtr<ContinueStatementSyntaxNode> Parser::ParseContinueStatement() + { + RefPtr<ContinueStatementSyntaxNode> continueStatement = new ContinueStatementSyntaxNode(); + FillPosition(continueStatement.Ptr()); + ReadToken("continue"); + ReadToken(TokenType::Semicolon); + return continueStatement; + } + + RefPtr<ReturnStatementSyntaxNode> Parser::ParseReturnStatement() + { + RefPtr<ReturnStatementSyntaxNode> returnStatement = new ReturnStatementSyntaxNode(); + FillPosition(returnStatement.Ptr()); + ReadToken("return"); + if (!LookAheadToken(TokenType::Semicolon)) + returnStatement->Expression = ParseExpression(); + ReadToken(TokenType::Semicolon); + return returnStatement; + } + + RefPtr<ExpressionStatementSyntaxNode> Parser::ParseExpressionStatement() + { + RefPtr<ExpressionStatementSyntaxNode> statement = new ExpressionStatementSyntaxNode(); + + FillPosition(statement.Ptr()); + statement->Expression = ParseExpression(); + + ReadToken(TokenType::Semicolon); + return statement; + } + + RefPtr<ParameterSyntaxNode> Parser::ParseParameter() + { + RefPtr<ParameterSyntaxNode> parameter = new ParameterSyntaxNode(); + parameter->modifiers = ParseModifiers(this); + + DeclaratorInfo declaratorInfo; + declaratorInfo.typeSpec = ParseType(); + + InitDeclarator initDeclarator = ParseInitDeclarator(this); + UnwrapDeclarator(initDeclarator, &declaratorInfo); + + // Assume it is a variable-like declarator + CompleteVarDecl(this, parameter, declaratorInfo); + return parameter; + } + + RefPtr<ExpressionSyntaxNode> Parser::ParseType() + { + auto typeSpec = parseTypeSpec(this); + if( typeSpec.decl ) + { + AddMember(currentScope, typeSpec.decl); + } + auto typeExpr = typeSpec.expr; + + while (LookAheadToken(TokenType::LBracket)) + { + RefPtr<IndexExpressionSyntaxNode> arrType = new IndexExpressionSyntaxNode(); + arrType->Position = typeExpr->Position; + arrType->BaseExpression = typeExpr; + ReadToken(TokenType::LBracket); + if (!LookAheadToken(TokenType::RBracket)) + { + arrType->IndexExpression = ParseExpression(); + } + ReadToken(TokenType::RBracket); + typeExpr = arrType; + } + + return typeExpr; + } + + + + TypeExp Parser::ParseTypeExp() + { + return TypeExp(ParseType()); + } + + enum class Associativity + { + Left, Right + }; + + + + Associativity GetAssociativityFromLevel(Precedence level) + { + if (level == Precedence::Assignment) + return Associativity::Right; + else + return Associativity::Left; + } + + + + + Precedence GetOpLevel(Parser* parser, TokenType type) + { + switch(type) + { + case TokenType::Comma: + return Precedence::Comma; + case TokenType::OpAssign: + case TokenType::OpMulAssign: + case TokenType::OpDivAssign: + case TokenType::OpAddAssign: + case TokenType::OpSubAssign: + case TokenType::OpModAssign: + case TokenType::OpShlAssign: + case TokenType::OpShrAssign: + case TokenType::OpOrAssign: + case TokenType::OpAndAssign: + case TokenType::OpXorAssign: + return Precedence::Assignment; + case TokenType::OpOr: + return Precedence::LogicalOr; + case TokenType::OpAnd: + return Precedence::LogicalAnd; + case TokenType::OpBitOr: + return Precedence::BitOr; + case TokenType::OpBitXor: + return Precedence::BitXor; + case TokenType::OpBitAnd: + return Precedence::BitAnd; + case TokenType::OpEql: + case TokenType::OpNeq: + return Precedence::EqualityComparison; + case TokenType::OpGreater: + case TokenType::OpGeq: + // Don't allow these ops inside a generic argument + if (parser->genericDepth > 0) return Precedence::Invalid; + case TokenType::OpLeq: + case TokenType::OpLess: + return Precedence::RelationalComparison; + case TokenType::OpRsh: + // Don't allow this op inside a generic argument + if (parser->genericDepth > 0) return Precedence::Invalid; + case TokenType::OpLsh: + return Precedence::BitShift; + case TokenType::OpAdd: + case TokenType::OpSub: + return Precedence::Additive; + case TokenType::OpMul: + case TokenType::OpDiv: + case TokenType::OpMod: + return Precedence::Multiplicative; + default: + return Precedence::Invalid; + } + } + + Operator GetOpFromToken(Token & token) + { + switch(token.Type) + { + case TokenType::Comma: + return Operator::Sequence; + case TokenType::OpAssign: + return Operator::Assign; + case TokenType::OpAddAssign: + return Operator::AddAssign; + case TokenType::OpSubAssign: + return Operator::SubAssign; + case TokenType::OpMulAssign: + return Operator::MulAssign; + case TokenType::OpDivAssign: + return Operator::DivAssign; + case TokenType::OpModAssign: + return Operator::ModAssign; + case TokenType::OpShlAssign: + return Operator::LshAssign; + case TokenType::OpShrAssign: + return Operator::RshAssign; + case TokenType::OpOrAssign: + return Operator::OrAssign; + case TokenType::OpAndAssign: + return Operator::AddAssign; + case TokenType::OpXorAssign: + return Operator::XorAssign; + case TokenType::OpOr: + return Operator::Or; + case TokenType::OpAnd: + return Operator::And; + case TokenType::OpBitOr: + return Operator::BitOr; + case TokenType::OpBitXor: + return Operator::BitXor; + case TokenType::OpBitAnd: + return Operator::BitAnd; + case TokenType::OpEql: + return Operator::Eql; + case TokenType::OpNeq: + return Operator::Neq; + case TokenType::OpGeq: + return Operator::Geq; + case TokenType::OpLeq: + return Operator::Leq; + case TokenType::OpGreater: + return Operator::Greater; + case TokenType::OpLess: + return Operator::Less; + case TokenType::OpLsh: + return Operator::Lsh; + case TokenType::OpRsh: + return Operator::Rsh; + case TokenType::OpAdd: + return Operator::Add; + case TokenType::OpSub: + return Operator::Sub; + case TokenType::OpMul: + return Operator::Mul; + case TokenType::OpDiv: + return Operator::Div; + case TokenType::OpMod: + return Operator::Mod; + case TokenType::OpInc: + return Operator::PostInc; + case TokenType::OpDec: + return Operator::PostDec; + case TokenType::OpNot: + return Operator::Not; + case TokenType::OpBitNot: + return Operator::BitNot; + default: + throw "Illegal TokenType."; + } + } + + static RefPtr<ExpressionSyntaxNode> parseOperator(Parser* parser) + { + Token opToken; + switch(parser->tokenReader.PeekTokenType()) + { + case TokenType::QuestionMark: + opToken = parser->ReadToken(); + opToken.Content = "?:"; + break; + + default: + opToken = parser->ReadToken(); + break; + } + + auto opExpr = new VarExpressionSyntaxNode(); + opExpr->Variable = opToken.Content; + opExpr->scope = parser->currentScope; + opExpr->Position = opToken.Position; + + return opExpr; + + } + + RefPtr<ExpressionSyntaxNode> Parser::ParseExpression(Precedence level) + { + if (level == Precedence::Prefix) + return ParseLeafExpression(); + if (level == Precedence::TernaryConditional) + { + // parse select clause + auto condition = ParseExpression(Precedence(level + 1)); + if (LookAheadToken(TokenType::QuestionMark)) + { + RefPtr<SelectExpressionSyntaxNode> select = new SelectExpressionSyntaxNode(); + FillPosition(select.Ptr()); + + select->Arguments.Add(condition); + + select->FunctionExpr = parseOperator(this); + + select->Arguments.Add(ParseExpression(level)); + ReadToken(TokenType::Colon); + select->Arguments.Add(ParseExpression(level)); + return select; + } + else + return condition; + } + else + { + if (GetAssociativityFromLevel(level) == Associativity::Left) + { + auto left = ParseExpression(Precedence(level + 1)); + while (GetOpLevel(this, tokenReader.PeekTokenType()) == level) + { + RefPtr<OperatorExpressionSyntaxNode> tmp = new InfixExpr(); + tmp->FunctionExpr = parseOperator(this); + + tmp->Arguments.Add(left); + FillPosition(tmp.Ptr()); + tmp->Arguments.Add(ParseExpression(Precedence(level + 1))); + left = tmp; + } + return left; + } + else + { + auto left = ParseExpression(Precedence(level + 1)); + if (GetOpLevel(this, tokenReader.PeekTokenType()) == level) + { + RefPtr<OperatorExpressionSyntaxNode> tmp = new InfixExpr(); + tmp->Arguments.Add(left); + FillPosition(tmp.Ptr()); + tmp->FunctionExpr = parseOperator(this); + tmp->Arguments.Add(ParseExpression(level)); + left = tmp; + } + return left; + } + } + } + + RefPtr<ExpressionSyntaxNode> Parser::ParseLeafExpression() + { + RefPtr<ExpressionSyntaxNode> rs; + if (LookAheadToken(TokenType::OpInc) || + LookAheadToken(TokenType::OpDec) || + LookAheadToken(TokenType::OpNot) || + LookAheadToken(TokenType::OpBitNot) || + LookAheadToken(TokenType::OpSub)) + { + RefPtr<OperatorExpressionSyntaxNode> unaryExpr = new PrefixExpr(); + FillPosition(unaryExpr.Ptr()); + unaryExpr->FunctionExpr = parseOperator(this); + unaryExpr->Arguments.Add(ParseLeafExpression()); + rs = unaryExpr; + return rs; + } + + if (LookAheadToken(TokenType::LParent)) + { + ReadToken(TokenType::LParent); + RefPtr<ExpressionSyntaxNode> expr; + if (peekTypeName(this) && LookAheadToken(TokenType::RParent, 1)) + { + RefPtr<TypeCastExpressionSyntaxNode> tcexpr = new TypeCastExpressionSyntaxNode(); + FillPosition(tcexpr.Ptr()); + tcexpr->TargetType = ParseTypeExp(); + ReadToken(TokenType::RParent); + tcexpr->Expression = ParseExpression(Precedence::Multiplicative); // Note(tfoley): need to double-check this + expr = tcexpr; + } + else + { + expr = ParseExpression(); + ReadToken(TokenType::RParent); + } + rs = expr; + } + else if( LookAheadToken(TokenType::LBrace) ) + { + RefPtr<InitializerListExpr> initExpr = new InitializerListExpr(); + FillPosition(initExpr.Ptr()); + + // Initializer list + ReadToken(TokenType::LBrace); + + List<RefPtr<ExpressionSyntaxNode>> exprs; + + for(;;) + { + if(AdvanceIfMatch(this, TokenType::RBrace)) + break; + + auto expr = ParseArgExpr(); + if( expr ) + { + initExpr->args.Add(expr); + } + + if(AdvanceIfMatch(this, TokenType::RBrace)) + break; + + ReadToken(TokenType::Comma); + } + rs = initExpr; + } + + else if (LookAheadToken(TokenType::IntLiterial) || + LookAheadToken(TokenType::DoubleLiterial)) + { + RefPtr<ConstantExpressionSyntaxNode> constExpr = new ConstantExpressionSyntaxNode(); + auto token = tokenReader.AdvanceToken(); + FillPosition(constExpr.Ptr()); + if (token.Type == TokenType::IntLiterial) + { + constExpr->ConstType = ConstantExpressionSyntaxNode::ConstantType::Int; + constExpr->IntValue = StringToInt(token.Content); + } + else if (token.Type == TokenType::DoubleLiterial) + { + constExpr->ConstType = ConstantExpressionSyntaxNode::ConstantType::Float; + constExpr->FloatValue = (FloatingPointLiteralValue) StringToDouble(token.Content); + } + rs = constExpr; + } + else if (LookAheadToken("true") || LookAheadToken("false")) + { + RefPtr<ConstantExpressionSyntaxNode> constExpr = new ConstantExpressionSyntaxNode(); + auto token = tokenReader.AdvanceToken(); + FillPosition(constExpr.Ptr()); + constExpr->ConstType = ConstantExpressionSyntaxNode::ConstantType::Bool; + constExpr->IntValue = token.Content == "true" ? 1 : 0; + rs = constExpr; + } + else if (LookAheadToken(TokenType::Identifier)) + { + RefPtr<VarExpressionSyntaxNode> varExpr = new VarExpressionSyntaxNode(); + varExpr->scope = currentScope.Ptr(); + FillPosition(varExpr.Ptr()); + auto token = ReadToken(TokenType::Identifier); + varExpr->Variable = token.Content; + rs = varExpr; + } + + while (!tokenReader.IsAtEnd() && + (LookAheadToken(TokenType::OpInc) || + LookAheadToken(TokenType::OpDec) || + LookAheadToken(TokenType::Dot) || + LookAheadToken(TokenType::LBracket) || + LookAheadToken(TokenType::LParent))) + { + if (LookAheadToken(TokenType::OpInc)) + { + RefPtr<OperatorExpressionSyntaxNode> unaryExpr = new PostfixExpr(); + FillPosition(unaryExpr.Ptr()); + unaryExpr->FunctionExpr = parseOperator(this); + unaryExpr->Arguments.Add(rs); + rs = unaryExpr; + } + else if (LookAheadToken(TokenType::OpDec)) + { + RefPtr<OperatorExpressionSyntaxNode> unaryExpr = new PostfixExpr(); + FillPosition(unaryExpr.Ptr()); + unaryExpr->FunctionExpr = parseOperator(this); + unaryExpr->Arguments.Add(rs); + rs = unaryExpr; + } + else if (LookAheadToken(TokenType::LBracket)) + { + RefPtr<IndexExpressionSyntaxNode> indexExpr = new IndexExpressionSyntaxNode(); + indexExpr->BaseExpression = rs; + FillPosition(indexExpr.Ptr()); + ReadToken(TokenType::LBracket); + indexExpr->IndexExpression = ParseExpression(); + ReadToken(TokenType::RBracket); + rs = indexExpr; + } + else if (LookAheadToken(TokenType::LParent)) + { + RefPtr<InvokeExpressionSyntaxNode> invokeExpr = new InvokeExpressionSyntaxNode(); + invokeExpr->FunctionExpr = rs; + FillPosition(invokeExpr.Ptr()); + ReadToken(TokenType::LParent); + while (!tokenReader.IsAtEnd()) + { + if (!LookAheadToken(TokenType::RParent)) + invokeExpr->Arguments.Add(ParseArgExpr()); + else + { + break; + } + if (!LookAheadToken(TokenType::Comma)) + break; + ReadToken(TokenType::Comma); + } + ReadToken(TokenType::RParent); + rs = invokeExpr; + } + else if (LookAheadToken(TokenType::Dot)) + { + RefPtr<MemberExpressionSyntaxNode> memberExpr = new MemberExpressionSyntaxNode(); + memberExpr->scope = currentScope.Ptr(); + FillPosition(memberExpr.Ptr()); + memberExpr->BaseExpression = rs; + ReadToken(TokenType::Dot); + memberExpr->MemberName = ReadToken(TokenType::Identifier).Content; + rs = memberExpr; + } + } + if (!rs) + { + sink->diagnose(tokenReader.PeekLoc(), Diagnostics::syntaxError); + } + return rs; + } + + // Parse a source file into an existing translation unit + void parseSourceFile( + ProgramSyntaxNode* translationUnitSyntax, + CompileOptions& options, + TokenSpan const& tokens, + DiagnosticSink* sink, + String const& fileName, + RefPtr<Scope> const&outerScope) + { + Parser parser(options, tokens, sink, fileName, outerScope); + return parser.parseSourceFile(translationUnitSyntax); + } + } +}
\ No newline at end of file diff --git a/source/slang/parser.h b/source/slang/parser.h new file mode 100644 index 000000000..90af69158 --- /dev/null +++ b/source/slang/parser.h @@ -0,0 +1,23 @@ +#ifndef RASTER_RENDERER_PARSER_H +#define RASTER_RENDERER_PARSER_H + +#include "lexer.h" +#include "compiler.h" +#include "syntax.h" + +namespace Slang +{ + namespace Compiler + { + // Parse a source file into an existing translation unit + void parseSourceFile( + ProgramSyntaxNode* translationUnitSyntax, + CompileOptions& options, + TokenSpan const& tokens, + DiagnosticSink* sink, + String const& fileName, + RefPtr<Scope> const&outerScope); + } +} + +#endif
\ No newline at end of file diff --git a/source/slang/preprocessor.cpp b/source/slang/preprocessor.cpp new file mode 100644 index 000000000..cdde2591d --- /dev/null +++ b/source/slang/preprocessor.cpp @@ -0,0 +1,2032 @@ +// Preprocessor.cpp +#include "Preprocessor.h" + +#include "Diagnostics.h" +#include "Lexer.h" + +// Needed so that we can construct modifier syntax +// to represent GLSL directives +#include "Syntax.h" + +#include <assert.h> + +using namespace CoreLib; + +// This file provides an implementation of a simple C-style preprocessor. +// It does not aim for 100% compatibility with any particular preprocessor +// specification, but the goal is to have it accept the most common +// idioms for using the preprocessor, found in shader code in the wild. + + +namespace Slang{ namespace Compiler { + +// State of a preprocessor conditional, which can change when +// we encounter directives like `#elif` or `#endif` +enum class PreprocessorConditionalState +{ + Before, // We have not yet seen a branch with a `true` condition. + During, // We are inside the branch with a `true` condition. + After, // We have already seen the branch with a `true` condition. +}; + +// Represents a preprocessor conditional that we are currently +// nested inside. +struct PreprocessorConditional +{ + // The next outer conditional in the current file/stream, or NULL. + PreprocessorConditional* parent; + + // The directive token that started the conditional (an `#if` or `#ifdef`) + Token ifToken; + + // The `#else` directive token, if one has been seen (otherwise `TokenType::Unknown`) + Token elseToken; + + // The state of the conditional + PreprocessorConditionalState state; +}; + +struct PreprocessorMacro; + +struct PreprocessorEnvironment +{ + // The "outer" environment, to be used if lookup in this env fails + PreprocessorEnvironment* parent = NULL; + + // Macros defined in this environment + Dictionary<String, PreprocessorMacro*> macros; + + ~PreprocessorEnvironment(); +}; + +// Input tokens can either come from source text, or from macro expansion. +// In general, input streams can be nested, so we have to keep a conceptual +// stack of input. + +// A stream of input tokens to be consumed +struct PreprocessorInputStream +{ + // The next input stream up the stack, if any. + PreprocessorInputStream* parent; + + // The deepest preprocessor conditional active for this stream. + PreprocessorConditional* conditional; + + // Environment to use when looking up macros + PreprocessorEnvironment* environment; + + // Reader for pre-tokenized input + TokenReader tokenReader; + + // If we are clobbering source locations with `#line`, then + // the state is tracked here: + + // Are we overriding source locations? + bool isOverridingSourceLoc; + + // What is the file name we are overriding to? + String overrideFileName; + + // What is the relative offset to apply to any line numbers? + int overrideLineOffset; + + // Destructor is virtual so that we can clean up + // after concrete subtypes. + virtual ~PreprocessorInputStream() = default; +}; + +struct SourceTextInputStream : PreprocessorInputStream +{ + // The pre-tokenized input + TokenList lexedTokens; +}; + +struct MacroExpansion : PreprocessorInputStream +{ + // The macro we will expand + PreprocessorMacro* macro; +}; + +struct ObjectLikeMacroExpansion : MacroExpansion +{ +}; + +struct FunctionLikeMacroExpansion : MacroExpansion +{ + // Environment for macro arguments + PreprocessorEnvironment argumentEnvironment; +}; + +// An enumeration for the diferent types of macros +enum class PreprocessorMacroFlavor +{ + ObjectLike, + FunctionArg, + FunctionLike, +}; + +// In the current design (which we may want to re-consider), +// a macro is a specialized flavor of input stream, that +// captures the token list in its expansion, and then +// can be "played back." +struct PreprocessorMacro +{ + // The name under which the macro was `#define`d + Token nameToken; + + // Parameters of the macro, in case of a function-like macro + List<Token> params; + + // The tokens that make up the macro body + TokenList tokens; + + // The flavor of macro + PreprocessorMacroFlavor flavor; + + // The environment in which this macro needs to be expanded. + // For ordinary macros this will be the global environment, + // while for function-like macro arguments, it will be + // the environment of the macro invocation. + PreprocessorEnvironment* environment; +}; + +// State of the preprocessor +struct Preprocessor +{ + // diagnostics sink to use when writing messages + DiagnosticSink* sink; + + // An external callback interface to use when looking + // for files in a `#include` directive + IncludeHandler* includeHandler; + + // Current input stream (top of the stack of input) + PreprocessorInputStream* inputStream; + + // Currently-defined macros + PreprocessorEnvironment globalEnv; + + // A pre-allocated token that can be returned to + // represent end-of-input situations. + Token endOfFileToken; + + // Syntax for the program we are trying to parse + ProgramSyntaxNode* syntax; +}; + +// Convenience routine to access the diagnostic sink +static DiagnosticSink* GetSink(Preprocessor* preprocessor) +{ + return preprocessor->sink; +} + +// +// Forward declarations +// + +static void DestroyConditional(PreprocessorConditional* conditional); +static void DestroyMacro(Preprocessor* preprocessor, PreprocessorMacro* macro); + +// +// Basic Input Handling +// + +// Create a fresh input stream +static void InitializeInputStream(Preprocessor* preprocessor, PreprocessorInputStream* inputStream) +{ + inputStream->parent = NULL; + inputStream->conditional = NULL; + inputStream->environment = &preprocessor->globalEnv; +} + +// Destroy an input stream +static void DestroyInputStream(Preprocessor* /*preprocessor*/, PreprocessorInputStream* inputStream) +{ + delete inputStream; +} + +// Create an input stream to represent a pre-tokenized input file. +// TODO(tfoley): pre-tokenizing files isn't going to work in the long run. +static PreprocessorInputStream* CreateInputStreamForSource(Preprocessor* preprocessor, CoreLib::String const& source, CoreLib::String const& fileName) +{ + SourceTextInputStream* inputStream = new SourceTextInputStream(); + InitializeInputStream(preprocessor, inputStream); + + // Use existing `Lexer` to generate a token stream. + Lexer lexer(fileName, source, GetSink(preprocessor)); + inputStream->lexedTokens = lexer.lexAllTokens(); + inputStream->tokenReader = TokenReader(inputStream->lexedTokens); + + return inputStream; +} + + + +static void PushInputStream(Preprocessor* preprocessor, PreprocessorInputStream* inputStream) +{ + inputStream->parent = preprocessor->inputStream; + preprocessor->inputStream = inputStream; +} + +// Called when we reach the end of an input stream. +// Performs some validation and then destroys the input stream if required. +static void EndInputStream(Preprocessor* preprocessor, PreprocessorInputStream* inputStream) +{ + // If there are any conditionals that weren't completed, then it is an error + if (inputStream->conditional) + { + PreprocessorConditional* conditional = inputStream->conditional; + + GetSink(preprocessor)->diagnose(conditional->ifToken.Position, Diagnostics::endOfFileInPreprocessorConditional); + + while (conditional) + { + PreprocessorConditional* parent = conditional->parent; + DestroyConditional(conditional); + conditional = parent; + } + } + + DestroyInputStream(preprocessor, inputStream); +} + +// Potentially clobber source location information based on `#line` +static Token PossiblyOverrideSourceLoc(PreprocessorInputStream* inputStream, Token const& token) +{ + Token result = token; + if( inputStream->isOverridingSourceLoc ) + { + result.Position.FileName = inputStream->overrideFileName; + result.Position.Line += inputStream->overrideLineOffset; + } + return result; +} + +// Consume one token from an input stream +static Token AdvanceRawToken(PreprocessorInputStream* inputStream) +{ + return PossiblyOverrideSourceLoc(inputStream, inputStream->tokenReader.AdvanceToken()); +} + +// Peek one token from an input stream +static Token PeekRawToken(PreprocessorInputStream* inputStream) +{ + return PossiblyOverrideSourceLoc(inputStream, inputStream->tokenReader.PeekToken()); +} + +// Peek one token type from an input stream +static TokenType PeekRawTokenType(PreprocessorInputStream* inputStream) +{ + return inputStream->tokenReader.PeekTokenType(); +} + + +// Read one token in "raw" mode (meaning don't expand macros) +static Token AdvanceRawToken(Preprocessor* preprocessor) +{ + for (;;) + { + // Look at the input stream on top of the stack + PreprocessorInputStream* inputStream = preprocessor->inputStream; + + // If there isn't one, then there is no more input left to read. + if (!inputStream) + { + return preprocessor->endOfFileToken; + } + + // The top-most input stream may be at its end + if (PeekRawTokenType(inputStream) == TokenType::EndOfFile) + { + // If there is another stream remaining, switch to it + if (inputStream->parent) + { + preprocessor->inputStream = inputStream->parent; + EndInputStream(preprocessor, inputStream); + continue; + } + else + { + // HACK(tfoley): A place to fall into debugger... + int f = 0; + } + } + + // Everything worked, so read a token from the top-most stream + return AdvanceRawToken(inputStream); + } +} + +// Return the next token in "raw" mode, but don't advance the +// current token state. +static Token PeekRawToken(Preprocessor* preprocessor) +{ + // We need to find the strema that `advanceRawToken` would read from. + PreprocessorInputStream* inputStream = preprocessor->inputStream; + for (;;) + { + if (!inputStream) + { + // No more input streams left to read + return preprocessor->endOfFileToken; + } + + // The top-most input stream may be at its end, so + // look one entry up the stack (don't actually pop + // here, since we are just peeking) + if (PeekRawTokenType(inputStream) == TokenType::EndOfFile) + { + if (inputStream->parent) + { + inputStream = inputStream->parent; + continue; + } + else + { + // HACK(tfoley): A place to fall into debugger... + int f = 0; + } + } + + // Everything worked, so the token we just peeked is fine. + return PeekRawToken(inputStream); + } +} + +// Without advancing preprocessor state, look *two* raw tokens ahead +// (This is only needed in order to determine when we are possibly +// expanding a function-style macro) +TokenType PeekSecondRawTokenType(Preprocessor* preprocessor) +{ + // We need to find the strema that `advanceRawToken` would read from. + PreprocessorInputStream* inputStream = preprocessor->inputStream; + int count = 1; + for (;;) + { + if (!inputStream) + { + // No more input streams left to read + return TokenType::EndOfFile; + } + + // The top-most input stream may be at its end, so + // look one entry up the stack (don't actually pop + // here, since we are just peeking) + + TokenReader reader = inputStream->tokenReader; + if (reader.PeekTokenType() == TokenType::EndOfFile) + { + inputStream = inputStream->parent; + continue; + } + + if (count) + { + count--; + + // Note: we are advancing our temporary + // copy of the token reader + reader.AdvanceToken(); + if (reader.PeekTokenType() == TokenType::EndOfFile) + { + inputStream = inputStream->parent; + continue; + } + } + + // Everything worked, so peek a token from the top-most stream + return reader.PeekTokenType(); + } +} + + +// Get the location of the current (raw) token +static CodePosition PeekLoc(Preprocessor* preprocessor) +{ + return PeekRawToken(preprocessor).Position; +} + +// Get the `TokenType` of the current (raw) token +static TokenType PeekRawTokenType(Preprocessor* preprocessor) +{ + return PeekRawToken(preprocessor).Type; +} + +// +// Macros +// + +// Create a macro +static PreprocessorMacro* CreateMacro(Preprocessor* preprocessor) +{ + // TODO(tfoley): Allocate these more intelligently. + // For example, consider pooling them on the preprocessor. + + PreprocessorMacro* macro = new PreprocessorMacro(); + macro->flavor = PreprocessorMacroFlavor::ObjectLike; + macro->environment = &preprocessor->globalEnv; + return macro; +} + +// Destroy a macro +static void DestroyMacro(Preprocessor* /*preprocessor*/, PreprocessorMacro* macro) +{ + delete macro; +} + + +// Find the currently-defined macro of the given name, or return NULL +static PreprocessorMacro* LookupMacro(PreprocessorEnvironment* environment, String const& name) +{ + for(PreprocessorEnvironment* e = environment; e; e = e->parent) + { + PreprocessorMacro* macro = NULL; + if (e->macros.TryGetValue(name, macro)) + return macro; + } + + return NULL; +} + +static PreprocessorEnvironment* GetCurrentEnvironment(Preprocessor* preprocessor) +{ + PreprocessorInputStream* inputStream = preprocessor->inputStream; + return inputStream ? inputStream->environment : &preprocessor->globalEnv; +} + +static PreprocessorMacro* LookupMacro(Preprocessor* preprocessor, String const& name) +{ + return LookupMacro(GetCurrentEnvironment(preprocessor), name); +} + +// A macro is "busy" if it is currently being used for expansion. +// A macro cannot be expanded again while busy, to avoid infinite recursion. +static bool IsMacroBusy(PreprocessorMacro* /*macro*/) +{ + // TODO: need to implement this correctly + // + // The challenge here is that we are implementing expansion + // for argumenst to function-like macros in a "lazy" fashion. + // + // The letter of the spec is that we should macro expand + // each argument *before* substitution, and then go and + // macro-expand the substituted body. This means that we + // can invoke a macro as part of an argument to an + // invocation of the same macro: + // + // FOO( 1, FOO(22), 333 ); + // + // In our implementation, the "inner" invocation of `FOO` + // gets expanded at the point where it gets referenced + // in the body of the "outer" invocation of `FOO`. + // Doing things this way leads to greatly simplified + // code for handling expansion. + // + // A proper implementation of `IsMacroBusy` needs to + // take context into account, so that it bans recursive + // use of a macro when it occurs (indirectly) through + // the *body* of the expansion, but not when it occcurs + // only through an *argument*. + return false; +} + +// +// Reading Tokens With Expansion +// + +static void InitializeMacroExpansion( + Preprocessor* preprocessor, + MacroExpansion* expansion, + PreprocessorMacro* macro) +{ + InitializeInputStream(preprocessor, expansion); + expansion->environment = macro->environment; + expansion->macro = macro; + expansion->tokenReader = TokenReader(macro->tokens); +} + +static void PushMacroExpansion( + Preprocessor* preprocessor, + MacroExpansion* expansion) +{ + PushInputStream(preprocessor, expansion); +} + +static void AddEndOfStreamToken( + Preprocessor* preprocessor, + PreprocessorMacro* macro) +{ + Token token = PeekRawToken(preprocessor); + token.Type = TokenType::EndOfFile; + macro->tokens.mTokens.Add(token); +} + +// Check whether the current token on the given input stream should be +// treated as a macro invocation, and if so set up state for expanding +// that macro. +static void MaybeBeginMacroExpansion( + Preprocessor* preprocessor ) +{ + // We iterate because the first token in the expansion of one + // macro may be another macro invocation. + for (;;) + { + // Look at the next token ahead of us + Token const& token = PeekRawToken(preprocessor); + + // Not an identifier? Can't be a macro. + if (token.Type != TokenType::Identifier) + return; + + // Look for a macro with the given name. + String name = token.Content; + PreprocessorMacro* macro = LookupMacro(preprocessor, name); + + // Not a macro? Can't be an invocation. + if (!macro) + return; + + // If the macro is busy (already being expanded), + // don't try to trigger recursive expansion + if (IsMacroBusy(macro)) + return; + + // A function-style macro invocation should only match + // if the token *after* the identifier is `(`. This + // requires more lookahead than we usually have/need + if (macro->flavor == PreprocessorMacroFlavor::FunctionLike) + { + if(PeekSecondRawTokenType(preprocessor) != TokenType::LParent) + return; + + // Consume the token that triggered macro expansion + AdvanceRawToken(preprocessor); + + // Consume the opening `(` + Token leftParen = AdvanceRawToken(preprocessor); + + FunctionLikeMacroExpansion* expansion = new FunctionLikeMacroExpansion(); + InitializeMacroExpansion(preprocessor, expansion, macro); + expansion->argumentEnvironment.parent = &preprocessor->globalEnv; + expansion->environment = &expansion->argumentEnvironment; + + // Try to read any arguments present. + int paramCount = macro->params.Count(); + int argIndex = 0; + + switch (PeekRawTokenType(preprocessor)) + { + case TokenType::EndOfFile: + case TokenType::RParent: + // No arguments. + break; + + default: + // At least one argument + while(argIndex < paramCount) + { + // Read an argument + + // Create the argument, represented as a special flavor of macro + PreprocessorMacro* arg = CreateMacro(preprocessor); + arg->flavor = PreprocessorMacroFlavor::FunctionArg; + arg->environment = GetCurrentEnvironment(preprocessor); + + // Associate the new macro with its parameter name + Token paramToken = macro->params[argIndex]; + String const& paramName = paramToken.Content; + arg->nameToken = paramToken; + expansion->argumentEnvironment.macros[paramName] = arg; + argIndex++; + + // Read tokens for the argument + + // We track the nesting depth, since we don't break + // arguments on a `,` nested in balanced parentheses + // + int nesting = 0; + for (;;) + { + switch (PeekRawTokenType(preprocessor)) + { + case TokenType::EndOfFile: + // if we reach the end of the file, + // then we have an error, and need to + // bail out + AddEndOfStreamToken(preprocessor, arg); + goto doneWithAllArguments; + + case TokenType::RParent: + // If we see a right paren when we aren't nested + // then we are at the end of an argument + if (nesting == 0) + { + AddEndOfStreamToken(preprocessor, arg); + goto doneWithAllArguments; + } + // Otherwise we decrease our nesting depth, add + // the token, and keep going + nesting--; + break; + + case TokenType::Comma: + // If we see a comma when we aren't nested + // then we are at the end of an argument + if (nesting == 0) + { + AddEndOfStreamToken(preprocessor, arg); + AdvanceRawToken(preprocessor); + goto doneWithArgument; + } + // Otherwise we add it as a normal token + break; + + case TokenType::LParent: + // If we see a left paren then we need to + // increase our tracking of nesting + nesting++; + break; + + default: + break; + } + + // Add the token and continue parsing. + arg->tokens.mTokens.Add(AdvanceRawToken(preprocessor)); + } + doneWithArgument: {} + // We've parsed an argument and should move onto + // the next one. + } + break; + } + doneWithAllArguments: + // TODO: handle possible varargs + + // Expect closing right paren + if (PeekRawTokenType(preprocessor) == TokenType::RParent) + { + AdvanceRawToken(preprocessor); + } + else + { + GetSink(preprocessor)->diagnose(PeekLoc(preprocessor), Diagnostics::expectedTokenInMacroArguments, TokenType::RParent, PeekRawTokenType(preprocessor)); + } + + int argCount = argIndex; + if (argCount != paramCount) + { + // TODO: diagnose + throw 99; + } + + // We are ready to expand. + PushMacroExpansion(preprocessor, expansion); + } + else + { + // Consume the token that triggered macro expansion + AdvanceRawToken(preprocessor); + + // Object-like macros are the easy case. + ObjectLikeMacroExpansion* expansion = new ObjectLikeMacroExpansion(); + InitializeMacroExpansion(preprocessor, expansion, macro); + PushMacroExpansion(preprocessor, expansion); + } + } +} + +// Read one token with macro-expansion enabled. +static Token AdvanceToken(Preprocessor* preprocessor) +{ +top: + // Check whether we need to macro expand at the cursor. + MaybeBeginMacroExpansion(preprocessor); + + // Read a raw token (now that expansion has been triggered) + Token token = AdvanceRawToken(preprocessor); + + // Check if we need to perform token pasting + if (PeekRawTokenType(preprocessor) != TokenType::PoundPound) + { + // If we aren't token pasting, then we are done + return token; + } + else + { + // We are pasting tokens, which could get messy + + StringBuilder sb; + sb << token.Content; + + while (PeekRawTokenType(preprocessor) == TokenType::PoundPound) + { + // Consume the `##` + AdvanceRawToken(preprocessor); + + // Possibly macro-expand the next token + MaybeBeginMacroExpansion(preprocessor); + + // Read the next raw token (now that expansion has been triggered) + Token nextToken = AdvanceRawToken(preprocessor); + + sb << nextToken.Content; + } + + // Now re-lex the input + PreprocessorInputStream* inputStream = CreateInputStreamForSource(preprocessor, sb.ProduceString(), "token paste"); + if (inputStream->tokenReader.GetCount() != 1) + { + // We expect a token paste to produce a single token + // TODO(tfoley): emit a diagnostic here + } + + PushInputStream(preprocessor, inputStream); + goto top; + } +} + +// Read one token with macro-expansion enabled. +// +// Note that because triggering macro expansion may +// involve changing the input-stream state, this +// operation *can* have side effects. +static Token PeekToken(Preprocessor* preprocessor) +{ + // Check whether we need to macro expand at the cursor. + MaybeBeginMacroExpansion(preprocessor); + + // Peek a raw token (now that expansion has been triggered) + return PeekRawToken(preprocessor); + + // TODO: need a plan for how to handle token pasting + // here without it being onerous. Would be nice if we + // didn't have to re-do pasting on a "peek"... +} + +// Peek the type of the next token, including macro expansion. +static TokenType PeekTokenType(Preprocessor* preprocessor) +{ + return PeekToken(preprocessor).Type; +} + +// +// Preprocessor Directives +// + +// When reading a preprocessor directive, we use a context +// to wrap the direct preprocessor routines defines so far. +// +// One of the most important things the directive context +// does is give us a convenient way to read tokens with +// a guarantee that we won't read past the end of a line. +struct PreprocessorDirectiveContext +{ + // The preprocessor that is parsing the directive. + Preprocessor* preprocessor; + + // The directive token (e.g., the `if` in `#if`). + // Useful for reference in diagnostic messages. + Token directiveToken; + + // Has any kind of parse error been encountered in + // the directive so far? + bool parseError; + + // Have we done the necessary checks at the end + // of the directive already? + bool haveDoneEndOfDirectiveChecks; +}; + +// Get the token for the preprocessor directive being parsed. +inline Token const& GetDirective(PreprocessorDirectiveContext* context) +{ + return context->directiveToken; +} + +// Get the name of the directive being parsed. +inline String const& GetDirectiveName(PreprocessorDirectiveContext* context) +{ + return context->directiveToken.Content; +} + +// Get the location of the directive being parsed. +inline CodePosition const& GetDirectiveLoc(PreprocessorDirectiveContext* context) +{ + return context->directiveToken.Position; +} + +// Wrapper to get the diagnostic sink in the context of a directive. +static inline DiagnosticSink* GetSink(PreprocessorDirectiveContext* context) +{ + return GetSink(context->preprocessor); +} + +// Wrapper to get a "current" location when parsing a directive +static CodePosition PeekLoc(PreprocessorDirectiveContext* context) +{ + return PeekLoc(context->preprocessor); +} + +// Wrapper to look up a macro in the context of a directive. +static PreprocessorMacro* LookupMacro(PreprocessorDirectiveContext* context, String const& name) +{ + return LookupMacro(context->preprocessor, name); +} + +// Determine if we have read everthing on the directive's line. +static bool IsEndOfLine(PreprocessorDirectiveContext* context) +{ + return PeekRawToken(context->preprocessor).Type == TokenType::EndOfDirective; +} + +// Peek one raw token in a directive, without going past the end of the line. +static Token PeekRawToken(PreprocessorDirectiveContext* context) +{ + return PeekRawToken(context->preprocessor); +} + +// Read one raw token in a directive, without going past the end of the line. +static Token AdvanceRawToken(PreprocessorDirectiveContext* context) +{ + if (IsEndOfLine(context)) + return PeekRawToken(context); + return AdvanceRawToken(context->preprocessor); +} + +// Peek next raw token type, without going past the end of the line. +static TokenType PeekRawTokenType(PreprocessorDirectiveContext* context) +{ + return PeekRawTokenType(context->preprocessor); +} + +// Read one token, with macro-expansion, without going past the end of the line. +static Token AdvanceToken(PreprocessorDirectiveContext* context) +{ + if (IsEndOfLine(context)) + return PeekRawToken(context); + return AdvanceToken(context->preprocessor); +} + +// Peek one token, with macro-expansion, without going past the end of the line. +static Token PeekToken(PreprocessorDirectiveContext* context) +{ + if (IsEndOfLine(context)) + context->preprocessor->endOfFileToken; + return PeekToken(context->preprocessor); +} + +// Peek next token type, with macro-expansion, without going past the end of the line. +static TokenType PeekTokenType(PreprocessorDirectiveContext* context) +{ + if (IsEndOfLine(context)) + return TokenType::EndOfDirective; + return PeekTokenType(context->preprocessor); +} + +// Skip to the end of the line (useful for recovering from errors in a directive) +static void SkipToEndOfLine(PreprocessorDirectiveContext* context) +{ + while(!IsEndOfLine(context)) + { + AdvanceRawToken(context); + } +} + +static bool ExpectRaw(PreprocessorDirectiveContext* context, TokenType tokenType, DiagnosticInfo const& diagnostic, Token* outToken = NULL) +{ + if (PeekRawTokenType(context) != tokenType) + { + // Only report the first parse error within a directive + if (!context->parseError) + { + GetSink(context)->diagnose(PeekLoc(context), diagnostic, tokenType, GetDirectiveName(context)); + } + context->parseError = true; + return false; + } + Token const& token = AdvanceRawToken(context); + if (outToken) + *outToken = token; + return true; +} + +static bool Expect(PreprocessorDirectiveContext* context, TokenType tokenType, DiagnosticInfo const& diagnostic, Token* outToken = NULL) +{ + if (PeekTokenType(context) != tokenType) + { + // Only report the first parse error within a directive + if (!context->parseError) + { + GetSink(context)->diagnose(PeekLoc(context), diagnostic, tokenType, GetDirectiveName(context)); + context->parseError = true; + } + return false; + } + Token const& token = AdvanceToken(context); + if (outToken) + *outToken = token; + return true; +} + + + +// +// Preprocessor Conditionals +// + +// Determine whether the current preprocessor state means we +// should be skipping tokens. +static bool IsSkipping(Preprocessor* preprocessor) +{ + PreprocessorInputStream* inputStream = preprocessor->inputStream; + if (!inputStream) return false; + + // If we are not inside a preprocessor conditional, then don't skip + PreprocessorConditional* conditional = inputStream->conditional; + if (!conditional) return false; + + // skip tokens unless the conditional is inside its `true` case + return conditional->state != PreprocessorConditionalState::During; +} + +// Wrapper for use inside directives +static inline bool IsSkipping(PreprocessorDirectiveContext* context) +{ + return IsSkipping(context->preprocessor); +} + +// Create a preprocessor conditional +static PreprocessorConditional* CreateConditional(Preprocessor* /*preprocessor*/) +{ + // TODO(tfoley): allocate these more intelligently (for example, + // pool them on the `Preprocessor`. + return new PreprocessorConditional(); +} + +// Destroy a preprocessor conditional. +static void DestroyConditional(PreprocessorConditional* conditional) +{ + delete conditional; +} + +// Start a preprocessor conditional, with an initial enable/disable state. +static void BeginConditional(PreprocessorDirectiveContext* context, bool enable) +{ + Preprocessor* preprocessor = context->preprocessor; + PreprocessorInputStream* inputStream = preprocessor->inputStream; + assert(inputStream); + + PreprocessorConditional* conditional = CreateConditional(preprocessor); + + conditional->ifToken = context->directiveToken; + + // Set state of this condition appropriately. + // + // Default to the "haven't yet seen a `true` branch" state. + PreprocessorConditionalState state = PreprocessorConditionalState::Before; + // + // If we are nested inside a `false` branch of another condition, then + // we never want to enable, so we act as if we already *saw* the `true` branch. + // + if (IsSkipping(preprocessor)) state = PreprocessorConditionalState::After; + // + // Similarly, if we ran into any parse errors when dealing with the + // opening directive, then things are probably screwy and we should just + // skip all the branches. + if (IsSkipping(preprocessor)) state = PreprocessorConditionalState::After; + // + // Otherwise, if our condition was true, then set us to be inside the `true` branch + else if (enable) state = PreprocessorConditionalState::During; + + conditional->state = state; + + // Push conditional onto the stack + conditional->parent = inputStream->conditional; + inputStream->conditional = conditional; +} + +// +// Preprocessor Conditional Expressions +// + +// Conditional expressions are always of type `int` +typedef int PreprocessorExpressionValue; + +// Forward-declaretion +static PreprocessorExpressionValue ParseAndEvaluateExpression(PreprocessorDirectiveContext* context); + +// Parse a unary (prefix) expression inside of a preprocessor directive. +static PreprocessorExpressionValue ParseAndEvaluateUnaryExpression(PreprocessorDirectiveContext* context) +{ + switch (PeekTokenType(context)) + { + // handle prefix unary ops + case TokenType::OpSub: + AdvanceToken(context); + return -ParseAndEvaluateUnaryExpression(context); + case TokenType::OpNot: + AdvanceToken(context); + return !ParseAndEvaluateUnaryExpression(context); + case TokenType::OpBitNot: + AdvanceToken(context); + return ~ParseAndEvaluateUnaryExpression(context); + + // handle parenthized sub-expression + case TokenType::LParent: + { + Token leftParen = AdvanceToken(context); + PreprocessorExpressionValue value = ParseAndEvaluateExpression(context); + if (!Expect(context, TokenType::RParent, Diagnostics::expectedTokenInPreprocessorExpression)) + { + GetSink(context)->diagnose(leftParen.Position, Diagnostics::seeOpeningToken, leftParen); + } + return value; + } + + case TokenType::IntLiterial: + return StringToInt(AdvanceToken(context).Content); + + case TokenType::Identifier: + { + Token token = AdvanceToken(context); + if (token.Content == "defined") + { + // handle `defined(someName)` + + // Possibly parse a `(` + Token leftParen; + if (PeekRawTokenType(context) == TokenType::LParent) + { + leftParen = AdvanceRawToken(context); + } + + // Expect an identifier + Token nameToken; + if (!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInDefinedExpression, &nameToken)) + { + return 0; + } + String name = nameToken.Content; + + // If we saw an opening `(`, then expect one to close + if (leftParen.Type != TokenType::Unknown) + { + if(!ExpectRaw(context, TokenType::RParent, Diagnostics::expectedTokenInDefinedExpression)) + { + GetSink(context)->diagnose(leftParen.Position, Diagnostics::seeOpeningToken, leftParen); + return 0; + } + } + + return LookupMacro(context, name) != NULL; + } + + // An identifier here means it was not defined as a macro (or + // it is defined, but as a function-like macro. These should + // just evaluate to zero (possibly with a warning) + return 0; + } + + default: + GetSink(context)->diagnose(PeekLoc(context), Diagnostics::syntaxErrorInPreprocessorExpression); + return 0; + } +} + +// Determine the precedence level of an infix operator +// for use in parsing preprocessor conditionals. +static int GetInfixOpPrecedence(Token const& opToken) +{ + // If token is on another line, it is not part of the + // expression + if (opToken.flags & TokenFlag::AtStartOfLine) + return -1; + + // otherwise we look at the token type to figure + // out what precednece it should be parse with + switch (opToken.Type) + { + default: + // tokens that aren't infix operators should + // cause us to stop parsing an expression + return -1; + + case TokenType::OpMul: return 10; + case TokenType::OpDiv: return 10; + case TokenType::OpMod: return 10; + + case TokenType::OpAdd: return 9; + case TokenType::OpSub: return 9; + + case TokenType::OpLsh: return 8; + case TokenType::OpRsh: return 8; + + case TokenType::OpLess: return 7; + case TokenType::OpGreater: return 7; + case TokenType::OpLeq: return 7; + case TokenType::OpGeq: return 7; + + case TokenType::OpEql: return 6; + case TokenType::OpNeq: return 6; + + case TokenType::OpBitAnd: return 5; + case TokenType::OpBitOr: return 4; + case TokenType::OpBitXor: return 3; + case TokenType::OpAnd: return 2; + case TokenType::OpOr: return 1; + } +}; + +// Evaluate one infix operation in a preprocessor +// conditional expression +static PreprocessorExpressionValue EvaluateInfixOp( + PreprocessorDirectiveContext* context, + Token const& opToken, + PreprocessorExpressionValue left, + PreprocessorExpressionValue right) +{ + switch (opToken.Type) + { + default: +// SLANG_INTERNAL_ERROR(getSink(preprocessor), opToken); + return 0; + break; + + case TokenType::OpMul: return left * right; + case TokenType::OpDiv: + { + if (right == 0) + { + if (!context->parseError) + { + GetSink(context)->diagnose(opToken.Position, Diagnostics::divideByZeroInPreprocessorExpression); + } + return 0; + } + return left / right; + } + case TokenType::OpMod: + { + if (right == 0) + { + if (!context->parseError) + { + GetSink(context)->diagnose(opToken.Position, Diagnostics::divideByZeroInPreprocessorExpression); + } + return 0; + } + return left % right; + } + case TokenType::OpAdd: return left + right; + case TokenType::OpSub: return left - right; + case TokenType::OpLsh: return left << right; + case TokenType::OpRsh: return left >> right; + case TokenType::OpLess: return left < right ? 1 : 0; + case TokenType::OpGreater: return left > right ? 1 : 0; + case TokenType::OpLeq: return left <= right ? 1 : 0; + case TokenType::OpGeq: return left <= right ? 1 : 0; + case TokenType::OpEql: return left == right ? 1 : 0; + case TokenType::OpNeq: return left != right ? 1 : 0; + case TokenType::OpBitAnd: return left & right; + case TokenType::OpBitOr: return left | right; + case TokenType::OpBitXor: return left ^ right; + case TokenType::OpAnd: return left && right; + case TokenType::OpOr: return left || right; + } +} + +// Parse the rest of an infix preprocessor expression with +// precedence greater than or equal to the given `precedence` argument. +// The value of the left-hand-side expression is provided as +// an argument. +// This is used to form a simple recursive-descent expression parser. +static PreprocessorExpressionValue ParseAndEvaluateInfixExpressionWithPrecedence( + PreprocessorDirectiveContext* context, + PreprocessorExpressionValue left, + int precedence) +{ + for (;;) + { + // Look at the next token, and see if it is an operator of + // high enough precedence to be included in our expression + Token opToken = PeekToken(context); + int opPrecedence = GetInfixOpPrecedence(opToken); + + // If it isn't an operator of high enough precendece, we are done. + if(opPrecedence < precedence) + break; + + // Otherwise we need to consume the operator token. + AdvanceToken(context); + + // Next we parse a right-hand-side expression by starting with + // a unary expression and absorbing and many infix operators + // as possible with strictly higher precedence than the operator + // we found above. + PreprocessorExpressionValue right = ParseAndEvaluateUnaryExpression(context); + for (;;) + { + // Look for an operator token + Token rightOpToken = PeekToken(context); + int rightOpPrecedence = GetInfixOpPrecedence(rightOpToken); + + // If no operator was found, or the operator wasn't high + // enough precedence to fold into the right-hand-side, + // exit this loop. + if (rightOpPrecedence <= opPrecedence) + break; + + // Now invoke the parser recursively, passing in our + // existing right-hand side to form an even larger one. + right = ParseAndEvaluateInfixExpressionWithPrecedence( + context, + right, + rightOpPrecedence); + } + + // Now combine the left- and right-hand sides using + // the operator we found above. + left = EvaluateInfixOp(context, opToken, left, right); + } + return left; +} + +// Parse a complete (infix) preprocessor expression, and return its value +static PreprocessorExpressionValue ParseAndEvaluateExpression(PreprocessorDirectiveContext* context) +{ + // First read in the left-hand side (or the whole expression in the unary case) + PreprocessorExpressionValue value = ParseAndEvaluateUnaryExpression(context); + + // Try to read in trailing infix operators with correct precedence + return ParseAndEvaluateInfixExpressionWithPrecedence(context, value, 0); +} + +// Handle a `#if` directive +static void HandleIfDirective(PreprocessorDirectiveContext* context) +{ + // Parse a preprocessor expression. + PreprocessorExpressionValue value = ParseAndEvaluateExpression(context); + + // Begin a preprocessor block, enabled based on the expression. + BeginConditional(context, value != 0); +} + +// Handle a `#ifdef` directive +static void HandleIfDefDirective(PreprocessorDirectiveContext* context) +{ + // Expect a raw identifier, so we can check if it is defined + Token nameToken; + if(!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInPreprocessorDirective, &nameToken)) + return; + String name = nameToken.Content; + + // Check if the name is defined. + BeginConditional(context, LookupMacro(context, name) != NULL); +} + +// Handle a `#ifndef` directive +static void HandleIfNDefDirective(PreprocessorDirectiveContext* context) +{ + // Expect a raw identifier, so we can check if it is defined + Token nameToken; + if(!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInPreprocessorDirective, &nameToken)) + return; + String name = nameToken.Content; + + // Check if the name is defined. + BeginConditional(context, LookupMacro(context, name) == NULL); +} + +// Handle a `#else` directive +static void HandleElseDirective(PreprocessorDirectiveContext* context) +{ + PreprocessorInputStream* inputStream = context->preprocessor->inputStream; + assert(inputStream); + + // if we aren't inside a conditional, then error + PreprocessorConditional* conditional = inputStream->conditional; + if (!conditional) + { + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveWithoutIf, GetDirectiveName(context)); + return; + } + + // if we've already seen a `#else`, then it is an error + if (conditional->elseToken.Type != TokenType::Unknown) + { + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveAfterElse, GetDirectiveName(context)); + GetSink(context)->diagnose(conditional->elseToken.Position, Diagnostics::seeDirective); + return; + } + conditional->elseToken = context->directiveToken; + + switch (conditional->state) + { + case PreprocessorConditionalState::Before: + conditional->state = PreprocessorConditionalState::During; + break; + + case PreprocessorConditionalState::During: + conditional->state = PreprocessorConditionalState::After; + break; + + default: + break; + } +} + +// Handle a `#elif` directive +static void HandleElifDirective(PreprocessorDirectiveContext* context) +{ + // HACK(tfoley): handle an empty `elif` like an `else` directive + // + // This is the behavior expected by at least one input program. + // We will eventually want to be pedantic about this. + // even if t + if (PeekRawTokenType(context) == TokenType::EndOfDirective) + { + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveExpectsExpression, GetDirectiveName(context)); + HandleElseDirective(context); + return; + } + + PreprocessorExpressionValue value = ParseAndEvaluateExpression(context); + + PreprocessorInputStream* inputStream = context->preprocessor->inputStream; + assert(inputStream); + + // if we aren't inside a conditional, then error + PreprocessorConditional* conditional = inputStream->conditional; + if (!conditional) + { + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveWithoutIf, GetDirectiveName(context)); + return; + } + + // if we've already seen a `#else`, then it is an error + if (conditional->elseToken.Type != TokenType::Unknown) + { + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveAfterElse, GetDirectiveName(context)); + GetSink(context)->diagnose(conditional->elseToken.Position, Diagnostics::seeDirective); + return; + } + + switch (conditional->state) + { + case PreprocessorConditionalState::Before: + if(value) + conditional->state = PreprocessorConditionalState::During; + break; + + case PreprocessorConditionalState::During: + conditional->state = PreprocessorConditionalState::After; + break; + + default: + break; + } +} + +// Handle a `#endif` directive +static void HandleEndIfDirective(PreprocessorDirectiveContext* context) +{ + PreprocessorInputStream* inputStream = context->preprocessor->inputStream; + assert(inputStream); + + // if we aren't inside a conditional, then error + PreprocessorConditional* conditional = inputStream->conditional; + if (!conditional) + { + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveWithoutIf, GetDirectiveName(context)); + return; + } + + inputStream->conditional = conditional->parent; + DestroyConditional(conditional); +} + +// Helper routine to check that we find the end of a directive where +// we expect it. +// +// Most directives do not need to call this directly, since we have +// a catch-all case in the main `HandleDirective()` funciton. +// The `#include` case will call it directly to avoid complications +// when it switches the input stream. +static void expectEndOfDirective(PreprocessorDirectiveContext* context) +{ + if(context->haveDoneEndOfDirectiveChecks) + return; + + context->haveDoneEndOfDirectiveChecks = true; + + if (!IsEndOfLine(context)) + { + // If we already saw a previous parse error, then don't + // emit another one for the same directive. + if (!context->parseError) + { + GetSink(context)->diagnose(PeekLoc(context), Diagnostics::unexpectedTokensAfterDirective, GetDirectiveName(context)); + } + SkipToEndOfLine(context); + } + + // Clear out the end-of-directive token + AdvanceRawToken(context->preprocessor); +} + + +// Handle a `#include` directive +static void HandleIncludeDirective(PreprocessorDirectiveContext* context) +{ + Token pathToken; + if(!Expect(context, TokenType::StringLiterial, Diagnostics::expectedTokenInPreprocessorDirective, &pathToken)) + return; + + String path = getFileNameTokenValue(pathToken); + + // TODO(tfoley): make this robust in presence of `#line` + String pathIncludedFrom = GetDirectiveLoc(context).FileName; + String foundPath; + String foundSource; + + + IncludeHandler* includeHandler = context->preprocessor->includeHandler; + if (!includeHandler) + { + GetSink(context)->diagnose(pathToken.Position, Diagnostics::includeFailed, path); + GetSink(context)->diagnose(pathToken.Position, Diagnostics::noIncludeHandlerSpecified); + return; + } + if (!includeHandler->TryToFindIncludeFile(path, pathIncludedFrom, &foundPath, &foundSource)) + { + GetSink(context)->diagnose(pathToken.Position, Diagnostics::includeFailed, path); + return; + } + + // Do all checking related to the end of this directive before we push a new stream, + // just to avoid complications where that check would need to deal with + // a switch of input stream + expectEndOfDirective(context); + + // Push the new file onto our stack of input streams + // TODO(tfoley): check if we have made our include stack too deep + PreprocessorInputStream* inputStream = CreateInputStreamForSource(context->preprocessor, foundSource, foundPath); + inputStream->parent = context->preprocessor->inputStream; + context->preprocessor->inputStream = inputStream; +} + +// Handle a `#define` directive +static void HandleDefineDirective(PreprocessorDirectiveContext* context) +{ + Token nameToken; + if (!Expect(context, TokenType::Identifier, Diagnostics::expectedTokenInPreprocessorDirective, &nameToken)) + return; + String name = nameToken.Content; + + PreprocessorMacro* macro = CreateMacro(context->preprocessor); + macro->nameToken = nameToken; + + PreprocessorMacro* oldMacro = LookupMacro(&context->preprocessor->globalEnv, name); + if (oldMacro) + { + GetSink(context)->diagnose(nameToken.Position, Diagnostics::macroRedefinition, name); + GetSink(context)->diagnose(oldMacro->nameToken.Position, Diagnostics::seePreviousDefinitionOf, name); + + DestroyMacro(context->preprocessor, oldMacro); + } + context->preprocessor->globalEnv.macros[name] = macro; + + // If macro name is immediately followed (with no space) by `(`, + // then we have a function-like macro + if (PeekRawTokenType(context) == TokenType::LParent) + { + if (!(PeekRawToken(context).flags & TokenFlag::AfterWhitespace)) + { + // This is a function-like macro, so we need to remember that + // and start capturing parameters + macro->flavor = PreprocessorMacroFlavor::FunctionLike; + + AdvanceRawToken(context); + + // If there are any parameters, parse them + if (PeekRawTokenType(context) != TokenType::RParent) + { + for (;;) + { + // TODO: handle elipsis (`...`) for varags + + // A macro parameter name should be a raw identifier + Token paramToken; + if (!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInMacroParameters, ¶mToken)) + break; + + // TODO(tfoley): some validation on parameter name. + // Certain names (e.g., `defined` and `__VA_ARGS__` + // are not allowed to be used as macros or parameters). + + // Add the parameter to the macro being deifned + macro->params.Add(paramToken); + + // If we see `)` then we are done with arguments + if (PeekRawTokenType(context) == TokenType::RParent) + break; + + ExpectRaw(context, TokenType::Comma, Diagnostics::expectedTokenInMacroParameters); + } + } + + ExpectRaw(context, TokenType::RParent, Diagnostics::expectedTokenInMacroParameters); + } + } + + // consume tokens until end-of-line + for(;;) + { + Token token = AdvanceRawToken(context); + if( token.Type == TokenType::EndOfDirective ) + { + // Last token on line will be turned into a conceptual end-of-file + // token for the sub-stream that the macro expands into. + token.Type = TokenType::EndOfFile; + macro->tokens.mTokens.Add(token); + break; + } + + // In the ordinary case, we just add the token to the definition + macro->tokens.mTokens.Add(token); + } +} + +// Handle a `#undef` directive +static void HandleUndefDirective(PreprocessorDirectiveContext* context) +{ + Token nameToken; + if (!Expect(context, TokenType::Identifier, Diagnostics::expectedTokenInPreprocessorDirective, &nameToken)) + return; + String name = nameToken.Content; + + PreprocessorEnvironment* env = &context->preprocessor->globalEnv; + PreprocessorMacro* macro = LookupMacro(env, name); + if (macro != NULL) + { + // name was defined, so remove it + env->macros.Remove(name); + + DestroyMacro(context->preprocessor, macro); + } + else + { + // name wasn't defined + GetSink(context)->diagnose(nameToken.Position, Diagnostics::macroNotDefined, name); + } +} + +// Handle a `#warning` directive +static void HandleWarningDirective(PreprocessorDirectiveContext* context) +{ + // TODO: read rest of line without actual tokenization + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::userDefinedWarning, "user-defined warning"); + SkipToEndOfLine(context); +} + +// Handle a `#error` directive +static void HandleErrorDirective(PreprocessorDirectiveContext* context) +{ + // TODO: read rest of line without actual tokenization + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::userDefinedError, "user-defined warning"); + SkipToEndOfLine(context); +} + +// Handle a `#line` directive +static void HandleLineDirective(PreprocessorDirectiveContext* context) +{ + int line = 0; + if (PeekTokenType(context) == TokenType::IntLiterial) + { + line = StringToInt(AdvanceToken(context).Content); + } + else if (PeekTokenType(context) == TokenType::Identifier + && PeekToken(context).Content == "default") + { + AdvanceToken(context); + + // Stop overiding soure locations. + context->preprocessor->inputStream->isOverridingSourceLoc = false; + return; + } + else + { + GetSink(context)->diagnose(PeekLoc(context), Diagnostics::expected2TokensInPreprocessorDirective, + TokenType::IntLiterial, + "default", + GetDirectiveName(context)); + context->parseError = true; + return; + } + + CodePosition directiveLoc = GetDirectiveLoc(context); + + String file; + if (PeekTokenType(context) == TokenType::EndOfDirective) + { + file = directiveLoc.FileName; + } + else if (PeekTokenType(context) == TokenType::StringLiterial) + { + file = AdvanceToken(context).Content; + } + else if (PeekTokenType(context) == TokenType::IntLiterial) + { + // Note(tfoley): GLSL allows the "source string" to be indicated by an integer + // TODO(tfoley): Figure out a better way to handle this, if it matters + file = AdvanceToken(context).Content; + } + else + { + Expect(context, TokenType::StringLiterial, Diagnostics::expectedTokenInPreprocessorDirective); + return; + } + + PreprocessorInputStream* inputStream = context->preprocessor->inputStream; + + inputStream->isOverridingSourceLoc = true; + inputStream->overrideFileName = file; + inputStream->overrideLineOffset = line - (directiveLoc.Line + 1); +} + +// Handle a `#pragma` directive +static void HandlePragmaDirective(PreprocessorDirectiveContext* context) +{ + // TODO(tfoley): figure out which pragmas to parse, + // and which to pass along + SkipToEndOfLine(context); +} + +// Handle a `#version` directive +static void handleGLSLVersionDirective(PreprocessorDirectiveContext* context) +{ + Token versionNumberToken; + if(!ExpectRaw( + context, + TokenType::IntLiterial, + Diagnostics::expectedTokenInPreprocessorDirective, + &versionNumberToken)) + { + return; + } + + Token glslProfileToken; + if(PeekTokenType(context) == TokenType::Identifier) + { + glslProfileToken = AdvanceToken(context); + } + + // Need to construct a representation taht we can hook into our compilation result + + auto modifier = new GLSLVersionDirective(); + modifier->versionNumberToken = versionNumberToken; + modifier->glslProfileToken = glslProfileToken; + + // Attach the modifier to the program we are parsing! + + addModifier( + context->preprocessor->syntax, + modifier); +} + +// Handle a `#extension` directive, e.g., +// +// #extension some_extension_name : enable +// +static void handleGLSLExtensionDirective(PreprocessorDirectiveContext* context) +{ + Token extensionNameToken; + if(!ExpectRaw( + context, + TokenType::Identifier, + Diagnostics::expectedTokenInPreprocessorDirective, + &extensionNameToken)) + { + return; + } + + if( !ExpectRaw(context, TokenType::Colon, Diagnostics::expectedTokenInPreprocessorDirective) ) + { + return; + } + + Token dispositionToken; + if(!ExpectRaw( + context, + TokenType::Identifier, + Diagnostics::expectedTokenInPreprocessorDirective, + &dispositionToken)) + { + return; + } + + // Need to construct a representation taht we can hook into our compilation result + + auto modifier = new GLSLExtensionDirective(); + modifier->extensionNameToken = extensionNameToken; + modifier->dispositionToken = dispositionToken; + + // Attach the modifier to the program we are parsing! + + addModifier( + context->preprocessor->syntax, + modifier); +} + +// Handle an invalid directive +static void HandleInvalidDirective(PreprocessorDirectiveContext* context) +{ + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::unknownPreprocessorDirective, GetDirectiveName(context)); + SkipToEndOfLine(context); +} + +// Callback interface used by preprocessor directives +typedef void (*PreprocessorDirectiveCallback)(PreprocessorDirectiveContext* context); + +enum PreprocessorDirectiveFlag : unsigned int +{ + // Should this directive be handled even when skipping disbaled code? + ProcessWhenSkipping = 1 << 0, +}; + +// Information about a specific directive +struct PreprocessorDirective +{ + // Name of the directive + char const* name; + + // Callback to handle the directive + PreprocessorDirectiveCallback callback; + + unsigned int flags; +}; + +// A simple array of all the directives we know how to handle. +// TODO(tfoley): considering making this into a real hash map, +// and then make it easy-ish for users of the codebase to add +// their own directives as desired. +static const PreprocessorDirective kDirectives[] = +{ + { "if", &HandleIfDirective, ProcessWhenSkipping }, + { "ifdef", &HandleIfDefDirective, ProcessWhenSkipping }, + { "ifndef", &HandleIfNDefDirective, ProcessWhenSkipping }, + { "else", &HandleElseDirective, ProcessWhenSkipping }, + { "elif", &HandleElifDirective, ProcessWhenSkipping }, + { "endif", &HandleEndIfDirective, ProcessWhenSkipping }, + + { "include", &HandleIncludeDirective, 0 }, + { "define", &HandleDefineDirective, 0 }, + { "undef", &HandleUndefDirective, 0 }, + { "warning", &HandleWarningDirective, 0 }, + { "error", &HandleErrorDirective, 0 }, + { "line", &HandleLineDirective, 0 }, + { "pragma", &HandlePragmaDirective, 0 }, + + // TODO(tfoley): These are specific to GLSL, and probably + // shouldn't be enabled for HLSL or Slang + { "version", &handleGLSLVersionDirective, 0 }, + { "extension", &handleGLSLExtensionDirective, 0 }, + + { NULL, NULL }, +}; + +static const PreprocessorDirective kInvalidDirective = { + NULL, &HandleInvalidDirective, 0, +}; + +// Look up the directive with the given name. +static PreprocessorDirective const* FindDirective(String const& name) +{ + char const* nameStr = name.Buffer(); + for (int ii = 0; kDirectives[ii].name; ++ii) + { + if (strcmp(kDirectives[ii].name, nameStr) != 0) + continue; + + return &kDirectives[ii]; + } + + return &kInvalidDirective; +} + +// Process a directive, where the preprocessor has already consumed the +// `#` token that started the directive line. +static void HandleDirective(PreprocessorDirectiveContext* context) +{ + // Try to read the directive name. + context->directiveToken = PeekRawToken(context); + + TokenType directiveTokenType = GetDirective(context).Type; + + // An empty directive is allowed, and ignored. + if (directiveTokenType == TokenType::EndOfDirective) + { + return; + } + // Otherwise the directive name had better be an identifier + else if (directiveTokenType != TokenType::Identifier) + { + GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::expectedPreprocessorDirectiveName); + SkipToEndOfLine(context); + return; + } + + // Consume the directive name token. + AdvanceRawToken(context); + + // Look up the handler for the directive. + PreprocessorDirective const* directive = FindDirective(GetDirectiveName(context)); + + // If we are skipping disabled code, and the directive is not one + // of the small number that need to run even in that case, skip it. + if (IsSkipping(context) && !(directive->flags & PreprocessorDirectiveFlag::ProcessWhenSkipping)) + { + SkipToEndOfLine(context); + return; + } + + // Apply the directive-specific callback + (directive->callback)(context); + + // We expect the directive callback to consume the entire line, so if + // it hasn't that is a parse error. + expectEndOfDirective(context); +} + +// Read one token using the full preprocessor, with all its behaviors. +static Token ReadToken(Preprocessor* preprocessor) +{ + for (;;) + { + // Look at the next raw token in the input. + Token const& token = PeekRawToken(preprocessor); + + // If we have a directive (`#` at start of line) then handle it + if ((token.Type == TokenType::Pound) && (token.flags & TokenFlag::AtStartOfLine)) + { + // Skip the `#` + AdvanceRawToken(preprocessor); + + // Create a context for parsing the directive + PreprocessorDirectiveContext directiveContext; + directiveContext.preprocessor = preprocessor; + directiveContext.parseError = false; + directiveContext.haveDoneEndOfDirectiveChecks = false; + + // Parse and handle the directive + HandleDirective(&directiveContext); + continue; + } + + // otherwise, if we are currently in a skipping mode, then skip tokens + if (IsSkipping(preprocessor)) + { + AdvanceRawToken(preprocessor); + continue; + } + + // otherwise read a token, which may involve macro expansion + return AdvanceToken(preprocessor); + } +} + +// intialize a preprocessor context, using the given sink for errros +static void InitializePreprocessor( + Preprocessor* preprocessor, + DiagnosticSink* sink) +{ + preprocessor->sink = sink; + preprocessor->includeHandler = NULL; + preprocessor->endOfFileToken.Type = TokenType::EndOfFile; + preprocessor->endOfFileToken.flags = TokenFlag::AtStartOfLine; +} + +// clean up after an environment +PreprocessorEnvironment::~PreprocessorEnvironment() +{ + for (auto pair : this->macros) + { + DestroyMacro(NULL, pair.Value); + } +} + +// finalize a preprocessor and free any memory still in use +static void FinalizePreprocessor( + Preprocessor* preprocessor) +{ + // Clear out any waiting input streams + PreprocessorInputStream* input = preprocessor->inputStream; + while (input) + { + PreprocessorInputStream* parent = input->parent; + DestroyInputStream(preprocessor, input); + input = parent; + } + +#if 0 + // clean up any macros that were allocated + for (auto pair : preprocessor->globalEnv.macros) + { + DestroyMacro(preprocessor, pair.Value); + } +#endif +} + +// Add a simple macro definition from a string (e.g., for a +// `-D` option passed on the command line +static void DefineMacro( + Preprocessor* preprocessor, + String const& key, + String const& value) +{ + String fileName = "command line"; + PreprocessorMacro* macro = CreateMacro(preprocessor); + + // Use existing `Lexer` to generate a token stream. + Lexer lexer(fileName, value, GetSink(preprocessor)); + macro->tokens = lexer.lexAllTokens(); + macro->nameToken = Token(TokenType::Identifier, key, 0, 0, 0, fileName); + + PreprocessorMacro* oldMacro = NULL; + if (preprocessor->globalEnv.macros.TryGetValue(key, oldMacro)) + { + DestroyMacro(preprocessor, oldMacro); + } + + preprocessor->globalEnv.macros[key] = macro; +} + +// read the entire input into tokens +static TokenList ReadAllTokens( + Preprocessor* preprocessor) +{ + TokenList tokens; + for (;;) + { + Token token = ReadToken(preprocessor); + + tokens.mTokens.Add(token); + + // Note: we include the EOF token in the list, + // since that is expected by the `TokenList` type. + if (token.Type == TokenType::EndOfFile) + break; + } + return tokens; +} + +TokenList preprocessSource( + CoreLib::String const& source, + CoreLib::String const& fileName, + DiagnosticSink* sink, + IncludeHandler* includeHandler, + CoreLib::Dictionary<CoreLib::String, CoreLib::String> defines, + ProgramSyntaxNode* syntax) +{ + Preprocessor preprocessor; + InitializePreprocessor(&preprocessor, sink); + preprocessor.syntax = syntax; + + preprocessor.includeHandler = includeHandler; + for (auto p : defines) + { + DefineMacro(&preprocessor, p.Key, p.Value); + } + + // create an initial input stream based on the provided buffer + preprocessor.inputStream = CreateInputStreamForSource(&preprocessor, source, fileName); + + TokenList tokens = ReadAllTokens(&preprocessor); + + FinalizePreprocessor(&preprocessor); + + // debugging: build the pre-processed source back together +#if 0 + StringBuilder sb; + for (auto t : tokens) + { + if (t.flags & TokenFlag::AtStartOfLine) + { + sb << "\n"; + } + else if (t.flags & TokenFlag::AfterWhitespace) + { + sb << " "; + } + + sb << t.Content; + } + + String s = sb.ProduceString(); +#endif + + return tokens; +} + +}} diff --git a/source/slang/preprocessor.h b/source/slang/preprocessor.h new file mode 100644 index 000000000..ab72f3f87 --- /dev/null +++ b/source/slang/preprocessor.h @@ -0,0 +1,35 @@ +// Preprocessor.h +#ifndef SLANG_PREPROCESSOR_H_INCLUDED +#define SLANG_PREPROCESSOR_H_INCLUDED + +#include "../core/basic.h" +#include "../slang/lexer.h" + +namespace Slang{ namespace Compiler { + +class DiagnosticSink; +class ProgramSyntaxNode; + +// Callback interface for the preprocessor to use when looking +// for files in `#include` directives. +struct IncludeHandler +{ + virtual bool TryToFindIncludeFile( + CoreLib::String const& pathToInclude, + CoreLib::String const& pathIncludedFrom, + CoreLib::String* outFoundPath, + CoreLib::String* outFoundSource) = 0; +}; + +// Take a string of source code and preprocess it into a list of tokens. +TokenList preprocessSource( + CoreLib::String const& source, + CoreLib::String const& fileName, + DiagnosticSink* sink, + IncludeHandler* includeHandler, + CoreLib::Dictionary<CoreLib::String, CoreLib::String> defines, + ProgramSyntaxNode* syntax); + +}} + +#endif diff --git a/source/slang/profile-defs.h b/source/slang/profile-defs.h new file mode 100644 index 000000000..76ba476bb --- /dev/null +++ b/source/slang/profile-defs.h @@ -0,0 +1,123 @@ +// + +// Define all the various language "profiles" we want to support. + +#ifndef LANGUAGE +#define LANGUAGE(TAG, NAME) /* emptry */ +#endif + +#ifndef LANGUAGE_ALIAS +#define LANGUAGE_ALIAS(TAG, NAME) /* empty */ +#endif + +#ifndef PROFILE_FAMILY +#define PROFILE_FAMILY(TAG) /* empty */ +#endif + +#ifndef PROFILE_VERSION +#define PROFILE_VERSION(TAG, FAMILY) /* empty */ +#endif + +#ifndef PROFILE_STAGE +#define PROFILE_STAGE(TAG, NAME, VAL) /* empty */ +#endif + +#ifndef PROFILE_STAGE_ALIAS +#define PROFILE_STAGE_ALIAS(TAG, NAME) /* empty */ +#endif + + +#ifndef PROFILE +#define PROFILE(TAG, NAME, STAGE, VERSION) /* empty */ +#endif + +#ifndef PROFILE_ALIAS +#define PROFILE_ALIAS(TAG, NAME) /* empty */ +#endif + +// Source and destination languages + +LANGUAGE(HLSL, hlsl) +LANGUAGE(DXBytecode, dxbc) +LANGUAGE(DXBytecodeAssembly,dxbc_asm) +LANGUAGE(DXIL, dxil) +LANGUAGE(DXILAssembly, dxil_asm) +LANGUAGE(GLSL, glsl) +LANGUAGE(GLSL_ES, glsl_es) +LANGUAGE(GLSL_VK, glsl_vk) +LANGUAGE(SPIRV, spirv) +LANGUAGE(SPIRV_GL, spirv_gl) + +LANGUAGE_ALIAS(GLSL, glsl_gl) +LANGUAGE_ALIAS(SPIRV, spirv_vk) + + +// Pipeline stages to target +PROFILE_STAGE(Vertex, vertex, SLANG_STAGE_VERTEX) +PROFILE_STAGE(Hull, hull, SLANG_STAGE_HULL) +PROFILE_STAGE(Domain, domain, SLANG_STAGE_DOMAIN) +PROFILE_STAGE(Geometry, geometry, SLANG_STAGE_GEOMETRY) +PROFILE_STAGE(Fragment, fragment, SLANG_STAGE_FRAGMENT) +PROFILE_STAGE(Compute, compute, SLANG_STAGE_COMPUTE) + +PROFILE_STAGE_ALIAS(Fragment, pixel) + +// Profile families + +PROFILE_FAMILY(DX) +PROFILE_FAMILY(GLSL) +PROFILE_FAMILY(SPRIV) + +// Profile versions + + +PROFILE_VERSION(DX_4_0, DX) +PROFILE_VERSION(DX_4_0_Level_9_0, DX) +PROFILE_VERSION(DX_4_0_Level_9_1, DX) +PROFILE_VERSION(DX_4_0_Level_9_3, DX) +PROFILE_VERSION(DX_4_1, DX) +PROFILE_VERSION(DX_5_0, DX) + +PROFILE_VERSION(GLSL, GLSL) + + +// Specific profiles + +PROFILE(DX_Compute_4_0, cs_4_0, Compute, DX_4_0) +PROFILE(DX_Compute_4_1, cs_4_1, Compute, DX_4_1) +PROFILE(DX_Compute_5_0, cs_5_0, Compute, DX_5_0) +PROFILE(DX_Domain_5_0, ds_5_0, Domain, DX_5_0) +PROFILE(DX_Geometry_4_0, gs_4_0, Geometry, DX_4_0) +PROFILE(DX_Geometry_4_1, gs_4_1, Geometry, DX_4_1) +PROFILE(DX_Geometry_5_0, gs_5_0, Geometry, DX_5_0) +PROFILE(DX_Hull_5_0, hs_5_0, Hull, DX_5_0) +PROFILE(DX_Fragment_4_0, ps_4_0, Fragment, DX_4_0) +PROFILE(DX_Fragment_4_0_Level_9_0, ps_4_0_level_9_0, Fragment, DX_4_0_Level_9_0) +PROFILE(DX_Fragment_4_0_Level_9_1, ps_4_0_level_9_1, Fragment, DX_4_0_Level_9_1) +PROFILE(DX_Fragment_4_0_Level_9_3, ps_4_0_level_9_3, Fragment, DX_4_0_Level_9_3) +PROFILE(DX_Fragment_4_1, ps_4_1, Fragment, DX_4_1) +PROFILE(DX_Fragment_5_0, ps_5_0, Fragment, DX_5_0) +PROFILE(DX_Vertex_4_0, vs_4_0, Vertex, DX_4_0) +PROFILE(DX_Vertex_4_0_Level_9_0, vs_4_0_level_9_0, Vertex, DX_4_0_Level_9_0) +PROFILE(DX_Vertex_4_0_Level_9_1, vs_4_0_level_9_1, Vertex, DX_4_0_Level_9_1) +PROFILE(DX_Vertex_4_0_Level_9_3, vs_4_0_level_9_3, Vertex, DX_4_0_Level_9_3) +PROFILE(DX_Vertex_4_1, vs_4_1, Vertex, DX_4_1) +PROFILE(DX_Vertex_5_0, vs_5_0, Vertex, DX_5_0) + +// + +PROFILE(GLSL_Compute, glsl_compute, Compute, GLSL) +PROFILE(GLSL_Vertex, glsl_vertex, Vertex, GLSL) +PROFILE(GLSL_Fragment, glsl_fragment, Fragment, GLSL) +PROFILE(GLSL_Geometry, glsl_geometry, Geometry, GLSL) +PROFILE(GLSL_TessControl, glsl_tess_control, Hull, GLSL) +PROFILE(GLSL_TessEval, glsl_tess_eval, Domain, GLSL) + +#undef LANGUAGE +#undef LANGUAGE_ALIAS +#undef PROFILE_FAMILY +#undef PROFILE_VERSION +#undef PROFILE_STAGE +#undef PROFILE_STAGE_ALIAS +#undef PROFILE +#undef PROFILE_ALIAS diff --git a/source/slang/profile.cpp b/source/slang/profile.cpp new file mode 100644 index 000000000..923dc2841 --- /dev/null +++ b/source/slang/profile.cpp @@ -0,0 +1,20 @@ +// profile.cpp +#include "Profile.h" + + +namespace Slang { +namespace Compiler { + + +ProfileFamily getProfileFamily(ProfileVersion version) +{ + switch( version ) + { + default: return ProfileFamily::Unknown; + +#define PROFILE_VERSION(TAG, FAMILY) case ProfileVersion::TAG: return ProfileFamily::FAMILY; +#include "profile-defs.h" + } +} + +}} diff --git a/source/slang/profile.h b/source/slang/profile.h new file mode 100644 index 000000000..31465c38c --- /dev/null +++ b/source/slang/profile.h @@ -0,0 +1,84 @@ +#ifndef SLANG_PROFILE_H_INCLUDED +#define SLANG_PROFILE_H_INCLUDED + +#include "../core/basic.h" +#include "../../slang.h" + +namespace Slang +{ + namespace Compiler + { + // Flavors of translation unit + enum class SourceLanguage : SlangSourceLanguage + { + Unknown = SLANG_SOURCE_LANGUAGE_UNKNOWN, // should not occur + Slang = SLANG_SOURCE_LANGUAGE_SLANG, + HLSL = SLANG_SOURCE_LANGUAGE_HLSL, + GLSL = SLANG_SOURCE_LANGUAGE_GLSL, + + // A separate PACKAGE of Slang code that has been imported + ImportedSlangCode, + }; + + // TODO(tfoley): This should merge with the above... + enum class Language + { + Unknown, +#define LANGUAGE(TAG, NAME) TAG, +#include "profile-defs.h" + }; + + enum class ProfileFamily + { + Unknown, +#define PROFILE_FAMILY(TAG) TAG, +#include "profile-defs.h" + }; + + enum class ProfileVersion + { + Unknown, +#define PROFILE_VERSION(TAG, FAMILY) TAG, +#include "profile-defs.h" + }; + + enum class Stage : SlangStage + { + Unknown = SLANG_STAGE_NONE, +#define PROFILE_STAGE(TAG, NAME, VAL) TAG = VAL, +#include "profile-defs.h" + }; + + ProfileFamily getProfileFamily(ProfileVersion version); + + struct Profile + { + typedef uint32_t RawVal; + enum : RawVal + { + Unknown, + +#define PROFILE(TAG, NAME, STAGE, VERSION) TAG = (uint32_t(Stage::STAGE) << 16) | uint32_t(ProfileVersion::VERSION), +#include "profile-defs.h" + }; + + Profile() {} + Profile(RawVal raw) + : raw(raw) + {} + + bool operator==(Profile const& other) const { return raw == other.raw; } + bool operator!=(Profile const& other) const { return raw != other.raw; } + + Stage GetStage() const { return Stage((uint32_t(raw) >> 16) & 0xFFFF); } + ProfileVersion GetVersion() const { return ProfileVersion(uint32_t(raw) & 0xFFFF); } + ProfileFamily getFamily() const { return getProfileFamily(GetVersion()); } + + static Profile LookUp(char const* name); + + RawVal raw = Unknown; + }; + } +} + +#endif diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp new file mode 100644 index 000000000..1a56a8c1f --- /dev/null +++ b/source/slang/reflection.cpp @@ -0,0 +1,1404 @@ +// reflection.cpp +#include "reflection.h" + +#include "compiler.h" +#include "type-layout.h" + +#include <assert.h> + +// Implementation to back public-facing reflection API + +using namespace Slang; +using namespace Slang::Compiler; + + +// Conversion routines to help with strongly-typed reflection API + +static inline ExpressionType* convert(SlangReflectionType* type) +{ + return (ExpressionType*) type; +} + +static inline SlangReflectionType* convert(ExpressionType* type) +{ + return (SlangReflectionType*) type; +} + +static inline TypeLayout* convert(SlangReflectionTypeLayout* type) +{ + return (TypeLayout*) type; +} + +static inline SlangReflectionTypeLayout* convert(TypeLayout* type) +{ + return (SlangReflectionTypeLayout*) type; +} + +static inline VarDeclBase* convert(SlangReflectionVariable* var) +{ + return (VarDeclBase*) var; +} + +static inline SlangReflectionVariable* convert(VarDeclBase* var) +{ + return (SlangReflectionVariable*) var; +} + +static inline VarLayout* convert(SlangReflectionVariableLayout* var) +{ + return (VarLayout*) var; +} + +static inline SlangReflectionVariableLayout* convert(VarLayout* var) +{ + return (SlangReflectionVariableLayout*) var; +} + +static inline EntryPointLayout* convert(SlangReflectionEntryPoint* entryPoint) +{ + return (EntryPointLayout*) entryPoint; +} + +static inline SlangReflectionEntryPoint* convert(EntryPointLayout* entryPoint) +{ + return (SlangReflectionEntryPoint*) entryPoint; +} + + +static inline ProgramLayout* convert(SlangReflection* program) +{ + return (ProgramLayout*) program; +} + +static inline SlangReflection* convert(ProgramLayout* program) +{ + return (SlangReflection*) program; +} + +// Type Reflection + + +SLANG_API SlangTypeKind spReflectionType_GetKind(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return SLANG_TYPE_KIND_NONE; + + // TODO(tfoley: Don't emit the same type more than once... + + if (auto basicType = type->As<BasicExpressionType>()) + { + return SLANG_TYPE_KIND_SCALAR; + } + else if (auto vectorType = type->As<VectorExpressionType>()) + { + return SLANG_TYPE_KIND_VECTOR; + } + else if (auto matrixType = type->As<MatrixExpressionType>()) + { + return SLANG_TYPE_KIND_MATRIX; + } + else if (auto constantBufferType = type->As<ConstantBufferType>()) + { + return SLANG_TYPE_KIND_CONSTANT_BUFFER; + } + else if (auto samplerStateType = type->As<SamplerStateType>()) + { + return SLANG_TYPE_KIND_SAMPLER_STATE; + } + else if (auto textureType = type->As<TextureType>()) + { + return SLANG_TYPE_KIND_RESOURCE; + } + + // TODO: need a better way to handle this stuff... +#define CASE(TYPE) \ + else if(type->As<TYPE>()) do { \ + return SLANG_TYPE_KIND_RESOURCE; \ + } while(0) + + CASE(HLSLBufferType); + CASE(HLSLRWBufferType); + CASE(HLSLBufferType); + CASE(HLSLRWBufferType); + CASE(HLSLStructuredBufferType); + CASE(HLSLRWStructuredBufferType); + CASE(HLSLAppendStructuredBufferType); + CASE(HLSLConsumeStructuredBufferType); + CASE(HLSLByteAddressBufferType); + CASE(HLSLRWByteAddressBufferType); + CASE(UntypedBufferResourceType); +#undef CASE + + else if (auto arrayType = type->As<ArrayExpressionType>()) + { + return SLANG_TYPE_KIND_ARRAY; + } + else if( auto declRefType = type->As<DeclRefType>() ) + { + auto declRef = declRefType->declRef; + if( auto structDeclRef = declRef.As<StructDeclRef>() ) + { + return SLANG_TYPE_KIND_STRUCT; + } + } + + assert(!"unexpected"); + return SLANG_TYPE_KIND_NONE; +} + +SLANG_API unsigned int spReflectionType_GetFieldCount(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return 0; + + // TODO: maybe filter based on kind + + if(auto declRefType = dynamic_cast<DeclRefType*>(type)) + { + auto declRef = declRefType->declRef; + if( auto structDeclRef = declRef.As<StructDeclRef>()) + { + return structDeclRef.GetFields().Count(); + } + } + + return 0; +} + +SLANG_API SlangReflectionVariable* spReflectionType_GetFieldByIndex(SlangReflectionType* inType, unsigned index) +{ + auto type = convert(inType); + if(!type) return nullptr; + + // TODO: maybe filter based on kind + + if(auto declRefType = dynamic_cast<DeclRefType*>(type)) + { + auto declRef = declRefType->declRef; + if( auto structDeclRef = declRef.As<StructDeclRef>()) + { + auto fieldDeclRef = structDeclRef.GetFields().ToArray()[index]; + return (SlangReflectionVariable*) fieldDeclRef.GetDecl(); + } + } + + return nullptr; +} + +SLANG_API size_t spReflectionType_GetElementCount(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return 0; + + if(auto arrayType = dynamic_cast<ArrayExpressionType*>(type)) + { + return GetIntVal(arrayType->ArrayLength); + } + else if( auto vectorType = dynamic_cast<VectorExpressionType*>(type)) + { + return GetIntVal(vectorType->elementCount); + } + + return 0; +} + +SLANG_API SlangReflectionType* spReflectionType_GetElementType(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return nullptr; + + if(auto arrayType = dynamic_cast<ArrayExpressionType*>(type)) + { + return (SlangReflectionType*) arrayType->BaseType.Ptr(); + } + else if( auto constantBufferType = dynamic_cast<ConstantBufferType*>(type)) + { + return convert(constantBufferType->elementType.Ptr()); + } + else if( auto vectorType = dynamic_cast<VectorExpressionType*>(type)) + { + return convert(vectorType->elementType.Ptr()); + } + else if( auto matrixType = dynamic_cast<MatrixExpressionType*>(type)) + { + return convert(matrixType->getElementType()); + } + + return nullptr; +} + +SLANG_API unsigned int spReflectionType_GetRowCount(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return 0; + + if(auto matrixType = dynamic_cast<MatrixExpressionType*>(type)) + { + return GetIntVal(matrixType->getRowCount()); + } + else if(auto vectorType = dynamic_cast<VectorExpressionType*>(type)) + { + return 1; + } + else if( auto basicType = dynamic_cast<BasicExpressionType*>(type) ) + { + return 1; + } + + return 0; +} + +SLANG_API unsigned int spReflectionType_GetColumnCount(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return 0; + + if(auto matrixType = dynamic_cast<MatrixExpressionType*>(type)) + { + return GetIntVal(matrixType->getColumnCount()); + } + else if(auto vectorType = dynamic_cast<VectorExpressionType*>(type)) + { + return GetIntVal(vectorType->elementCount); + } + else if( auto basicType = dynamic_cast<BasicExpressionType*>(type) ) + { + return 1; + } + + return 0; +} + +SLANG_API SlangScalarType spReflectionType_GetScalarType(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return 0; + + if(auto matrixType = dynamic_cast<MatrixExpressionType*>(type)) + { + type = matrixType->getElementType(); + } + else if(auto vectorType = dynamic_cast<VectorExpressionType*>(type)) + { + type = vectorType->elementType.Ptr(); + } + + if(auto basicType = dynamic_cast<BasicExpressionType*>(type)) + { + switch (basicType->BaseType) + { +#define CASE(BASE, TAG) \ + case BaseType::BASE: return SLANG_SCALAR_TYPE_##TAG + + CASE(Void, VOID); + CASE(Int, INT32); + CASE(Float, FLOAT32); + CASE(UInt, UINT32); + CASE(Bool, BOOL); + CASE(UInt64, UINT64); + +#undef CASE + + default: + assert(!"unexpected"); + return SLANG_SCALAR_TYPE_NONE; + break; + } + } + + return SLANG_SCALAR_TYPE_NONE; +} + +SLANG_API SlangResourceShape spReflectionType_GetResourceShape(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return 0; + + while(auto arrayType = type->As<ArrayExpressionType>()) + { + type = arrayType->BaseType.Ptr(); + } + + if(auto textureType = type->As<TextureType>()) + { + return textureType->getShape(); + } + + // TODO: need a better way to handle this stuff... +#define CASE(TYPE, SHAPE, ACCESS) \ + else if(type->As<TYPE>()) do { \ + return SHAPE; \ + } while(0) + + CASE(HLSLBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ); + CASE(HLSLRWBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); + CASE(HLSLBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ); + CASE(HLSLRWBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); + CASE(HLSLStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ); + CASE(HLSLRWStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); + CASE(HLSLAppendStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_APPEND); + CASE(HLSLConsumeStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_CONSUME); + CASE(HLSLByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ); + CASE(HLSLRWByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); + CASE(UntypedBufferResourceType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ); +#undef CASE + + return SLANG_RESOURCE_NONE; +} + +SLANG_API SlangResourceAccess spReflectionType_GetResourceAccess(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return 0; + + while(auto arrayType = type->As<ArrayExpressionType>()) + { + type = arrayType->BaseType.Ptr(); + } + + if(auto textureType = type->As<TextureType>()) + { + return textureType->getAccess(); + } + + // TODO: need a better way to handle this stuff... +#define CASE(TYPE, SHAPE, ACCESS) \ + else if(type->As<TYPE>()) do { \ + return ACCESS; \ + } while(0) + + CASE(HLSLBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ); + CASE(HLSLRWBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); + CASE(HLSLBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ); + CASE(HLSLRWBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); + CASE(HLSLStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ); + CASE(HLSLRWStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); + CASE(HLSLAppendStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_APPEND); + CASE(HLSLConsumeStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_CONSUME); + CASE(HLSLByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ); + CASE(HLSLRWByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); + CASE(UntypedBufferResourceType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ); +#undef CASE + + return SLANG_RESOURCE_ACCESS_NONE; +} + +SLANG_API SlangReflectionType* spReflectionType_GetResourceResultType(SlangReflectionType* inType) +{ + auto type = convert(inType); + if(!type) return nullptr; + + while(auto arrayType = type->As<ArrayExpressionType>()) + { + type = arrayType->BaseType.Ptr(); + } + + if (auto textureType = type->As<TextureType>()) + { + return convert(textureType->elementType.Ptr()); + } + + // TODO: need a better way to handle this stuff... +#define CASE(TYPE, SHAPE, ACCESS) \ + else if(type->As<TYPE>()) do { \ + return convert(type->As<TYPE>()->elementType.Ptr()); \ + } while(0) + + CASE(HLSLBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ); + CASE(HLSLRWBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); + CASE(HLSLBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ); + CASE(HLSLRWBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); + + // TODO: structured buffer needs to expose type layout! + + CASE(HLSLStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ); + CASE(HLSLRWStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE); + CASE(HLSLAppendStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_APPEND); + CASE(HLSLConsumeStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_CONSUME); +#undef CASE + + return nullptr; +} + +// Type Layout Reflection + +SLANG_API SlangReflectionType* spReflectionTypeLayout_GetType(SlangReflectionTypeLayout* inTypeLayout) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return nullptr; + + return (SlangReflectionType*) typeLayout->type.Ptr(); +} + +SLANG_API size_t spReflectionTypeLayout_GetSize(SlangReflectionTypeLayout* inTypeLayout, SlangParameterCategory category) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return 0; + + auto info = typeLayout->FindResourceInfo(LayoutResourceKind(category)); + if(!info) return 0; + + return info->count; +} + +SLANG_API SlangReflectionVariableLayout* spReflectionTypeLayout_GetFieldByIndex(SlangReflectionTypeLayout* inTypeLayout, unsigned index) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return nullptr; + + if(auto structTypeLayout = dynamic_cast<StructTypeLayout*>(typeLayout)) + { + return (SlangReflectionVariableLayout*) structTypeLayout->fields[index].Ptr(); + } + + return nullptr; +} + +SLANG_API size_t spReflectionTypeLayout_GetElementStride(SlangReflectionTypeLayout* inTypeLayout, SlangParameterCategory category) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return 0; + + if( auto arrayTypeLayout = dynamic_cast<ArrayTypeLayout*>(typeLayout)) + { + if(category == SLANG_PARAMETER_CATEGORY_UNIFORM) + { + return arrayTypeLayout->uniformStride; + } + else + { + auto elementTypeLayout = arrayTypeLayout->elementTypeLayout; + auto info = elementTypeLayout->FindResourceInfo(LayoutResourceKind(category)); + if(!info) return 0; + return info->count; + } + } + + return 0; +} + +SLANG_API SlangReflectionTypeLayout* spReflectionTypeLayout_GetElementTypeLayout(SlangReflectionTypeLayout* inTypeLayout) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return nullptr; + + if( auto arrayTypeLayout = dynamic_cast<ArrayTypeLayout*>(typeLayout)) + { + return (SlangReflectionTypeLayout*) arrayTypeLayout->elementTypeLayout.Ptr(); + } + else if( auto constantBufferTypeLayout = dynamic_cast<ParameterBlockTypeLayout*>(typeLayout)) + { + return convert(constantBufferTypeLayout->elementTypeLayout.Ptr()); + } + else if( auto structuredBufferTypeLayout = dynamic_cast<StructuredBufferTypeLayout*>(typeLayout)) + { + return convert(structuredBufferTypeLayout->elementTypeLayout.Ptr()); + } + + return nullptr; +} + +static SlangParameterCategory getParameterCategory( + LayoutResourceKind kind) +{ + return SlangParameterCategory(kind); +} + +static SlangParameterCategory getParameterCategory( + TypeLayout* typeLayout) +{ + auto resourceInfoCount = typeLayout->resourceInfos.Count(); + if(resourceInfoCount == 1) + { + return getParameterCategory(typeLayout->resourceInfos[0].kind); + } + else if(resourceInfoCount == 0) + { + // TODO: can this ever happen? + return SLANG_PARAMETER_CATEGORY_NONE; + } + return SLANG_PARAMETER_CATEGORY_MIXED; +} + +SLANG_API SlangParameterCategory spReflectionTypeLayout_GetParameterCategory(SlangReflectionTypeLayout* inTypeLayout) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return SLANG_PARAMETER_CATEGORY_NONE; + + return getParameterCategory(typeLayout); +} + +SLANG_API unsigned spReflectionTypeLayout_GetCategoryCount(SlangReflectionTypeLayout* inTypeLayout) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return 0; + + return (unsigned) typeLayout->resourceInfos.Count(); +} + +SLANG_API SlangParameterCategory spReflectionTypeLayout_GetCategoryByIndex(SlangReflectionTypeLayout* inTypeLayout, unsigned index) +{ + auto typeLayout = convert(inTypeLayout); + if(!typeLayout) return SLANG_PARAMETER_CATEGORY_NONE; + + return typeLayout->resourceInfos[index].kind; +} + +// Variable Reflection + +SLANG_API char const* spReflectionVariable_GetName(SlangReflectionVariable* inVar) +{ + auto var = convert(inVar); + if(!var) return nullptr; + + // If the variable is one that has an "external" name that is supposed + // to be exposed for reflection, then report it here + if(auto reflectionNameMod = var->FindModifier<ParameterBlockReflectionName>()) + return reflectionNameMod->nameToken.Content.Buffer(); + + return var->getName().Buffer(); +} + +SLANG_API SlangReflectionType* spReflectionVariable_GetType(SlangReflectionVariable* inVar) +{ + auto var = convert(inVar); + if(!var) return nullptr; + + return convert(var->getType()); +} + +// Variable Layout Reflection + +SLANG_API SlangReflectionVariable* spReflectionVariableLayout_GetVariable(SlangReflectionVariableLayout* inVarLayout) +{ + auto varLayout = convert(inVarLayout); + if(!varLayout) return nullptr; + + return convert(varLayout->varDecl.GetDecl()); +} + +SLANG_API SlangReflectionTypeLayout* spReflectionVariableLayout_GetTypeLayout(SlangReflectionVariableLayout* inVarLayout) +{ + auto varLayout = convert(inVarLayout); + if(!varLayout) return nullptr; + + return convert(varLayout->getTypeLayout()); +} + +SLANG_API size_t spReflectionVariableLayout_GetOffset(SlangReflectionVariableLayout* inVarLayout, SlangParameterCategory category) +{ + auto varLayout = convert(inVarLayout); + if(!varLayout) return 0; + + auto info = varLayout->FindResourceInfo(LayoutResourceKind(category)); + if(!info) return 0; + + return info->index; +} + +SLANG_API size_t spReflectionVariableLayout_GetSpace(SlangReflectionVariableLayout* inVarLayout, SlangParameterCategory category) +{ + auto varLayout = convert(inVarLayout); + if(!varLayout) return 0; + + auto info = varLayout->FindResourceInfo(LayoutResourceKind(category)); + if(!info) return 0; + + return info->space; +} + + +// Shader Parameter Reflection + +SLANG_API unsigned spReflectionParameter_GetBindingIndex(SlangReflectionParameter* inVarLayout) +{ + auto varLayout = convert(inVarLayout); + if(!varLayout) return 0; + + if(varLayout->resourceInfos.Count() > 0) + { + return (unsigned) varLayout->resourceInfos[0].index; + } + + return 0; +} + +SLANG_API unsigned spReflectionParameter_GetBindingSpace(SlangReflectionParameter* inVarLayout) +{ + auto varLayout = convert(inVarLayout); + if(!varLayout) return 0; + + if(varLayout->resourceInfos.Count() > 0) + { + return (unsigned) varLayout->resourceInfos[0].space; + } + + return 0; +} + +// Entry Point Reflection + +SLANG_API SlangStage spReflectionEntryPoint_getStage(SlangReflectionEntryPoint* inEntryPoint) +{ + auto entryPointLayout = convert(inEntryPoint); + + if(!entryPointLayout) return SLANG_STAGE_NONE; + + return SlangStage(entryPointLayout->profile.GetStage()); +} + +SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( + SlangReflectionEntryPoint* inEntryPoint, + SlangUInt axisCount, + SlangUInt* outSizeAlongAxis) +{ + auto entryPointLayout = convert(inEntryPoint); + + if(!entryPointLayout) return; + if(!axisCount) return; + if(!outSizeAlongAxis) return; + + auto entryPointFunc = entryPointLayout->entryPoint; + if(!entryPointFunc) return; + + auto numThreadsAttribute = entryPointFunc->FindModifier<HLSLNumThreadsAttribute>(); + if(!numThreadsAttribute) return; + + if(axisCount > 0) outSizeAlongAxis[0] = numThreadsAttribute->x; + if(axisCount > 1) outSizeAlongAxis[1] = numThreadsAttribute->y; + if(axisCount > 2) outSizeAlongAxis[2] = numThreadsAttribute->z; + for( SlangUInt aa = 3; aa < axisCount; ++aa ) + { + outSizeAlongAxis[aa] = 1; + } +} + + +// Shader Reflection + +SLANG_API unsigned spReflection_GetParameterCount(SlangReflection* inProgram) +{ + auto program = convert(inProgram); + if(!program) return 0; + + auto globalLayout = program->globalScopeLayout; + if(auto globalConstantBufferLayout = globalLayout.As<ParameterBlockTypeLayout>()) + { + globalLayout = globalConstantBufferLayout->elementTypeLayout; + } + + if(auto globalStructLayout = globalLayout.As<StructTypeLayout>()) + { + return globalStructLayout->fields.Count(); + } + + return 0; +} + +SLANG_API SlangReflectionParameter* spReflection_GetParameterByIndex(SlangReflection* inProgram, unsigned index) +{ + auto program = convert(inProgram); + if(!program) return nullptr; + + auto globalLayout = program->globalScopeLayout; + if(auto globalConstantBufferLayout = globalLayout.As<ParameterBlockTypeLayout>()) + { + globalLayout = globalConstantBufferLayout->elementTypeLayout; + } + + if(auto globalStructLayout = globalLayout.As<StructTypeLayout>()) + { + return convert(globalStructLayout->fields[index].Ptr()); + } + + return nullptr; +} + +SLANG_API SlangUInt spReflection_getEntryPointCount(SlangReflection* inProgram) +{ + auto program = convert(inProgram); + if(!program) return 0; + + return SlangUInt(program->entryPoints.Count()); +} + +SLANG_API SlangReflectionEntryPoint* spReflection_getEntryPointByIndex(SlangReflection* inProgram, SlangUInt index) +{ + auto program = convert(inProgram); + if(!program) return 0; + + return convert(program->entryPoints[(int) index].Ptr()); +} + + + + + + + + + + + + + + + + + + + + +namespace Slang { +namespace Compiler { + + + + + + + +// Debug helper code: dump reflection data after generation + +struct PrettyWriter +{ + StringBuilder sb; + bool startOfLine = true; + int indent = 0; +}; + +static void adjust(PrettyWriter& writer) +{ + if (!writer.startOfLine) + return; + + int indent = writer.indent; + for (int ii = 0; ii < indent; ++ii) + writer.sb << " "; + + writer.startOfLine = false; +} + +static void indent(PrettyWriter& writer) +{ + writer.indent++; +} + +static void dedent(PrettyWriter& writer) +{ + writer.indent--; +} + +static void write(PrettyWriter& writer, char const* text) +{ + // TODO: can do this more efficiently... + char const* cursor = text; + for(;;) + { + char c = *cursor++; + if (!c) break; + + if (c == '\n') + { + writer.startOfLine = true; + } + else + { + adjust(writer); + } + + writer.sb << c; + } +} + +static void write(PrettyWriter& writer, UInt val) +{ + adjust(writer); + writer.sb << ((unsigned int) val); +} + +static void emitReflectionVarInfoJSON(PrettyWriter& writer, slang::VariableReflection* var); +static void emitReflectionTypeLayoutJSON(PrettyWriter& writer, slang::TypeLayoutReflection* type); +static void emitReflectionTypeJSON(PrettyWriter& writer, slang::TypeReflection* type); + +static void emitReflectionVarBindingInfoJSON( + PrettyWriter& writer, + SlangParameterCategory category, + UInt index, + UInt count, + UInt space = 0) +{ + if( category == SLANG_PARAMETER_CATEGORY_UNIFORM ) + { + write(writer,"\"kind\": \"uniform\""); + write(writer, ", "); + write(writer,"\"offset\": "); + write(writer, index); + write(writer, ", "); + write(writer, "\"size\": "); + write(writer, count); + } + else + { + write(writer, "\"kind\": \""); + switch( category ) + { + #define CASE(NAME, KIND) case SLANG_PARAMETER_CATEGORY_##NAME: write(writer, #KIND); break + CASE(CONSTANT_BUFFER, constantBuffer); + CASE(SHADER_RESOURCE, shaderResource); + CASE(UNORDERED_ACCESS, unorderedAccess); + CASE(VERTEX_INPUT, vertexInput); + CASE(FRAGMENT_OUTPUT, fragmentOutput); + CASE(SAMPLER_STATE, samplerState); + #undef CASE + + default: + write(writer, "unknown"); + assert(!"unexpected"); + break; + } + write(writer, "\""); + if( space ) + { + write(writer, ", "); + write(writer, "\"space\": "); + write(writer, space); + } + write(writer, ", "); + write(writer, "\"index\": "); + write(writer, index); + if( count != 1) + { + write(writer, ", "); + write(writer, "\"count\": "); + write(writer, count); + } + } +} + +static void emitReflectionVarBindingInfoJSON( + PrettyWriter& writer, + slang::VariableLayoutReflection* var) +{ + auto typeLayout = var->getTypeLayout(); + auto categoryCount = var->getCategoryCount(); + + if( categoryCount != 1 ) + { + write(writer,"\"bindings\": [\n"); + } + else + { + write(writer,"\"binding\": "); + } + indent(writer); + + for(uint32_t cc = 0; cc < categoryCount; ++cc ) + { + auto category = var->getCategoryByIndex(cc); + auto index = var->getOffset(category); + auto space = var->getBindingSpace(category); + auto count = typeLayout->getSize(category); + + if (cc != 0) write(writer, ",\n"); + + write(writer,"{"); + emitReflectionVarBindingInfoJSON( + writer, + category, + index, + count, + space); + write(writer,"}"); + } + + dedent(writer); + if( categoryCount != 1 ) + { + write(writer,"\n]"); + } +} + +static void emitReflectionNameInfoJSON( + PrettyWriter& writer, + char const* name) +{ + // TODO: deal with escaping special characters if/when needed + write(writer, "\"name\": \""); + write(writer, name); + write(writer, "\""); +} + +static void emitReflectionVarLayoutJSON( + PrettyWriter& writer, + slang::VariableLayoutReflection* var) +{ + write(writer, "{\n"); + indent(writer); + + emitReflectionNameInfoJSON(writer, var->getName()); + write(writer, ",\n"); + + write(writer, "\"type\": "); + emitReflectionTypeLayoutJSON(writer, var->getTypeLayout()); + write(writer, ",\n"); + + emitReflectionVarBindingInfoJSON(writer, var); + + dedent(writer); + write(writer, "\n}"); +} + +static void emitReflectionScalarTypeInfoJSON( + PrettyWriter& writer, + SlangScalarType scalarType) +{ + write(writer, "\"scalarType\": \""); + switch (scalarType) + { + default: + write(writer, "unknown"); + assert(!"unexpected"); + break; +#define CASE(TAG, ID) case slang::TypeReflection::ScalarType::TAG: write(writer, #ID); break + CASE(Void, void); + CASE(Bool, bool); + CASE(Int32, int32); + CASE(UInt32, uint32); + CASE(Int64, int64); + CASE(UInt64, uint64); + CASE(Float16, float16); + CASE(Float32, float32); + CASE(Float64, float64); +#undef CASE + } + write(writer, "\""); +} + +static void emitReflectionTypeInfoJSON( + PrettyWriter& writer, + slang::TypeReflection* type) +{ + switch( type->getKind() ) + { + case SLANG_TYPE_KIND_SAMPLER_STATE: + write(writer, "\"kind\": \"samplerState\""); + break; + + case SLANG_TYPE_KIND_RESOURCE: + { + auto shape = type->getResourceShape(); + auto access = type->getResourceAccess(); + write(writer, "\"kind\": \"resource\""); + write(writer, ",\n"); + write(writer, "\"baseShape\": \""); + switch (shape & SLANG_RESOURCE_BASE_SHAPE_MASK) + { + default: + write(writer, "unknown"); + assert(!"unexpected"); + break; + +#define CASE(SHAPE, NAME) case SLANG_##SHAPE: write(writer, #NAME); break + CASE(TEXTURE_1D, texture1D); + CASE(TEXTURE_2D, texture2D); + CASE(TEXTURE_3D, texture3D); + CASE(TEXTURE_CUBE, textureCube); + CASE(TEXTURE_BUFFER, textureBuffer); + CASE(STRUCTURED_BUFFER, structuredBuffer); + CASE(BYTE_ADDRESS_BUFFER, byteAddressBuffer); +#undef CASE + } + write(writer, "\""); + if (shape & SLANG_TEXTURE_ARRAY_FLAG) + { + write(writer, ",\n"); + write(writer, "\"array\": true"); + } + if (shape & SLANG_TEXTURE_MULTISAMPLE_FLAG) + { + write(writer, ",\n"); + write(writer, "\"multisample\": true"); + } + + if( access != SLANG_RESOURCE_ACCESS_READ ) + { + write(writer, ",\n\"access\": \""); + switch(access) + { + default: + write(writer, "unknown"); + assert(!"unexpected"); + break; + + case SLANG_RESOURCE_ACCESS_READ: + break; + + case SLANG_RESOURCE_ACCESS_READ_WRITE: write(writer, "readWrite"); break; + case SLANG_RESOURCE_ACCESS_RASTER_ORDERED: write(writer, "rasterOrdered"); break; + case SLANG_RESOURCE_ACCESS_APPEND: write(writer, "append"); break; + case SLANG_RESOURCE_ACCESS_CONSUME: write(writer, "consume"); break; + } + write(writer, "\""); + } + } + break; + + case SLANG_TYPE_KIND_CONSTANT_BUFFER: + write(writer, "\"kind\": \"constantBuffer\""); + write(writer, ",\n"); + write(writer, "\"elementType\": "); + emitReflectionTypeJSON( + writer, + type->getElementType()); + break; + + case SLANG_TYPE_KIND_SCALAR: + write(writer, "\"kind\": \"scalar\""); + write(writer, ",\n"); + emitReflectionScalarTypeInfoJSON( + writer, + type->getScalarType()); + break; + + case SLANG_TYPE_KIND_VECTOR: + write(writer, "\"kind\": \"vector\""); + write(writer, ",\n"); + write(writer, "\"elementCount\": "); + write(writer, type->getElementCount()); + write(writer, ",\n"); + write(writer, "\"elementType\": "); + emitReflectionTypeJSON( + writer, + type->getElementType()); + break; + + case SLANG_TYPE_KIND_MATRIX: + write(writer, "\"kind\": \"matrix\""); + write(writer, ",\n"); + write(writer, "\"rowCount\": "); + write(writer, type->getRowCount()); + write(writer, ",\n"); + write(writer, "\"columnCount\": "); + write(writer, type->getColumnCount()); + write(writer, ",\n"); + write(writer, "\"elementType\": "); + emitReflectionTypeJSON( + writer, + type->getElementType()); + break; + + case SLANG_TYPE_KIND_ARRAY: + { + auto arrayType = type; + write(writer, "\"kind\": \"array\""); + write(writer, ",\n"); + write(writer, "\"elementCount\": "); + write(writer, arrayType->getElementCount()); + write(writer, ",\n"); + write(writer, "\"elementType\": "); + emitReflectionTypeJSON(writer, arrayType->getElementType()); + } + break; + + case SLANG_TYPE_KIND_STRUCT: + { + write(writer, "\"kind\": \"struct\",\n"); + write(writer, "\"fields\": [\n"); + indent(writer); + + auto structType = type; + auto fieldCount = structType->getFieldCount(); + for( uint32_t ff = 0; ff < fieldCount; ++ff ) + { + if (ff != 0) write(writer, ",\n"); + emitReflectionVarInfoJSON( + writer, + structType->getFieldByIndex(ff)); + } + dedent(writer); + write(writer, "\n]"); + } + break; + + default: + assert(!"unimplemented"); + break; + } +} + +static void emitReflectionTypeLayoutInfoJSON( + PrettyWriter& writer, + slang::TypeLayoutReflection* typeLayout) +{ + switch( typeLayout->getKind() ) + { + default: + emitReflectionTypeInfoJSON(writer, typeLayout->getType()); + break; + + case SLANG_TYPE_KIND_ARRAY: + { + auto arrayTypeLayout = typeLayout; + auto elementTypeLayout = arrayTypeLayout->getElementTypeLayout(); + write(writer, "\"kind\": \"array\""); + write(writer, ",\n"); + write(writer, "\"elementCount\": "); + write(writer, arrayTypeLayout->getElementCount()); + write(writer, ",\n"); + write(writer, "\"elementType\": "); + emitReflectionTypeLayoutJSON( + writer, + elementTypeLayout); + if (arrayTypeLayout->getSize(SLANG_PARAMETER_CATEGORY_UNIFORM) != 0) + { + write(writer, ",\n"); + write(writer, "\"uniformStride\": "); + write(writer, arrayTypeLayout->getElementStride(SLANG_PARAMETER_CATEGORY_UNIFORM)); + } + } + break; + + case SLANG_TYPE_KIND_STRUCT: + { + write(writer, "\"kind\": \"struct\",\n"); + write(writer, "\"fields\": [\n"); + indent(writer); + + auto structTypeLayout = typeLayout; + auto fieldCount = structTypeLayout->getFieldCount(); + for( uint32_t ff = 0; ff < fieldCount; ++ff ) + { + if (ff != 0) write(writer, ",\n"); + emitReflectionVarLayoutJSON( + writer, + structTypeLayout->getFieldByIndex(ff)); + } + dedent(writer); + write(writer, "\n]"); + } + break; + + case SLANG_TYPE_KIND_CONSTANT_BUFFER: + write(writer, "\"kind\": \"constantBuffer\""); + write(writer, ",\n"); + write(writer, "\"elementType\": "); + emitReflectionTypeLayoutJSON( + writer, + typeLayout->getElementTypeLayout()); + break; + + } + + // TODO: emit size info for types +} + +static void emitReflectionTypeLayoutJSON( + PrettyWriter& writer, + slang::TypeLayoutReflection* typeLayout) +{ + write(writer, "{\n"); + indent(writer); + emitReflectionTypeLayoutInfoJSON(writer, typeLayout); + dedent(writer); + write(writer, "\n}"); +} + +static void emitReflectionTypeJSON( + PrettyWriter& writer, + slang::TypeReflection* type) +{ + write(writer, "{\n"); + indent(writer); + emitReflectionTypeInfoJSON(writer, type); + dedent(writer); + write(writer, "\n}"); +} + +static void emitReflectionVarInfoJSON( + PrettyWriter& writer, + slang::VariableReflection* var) +{ + emitReflectionNameInfoJSON(writer, var->getName()); + write(writer, ",\n"); + + write(writer, "\"type\": "); + emitReflectionTypeJSON(writer, var->getType()); +} + +#if 0 +static void emitReflectionBindingInfoJSON( + PrettyWriter& writer, + + ReflectionParameterNode* param) +{ + auto info = ¶m->binding; + + if( info->category == SLANG_PARAMETER_CATEGORY_MIXED ) + { + write(writer,"\"bindings\": [\n"); + indent(writer); + + ReflectionSize bindingCount = info->bindingCount; + assert(bindingCount); + ReflectionParameterBindingInfo* bindings = info->bindings; + for( ReflectionSize bb = 0; bb < bindingCount; ++bb ) + { + if (bb != 0) write(writer, ",\n"); + + write(writer,"{"); + auto& binding = bindings[bb]; + emitReflectionVarBindingInfoJSON( + writer, + binding.category, + binding.index, + (ReflectionSize) param->GetTypeLayout()->GetSize(binding.category), + binding.space); + + write(writer,"}"); + } + dedent(writer); + write(writer,"\n]"); + } + else + { + write(writer,"\"binding\": {"); + indent(writer); + + emitReflectionVarBindingInfoJSON( + writer, + info->category, + info->index, + (ReflectionSize) param->GetTypeLayout()->GetSize(info->category), + info->space); + + dedent(writer); + write(writer,"}"); + } +} +#endif + +static void emitReflectionParamJSON( + PrettyWriter& writer, + slang::VariableLayoutReflection* param) +{ + write(writer, "{\n"); + indent(writer); + + emitReflectionNameInfoJSON(writer, param->getName()); + write(writer, ",\n"); + + emitReflectionVarBindingInfoJSON(writer, param); + write(writer, ",\n"); + + write(writer, "\"type\": "); + emitReflectionTypeLayoutJSON(writer, param->getTypeLayout()); + + dedent(writer); + write(writer, "\n}"); +} + +template<typename T> +struct Range +{ +public: + Range( + T begin, + T end) + : mBegin(begin) + , mEnd(end) + {} + + struct Iterator + { + public: + explicit Iterator(T value) + : mValue(value) + {} + + T operator*() const { return mValue; } + void operator++() { mValue++; } + + bool operator!=(Iterator const& other) + { + return mValue != other.mValue; + } + + private: + T mValue; + }; + + Iterator begin() const { return Iterator(mBegin); } + Iterator end() const { return Iterator(mEnd); } + +private: + T mBegin; + T mEnd; +}; + +template<typename T> +Range<T> range(T begin, T end) +{ + return Range<T>(begin, end); +} + +template<typename T> +Range<T> range(T end) +{ + return Range<T>(T(0), end); +} + +static void emitReflectionJSON( + PrettyWriter& writer, + slang::ShaderReflection* programReflection) +{ + write(writer, "{\n"); + indent(writer); + write(writer, "\"parameters\": [\n"); + indent(writer); + + auto parameterCount = programReflection->getParameterCount(); + for( auto pp : range(parameterCount) ) + { + if(pp != 0) write(writer, ",\n"); + + auto parameter = programReflection->getParameterByIndex(pp); + emitReflectionParamJSON(writer, parameter); + } + + dedent(writer); + write(writer, "\n]"); + dedent(writer); + write(writer, "\n}\n"); +} + +#if 0 +ReflectionBlob* ReflectionBlob::Create( + CollectionOfTranslationUnits* program) +{ + ReflectionGenerationContext context; + ReflectionBlob* blob = GenerateReflectionBlob(&context, program); +#if 0 + String debugDump = blob->emitAsJSON(); + OutputDebugStringA("REFLECTION BLOB\n"); + OutputDebugStringA(debugDump.begin()); +#endif + return blob; +} +#endif + +// JSON emit logic + + + +String emitReflectionJSON( + ProgramLayout* programLayout) +{ + auto programReflection = (slang::ShaderReflection*) programLayout; + + PrettyWriter writer; + emitReflectionJSON(writer, programReflection); + return writer.sb.ProduceString(); +} + +}} diff --git a/source/slang/reflection.h b/source/slang/reflection.h new file mode 100644 index 000000000..4d2c53084 --- /dev/null +++ b/source/slang/reflection.h @@ -0,0 +1,39 @@ +#ifndef SLANG_REFLECTION_H +#define SLANG_REFLECTION_H + +#include "../core/basic.h" +#include "syntax.h" + +#include "../../slang.h" + +namespace Slang { + +// TODO(tfoley): Need to move these somewhere universal + +typedef intptr_t Int; +typedef int64_t Int64; + +typedef uintptr_t UInt; +typedef uint64_t UInt64; + +namespace Compiler { + +class ProgramLayout; +class TypeLayout; + +String emitReflectionJSON( + ProgramLayout* programLayout); + +// + +SlangTypeKind getReflectionTypeKind(ExpressionType* type); + +SlangTypeKind getReflectionParameterCategory(TypeLayout* typeLayout); + +UInt getReflectionFieldCount(ExpressionType* type); +UInt getReflectionFieldByIndex(ExpressionType* type, UInt index); +UInt getReflectionFieldByIndex(TypeLayout* typeLayout, UInt index); + +}} + +#endif // SLANG_REFLECTION_H diff --git a/source/slang/slang-stdlib.cpp b/source/slang/slang-stdlib.cpp new file mode 100644 index 000000000..f74fcd603 --- /dev/null +++ b/source/slang/slang-stdlib.cpp @@ -0,0 +1,1855 @@ +// slang-stdlib.cpp + +#include "slang-stdlib.h" +#include "syntax.h" + +#define STRINGIZE(x) STRINGIZE2(x) +#define STRINGIZE2(x) #x +#define LINE_STRING STRINGIZE(__LINE__) + +enum { kLibIncludeStringLine = __LINE__+1 }; +const char * LibIncludeStringChunks[] = { R"( + +typedef uint UINT; + +__generic<T> __intrinsic(Assign) T operator=(out T left, T right); + +__generic<T,U> __intrinsic(Sequence) U operator,(T left, U right); + +__generic<T> __intrinsic(Select) T operator?:(bool condition, T ifTrue, T ifFalse); +__generic<T, let N : int> __intrinsic(Select) vector<T,N> operator?:(vector<bool,N> condition, vector<T,N> ifTrue, vector<T,N> ifFalse); + +__generic<T> __magic_type(HLSLAppendStructuredBufferType) struct AppendStructuredBuffer +{ + __intrinsic void Append(T value); + + __intrinsic void GetDimensions( + out uint numStructs, + out uint stride); +}; + +__generic<T> __magic_type(HLSLBufferType) struct Buffer +{ + __intrinsic void GetDimensions( + out uint dim); + + __intrinsic T Load(int location); + __intrinsic T Load(int location, out uint status); + + __intrinsic __subscript(uint index) -> T; +}; + +__magic_type(HLSLByteAddressBufferType) struct ByteAddressBuffer +{ + __intrinsic void GetDimensions( + out uint dim); + + __intrinsic uint Load(int location); + __intrinsic uint Load(int location, out uint status); + + __intrinsic uint2 Load2(int location); + __intrinsic uint2 Load2(int location, out uint status); + + __intrinsic uint3 Load3(int location); + __intrinsic uint3 Load3(int location, out uint status); + + __intrinsic uint4 Load4(int location); + __intrinsic uint4 Load4(int location, out uint status); +}; + +__generic<T> __magic_type(HLSLStructuredBufferType) struct StructuredBuffer +{ + __intrinsic void GetDimensions( + out uint numStructs, + out uint stride); + + __intrinsic T Load(int location); + __intrinsic T Load(int location, out uint status); + + __intrinsic __subscript(uint index) -> T; +}; + +__generic<T> __magic_type(HLSLConsumeStructuredBufferType) struct ConsumeStructuredBuffer +{ + __intrinsic T Consume(); + + __intrinsic void GetDimensions( + out uint numStructs, + out uint stride); +}; + +__generic<T> __magic_type(HLSLInputPatchType) struct InputPatch +{ + __intrinsic __subscript(uint index) -> T; +}; + +__generic<T> __magic_type(HLSLOutputPatchType) struct OutputPatch +{ + __intrinsic __subscript(uint index) -> T { set; } +}; + +__generic<T> __magic_type(HLSLRWBufferType) struct RWBuffer +{ + // Note(tfoley): duplication with declaration of `Buffer` + + __intrinsic void GetDimensions( + out uint dim); + + __intrinsic T Load(int location); + __intrinsic T Load(int location, out uint status); + + __intrinsic __subscript(uint index) -> T { get; set; } +}; + +__magic_type(HLSLRWByteAddressBufferType) struct RWByteAddressBuffer +{ + // Note(tfoley): supports alll operations from `ByteAddressBuffer` + // TODO(tfoley): can this be made a sub-type? + + __intrinsic void GetDimensions( + out uint dim); + + __intrinsic uint Load(int location); + __intrinsic uint Load(int location, out uint status); + + __intrinsic uint2 Load2(int location); + __intrinsic uint2 Load2(int location, out uint status); + + __intrinsic uint3 Load3(int location); + __intrinsic uint3 Load3(int location, out uint status); + + __intrinsic uint4 Load4(int location); + __intrinsic uint4 Load4(int location, out uint status); + + // Added operations: + + __intrinsic void InterlockedAdd( + UINT dest, + UINT value, + out UINT original_value); + __intrinsic void InterlockedAdd( + UINT dest, + UINT value); + + __intrinsic void InterlockedAnd( + UINT dest, + UINT value, + out UINT original_value); + __intrinsic void InterlockedAnd( + UINT dest, + UINT value); + + __intrinsic void InterlockedCompareExchange( + UINT dest, + UINT compare_value, + UINT value, + out UINT original_value); + __intrinsic void InterlockedCompareExchange( + UINT dest, + UINT compare_value, + UINT value); + + __intrinsic void InterlockedCompareStore( + UINT dest, + UINT compare_value, + UINT value); + __intrinsic void InterlockedCompareStore( + UINT dest, + UINT compare_value); + + __intrinsic void InterlockedExchange( + UINT dest, + UINT value, + out UINT original_value); + __intrinsic void InterlockedExchange( + UINT dest, + UINT value); + + __intrinsic void InterlockedMax( + UINT dest, + UINT value, + out UINT original_value); + __intrinsic void InterlockedMax( + UINT dest, + UINT value); + + __intrinsic void InterlockedMin( + UINT dest, + UINT value, + out UINT original_value); + __intrinsic void InterlockedMin( + UINT dest, + UINT value); + + __intrinsic void InterlockedOr( + UINT dest, + UINT value, + out UINT original_value); + __intrinsic void InterlockedOr( + UINT dest, + UINT value); + + __intrinsic void InterlockedXor( + UINT dest, + UINT value, + out UINT original_value); + __intrinsic void InterlockedXor( + UINT dest, + UINT value); + + __intrinsic void Store( + uint address, + uint value); + + __intrinsic void Store2( + uint address, + uint2 value); + + __intrinsic void Store3( + uint address, + uint3 value); + + __intrinsic void Store4( + uint address, + uint4 value); +}; + +__generic<T> __magic_type(HLSLRWStructuredBufferType) struct RWStructuredBuffer +{ + __intrinsic uint DecrementCounter(); + + __intrinsic void GetDimensions( + out uint numStructs, + out uint stride); + + __intrinsic void IncrementCounter(); + + __intrinsic T Load(int location); + __intrinsic T Load(int location, out uint status); + + __intrinsic __subscript(uint index) -> T { get; set; } +}; + +__generic<T> __magic_type(HLSLPointStreamType) struct PointStream {}; +__generic<T> __magic_type(HLSLLineStreamType) struct LineStream {}; +__generic<T> __magic_type(HLSLLineStreamType) struct TriangleStream {}; + +)", R"( + +// Note(tfoley): Trying to systematically add all the HLSL builtins + +// A type that can be used as an operand for builtins +__trait __BuiltinType {} + +// A type that can be used for arithmetic operations +__trait __BuiltinArithmeticType : __BuiltinType {} + +// A type that logically has a sign (positive/negative/zero) +__trait __BuiltinSignedArithmeticType : __BuiltinArithmeticType {} + +// A type that can represent integers +__trait __BuiltinIntegerType : __BuiltinArithmeticType {} + +// A type that can represent non-integers +__trait __BuiltinRealType : __BuiltinArithmeticType {} + +// A type that uses a floating-point representation +__trait __BuiltinFloatingPointType : __BuiltinRealType, __BuiltinSignedType {} + +// Try to terminate the current draw or dispatch call (HLSL SM 4.0) +__intrinsic void abort(); + +// Absolute value (HLSL SM 1.0) +__generic<T : __BuiltinSignedArithmeticType> __intrinsic T abs(T x); +__generic<T : __BuiltinSignedArithmeticType, let N : int> __intrinsic vector<T,N> abs(vector<T,N> x); +__generic<T : __BuiltinSignedArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> abs(matrix<T,N,M> x); + +// Inverse cosine (HLSL SM 1.0) +__generic<T : __BuiltinFloatingPointType> __intrinsic T acos(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> acos(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> acos(matrix<T,N,M> x); + +// Test if all components are non-zero (HLSL SM 1.0) +__generic<T : __BuiltinType> __intrinsic T all(T x); +__generic<T : __BuiltinType, let N : int> __intrinsic vector<T,N> all(vector<T,N> x); +__generic<T : __BuiltinType, let N : int, let M : int> __intrinsic matrix<T,N,M> all(matrix<T,N,M> x); + +// Barrier for writes to all memory spaces (HLSL SM 5.0) +__intrinsic void AllMemoryBarrier(); + +// Thread-group sync and barrier for writes to all memory spaces (HLSL SM 5.0) +__intrinsic void AllMemoryBarrierWithGroupSync(); + +// Test if any components is non-zero (HLSL SM 1.0) +__generic<T : __BuiltinType> __intrinsic T any(T x); +__generic<T : __BuiltinType, let N : int> __intrinsic vector<T,N> any(vector<T,N> x); +__generic<T : __BuiltinType, let N : int, let M : int> __intrinsic matrix<T,N,M> any(matrix<T,N,M> x); + + +// Reinterpret bits as a double (HLSL SM 5.0) +__intrinsic double asdouble(uint lowbits, uint highbits); + +// Reinterpret bits as a float (HLSL SM 4.0) +__intrinsic float asfloat( int x); +__intrinsic float asfloat(uint x); +__generic<let N : int> __intrinsic vector<float,N> asfloat(vector< int,N> x); +__generic<let N : int> __intrinsic vector<float,N> asfloat(vector<uint,N> x); +__generic<let N : int, let M : int> __intrinsic matrix<float,N,M> asfloat(matrix< int,N,M> x); +__generic<let N : int, let M : int> __intrinsic matrix<float,N,M> asfloat(matrix<uint,N,M> x); + + +// Inverse sine (HLSL SM 1.0) +__generic<T : __BuiltinFloatingPointType> __intrinsic T asin(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> asin(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> asin(matrix<T,N,M> x); + +// Reinterpret bits as an int (HLSL SM 4.0) +__intrinsic int asint(float x); +__intrinsic int asint(uint x); +__generic<let N : int> __intrinsic vector<int,N> asint(vector<float,N> x); +__generic<let N : int> __intrinsic vector<int,N> asint(vector<uint,N> x); +__generic<let N : int, let M : int> __intrinsic matrix<int,N,M> asint(matrix<float,N,M> x); +__generic<let N : int, let M : int> __intrinsic matrix<int,N,M> asint(matrix<uint,N,M> x); + +// Reinterpret bits of double as a uint (HLSL SM 5.0) +__intrinsic void asuint(double value, out uint lowbits, out uint highbits); + +// Reinterpret bits as a uint (HLSL SM 4.0) +__intrinsic uint asuint(float x); +__intrinsic uint asuint(int x); +__generic<let N : int> __intrinsic vector<uint,N> asuint(vector<float,N> x); +__generic<let N : int> __intrinsic vector<uint,N> asuint(vector<int,N> x); +__generic<let N : int, let M : int> __intrinsic matrix<uint,N,M> asuint(matrix<float,N,M> x); +__generic<let N : int, let M : int> __intrinsic matrix<uint,N,M> asuint(matrix<int,N,M> x); + +// Inverse tangent (HLSL SM 1.0) +__generic<T : __BuiltinFloatingPointType> __intrinsic T atan(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> atan(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> atan(matrix<T,N,M> x); + +__generic<T : __BuiltinFloatingPointType> __intrinsic T atan2(T y, T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> atan2(vector<T,N> y, vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> atan2(matrix<T,N,M> y, matrix<T,N,M> x); + +// Ceiling (HLSL SM 1.0) +__generic<T : __BuiltinFloatingPointType> __intrinsic T ceil(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> ceil(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> ceil(matrix<T,N,M> x); + + +// Check access status to tiled resource +__intrinsic bool CheckAccessFullyMapped(uint status); + +// Clamp (HLSL SM 1.0) +__generic<T : __BuiltinArithmeticType> __intrinsic T clamp(T x, T min, T max); +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> clamp(vector<T,N> x, vector<T,N> min, vector<T,N> max); +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> clamp(matrix<T,N,M> x, matrix<T,N,M> min, matrix<T,N,M> max); + +// Clip (discard) fragment conditionally +__generic<T : __BuiltinFloatingPointType> __intrinsic void clip(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic void clip(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic void clip(matrix<T,N,M> x); + +// Cosine +__generic<T : __BuiltinFloatingPointType> __intrinsic T cos(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> cos(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> cos(matrix<T,N,M> x); + +// Hyperbolic cosine +__generic<T : __BuiltinFloatingPointType> __intrinsic T cosh(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> cosh(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> cosh(matrix<T,N,M> x); + +// Population count +__intrinsic uint countbits(uint value); + +// Cross product +__generic<T : __BuiltinArithmeticType> __intrinsic vector<T,3> cross(vector<T,3> x, vector<T,3> y); + +// Convert encoded color +__intrinsic int4 D3DCOLORtoUBYTE4(float4 x); + +// Partial-difference derivatives +__generic<T : __BuiltinFloatingPointType> __intrinsic T ddx(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> ddx(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> ddx(matrix<T,N,M> x); + +__generic<T : __BuiltinFloatingPointType> __intrinsic T ddx_coarse(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> ddx_coarse(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> ddx_coarse(matrix<T,N,M> x); + +__generic<T : __BuiltinFloatingPointType> __intrinsic T ddx_fine(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> ddx_fine(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> ddx_fine(matrix<T,N,M> x); + +__generic<T : __BuiltinFloatingPointType> __intrinsic T ddy(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> ddy(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> ddy(matrix<T,N,M> x); + +__generic<T : __BuiltinFloatingPointType> __intrinsic T ddy_coarse(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> ddy_coarse(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> ddy_coarse(matrix<T,N,M> x); + +__generic<T : __BuiltinFloatingPointType> __intrinsic T ddy_fine(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> ddy_fine(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> ddy_fine(matrix<T,N,M> x); + + +// Radians to degrees +__generic<T : __BuiltinFloatingPointType> __intrinsic T degrees(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> degrees(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> degrees(matrix<T,N,M> x); + +// Matrix determinant + +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic T determinant(matrix<T,N,N> m); + +// Barrier for device memory +__intrinsic void DeviceMemoryBarrier(); +__intrinsic void DeviceMemoryBarrierWithGroupSync(); + +// Vector distance + +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic T distance(vector<T,N> x, vector<T,N> y); + +// Vector dot product + +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic T dot(vector<T,N> x, vector<T,N> y); + +// Helper for computing distance terms for lighting (obsolete) + +__generic<T : __BuiltinFloatingPointType> __intrinsic vector<T,4> dst(vector<T,4> x, vector<T,4> y); + +// Error message + +// __intrinsic void errorf( string format, ... ); + +// Attribute evaluation + +__generic<T : __BuiltinArithmeticType> __intrinsic T EvaluateAttributeAtCentroid(T x); +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> EvaluateAttributeAtCentroid(vector<T,N> x); +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> EvaluateAttributeAtCentroid(matrix<T,N,M> x); + +__generic<T : __BuiltinArithmeticType> __intrinsic T EvaluateAttributeAtSample(T x, uint sampleindex); +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> EvaluateAttributeAtSample(vector<T,N> x, uint sampleindex); +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> EvaluateAttributeAtSample(matrix<T,N,M> x, uint sampleindex); + +__generic<T : __BuiltinArithmeticType> __intrinsic T EvaluateAttributeSnapped(T x, int2 offset); +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> EvaluateAttributeSnapped(vector<T,N> x, int2 offset); +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> EvaluateAttributeSnapped(matrix<T,N,M> x, int2 offset); + +// Base-e exponent +__generic<T : __BuiltinFloatingPointType> __intrinsic T exp(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> exp(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> exp(matrix<T,N,M> x); + +// Base-2 exponent +__generic<T : __BuiltinFloatingPointType> __intrinsic T exp2(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> exp2(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> exp2(matrix<T,N,M> x); + +// Convert 16-bit float stored in low bits of integer +__intrinsic float f16tof32(uint value); +__generic<let N : int> __intrinsic vector<float,N> f16tof32(vector<uint,N> value); + +// Convert to 16-bit float stored in low bits of integer +__intrinsic uint f32tof16(float value); +__generic<let N : int> __intrinsic vector<uint,N> f32tof16(vector<float,N> value); + +// Flip surface normal to face forward, if needed +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> faceforward(vector<T,N> n, vector<T,N> i, vector<T,N> ng); + +// Find first set bit starting at high bit and working down +__intrinsic int firstbithigh(int value); +__generic<let N : int> __intrinsic vector<int,N> firstbithigh(vector<int,N> value); + +__intrinsic uint firstbithigh(uint value); +__generic<let N : int> __intrinsic vector<uint,N> firstbithigh(vector<uint,N> value); + +// Find first set bit starting at low bit and working up +__intrinsic int firstbitlow(int value); +__generic<let N : int> __intrinsic vector<int,N> firstbitlow(vector<int,N> value); + +__intrinsic uint firstbitlow(uint value); +__generic<let N : int> __intrinsic vector<uint,N> firstbitlow(vector<uint,N> value); + +// Floor (HLSL SM 1.0) +__generic<T : __BuiltinFloatingPointType> __intrinsic T floor(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> floor(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> floor(matrix<T,N,M> x); + +// Fused multiply-add for doubles +__intrinsic double fma(double a, double b, double c); +__generic<let N : int> __intrinsic vector<double, N> fma(vector<double, N> a, vector<double, N> b, vector<double, N> c); +__generic<let N : int, let M : int> __intrinsic matrix<double,N,M> fma(matrix<double,N,M> a, matrix<double,N,M> b, matrix<double,N,M> c); + +// Floating point remainder of x/y +__generic<T : __BuiltinFloatingPointType> __intrinsic T fmod(T x, T y); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> fmod(vector<T,N> x, vector<T,N> y); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> fmod(matrix<T,N,M> x, matrix<T,N,M> y); + +// Fractional part +__generic<T : __BuiltinFloatingPointType> __intrinsic T frac(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> frac(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> frac(matrix<T,N,M> x); + +// Split float into mantissa and exponent +__generic<T : __BuiltinFloatingPointType> __intrinsic T frexp(T x, out T exp); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> frexp(vector<T,N> x, out vector<T,N> exp); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> frexp(matrix<T,N,M> x, out matrix<T,N,M> exp); + +// Texture filter width +__generic<T : __BuiltinFloatingPointType> __intrinsic T fwidth(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> fwidth(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> fwidth(matrix<T,N,M> x); + +)", R"( + +// Get number of samples in render target +__intrinsic uint GetRenderTargetSampleCount(); + +// Get position of given sample +__intrinsic float2 GetRenderTargetSamplePosition(int Index); + +// Group memory barrier +__intrinsic void GroupMemoryBarrier(); +__intrinsic void GroupMemoryBarrierWithGroupSync(); + +// Atomics +__intrinsic void InterlockedAdd(in out int dest, int value, out int original_value); +__intrinsic void InterlockedAdd(in out uint dest, uint value, out uint original_value); + +__intrinsic void InterlockedAnd(in out int dest, int value, out int original_value); +__intrinsic void InterlockedAnd(in out uint dest, uint value, out uint original_value); + +__intrinsic void InterlockedCompareExchange(in out int dest, int compare_value, int value, out int original_value); +__intrinsic void InterlockedCompareExchange(in out uint dest, uint compare_value, uint value, out uint original_value); + +__intrinsic void InterlockedCompareStore(in out int dest, int compare_value, int value); +__intrinsic void InterlockedCompareStore(in out uint dest, uint compare_value, uint value); + +__intrinsic void InterlockedExchange(in out int dest, int value, out int original_value); +__intrinsic void InterlockedExchange(in out uint dest, uint value, out uint original_value); + +__intrinsic void InterlockedMax(in out int dest, int value, out int original_value); +__intrinsic void InterlockedMax(in out uint dest, uint value, out uint original_value); + +__intrinsic void InterlockedMin(in out int dest, int value, out int original_value); +__intrinsic void InterlockedMin(in out uint dest, uint value, out uint original_value); + +__intrinsic void InterlockedOr(in out int dest, int value, out int original_value); +__intrinsic void InterlockedOr(in out uint dest, uint value, out uint original_value); + +__intrinsic void InterlockedXor(in out int dest, int value, out int original_value); +__intrinsic void InterlockedXor(in out uint dest, uint value, out uint original_value); + +// Is floating-point value finite? +__generic<T : __BuiltinFloatingPointType> __intrinsic bool isfinite(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<bool,N> isfinite(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<bool,N,M> isfinite(matrix<T,N,M> x); + +// Is floating-point value infinite? +__generic<T : __BuiltinFloatingPointType> __intrinsic bool isinf(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<bool,N> isinf(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<bool,N,M> isinf(matrix<T,N,M> x); + +// Is floating-point value not-a-number? +__generic<T : __BuiltinFloatingPointType> __intrinsic bool isnan(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<bool,N> isnan(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<bool,N,M> isnan(matrix<T,N,M> x); + +// Construct float from mantissa and exponent +__generic<T : __BuiltinFloatingPointType> __intrinsic T ldexp(T x, T exp); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> ldexp(vector<T,N> x, vector<T,N> exp); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> ldexp(matrix<T,N,M> x, matrix<T,N,M> exp); + +// Vector length +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic T length(vector<T,N> x); + +// Linear interpolation +__generic<T : __BuiltinFloatingPointType> __intrinsic T lerp(T x, T y, T s); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> lerp(vector<T,N> x, vector<T,N> y, vector<T,N> s); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> lerp(matrix<T,N,M> x, matrix<T,N,M> y, matrix<T,N,M> s); + +// Legacy lighting function (obsolete) +__intrinsic float4 lit(float n_dot_l, float n_dot_h, float m); + +// Base-e logarithm +__generic<T : __BuiltinFloatingPointType> __intrinsic T log(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> log(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> log(matrix<T,N,M> x); + +// Base-10 logarithm +__generic<T : __BuiltinFloatingPointType> __intrinsic T log10(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> log10(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> log10(matrix<T,N,M> x); + +// Base-2 logarithm +__generic<T : __BuiltinFloatingPointType> __intrinsic T log2(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> log2(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> log2(matrix<T,N,M> x); + +// multiply-add +__generic<T : __BuiltinArithmeticType> __intrinsic T mad(T mvalue, T avalue, T bvalue); +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> mad(vector<T,N> mvalue, vector<T,N> avalue, vector<T,N> bvalue); +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> mad(matrix<T,N,M> mvalue, matrix<T,N,M> avalue, matrix<T,N,M> bvalue); + +// maximum +__generic<T : __BuiltinArithmeticType> __intrinsic T max(T x, T y); +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> max(vector<T,N> x, vector<T,N> y); +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> max(matrix<T,N,M> x, matrix<T,N,M> y); + +// minimum +__generic<T : __BuiltinArithmeticType> __intrinsic T min(T x, T y); +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> min(vector<T,N> x, vector<T,N> y); +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> min(matrix<T,N,M> x, matrix<T,N,M> y); + +// split into integer and fractional parts (both with same sign) +__generic<T : __BuiltinFloatingPointType> __intrinsic T modf(T x, out T ip); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> modf(vector<T,N> x, out vector<T,N> ip); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> modf(matrix<T,N,M> x, out matrix<T,N,M> ip); + +// msad4 (whatever that is) +__intrinsic uint4 msad4(uint reference, uint2 source, uint4 accum); + +// General inner products + +// scalar-scalar +__generic<T : __BuiltinArithmeticType> __intrinsic(Mul_Scalar_Scalar) T mul(T x, T y); + +// scalar-vector and vector-scalar +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic(Mul_Vector_Scalar) vector<T,N> mul(vector<T,N> x, T y); +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic(Mul_Scalar_Vector) vector<T,N> mul(T x, vector<T,N> y); + +// scalar-matrix and matrix-scalar +__generic<T : __BuiltinArithmeticType, let N : int, let M :int> __intrinsic(Mul_Matrix_Scalar) matrix<T,N,M> mul(matrix<T,N,M> x, T y); +__generic<T : __BuiltinArithmeticType, let N : int, let M :int> __intrinsic(Mul_Scalar_Matrix) matrix<T,N,M> mul(T x, matrix<T,N,M> y); + +// vector-vector (dot product) +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic(InnerProduct_Vector_Vector) T mul(vector<T,N> x, vector<T,N> y); + +// vector-matrix +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic(InnerProduct_Vector_Matrix) vector<T,M> mul(vector<T,N> x, matrix<T,N,M> y); + +// matrix-vector +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic(InnerProduct_Matrix_Vector) vector<T,N> mul(matrix<T,N,M> x, vector<T,M> y); + +// matrix-matrix +__generic<T : __BuiltinArithmeticType, let R : int, let N : int, let C : int> __intrinsic(InnerProduct_Matrix_Matrix) matrix<T,R,C> mul(matrix<T,R,N> x, matrix<T,N,C> y); + +// noise (deprecated) +__intrinsic float noise(float x); +__generic<let N : int> __intrinsic float noise(vector<float, N> x); + +// Normalize a vector +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> normalize(vector<T,N> x); + +// Raise to a power +__generic<T : __BuiltinFloatingPointType> __intrinsic T pow(T x, T y); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> pow(vector<T,N> x, vector<T,N> y); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> pow(matrix<T,N,M> x, matrix<T,N,M> y); + +// Output message + +// __intrinsic void printf( string format, ... ); + +// Tessellation factor fixup routines + +__intrinsic void Process2DQuadTessFactorsAvg( + in float4 RawEdgeFactors, + in float2 InsideScale, + out float4 RoundedEdgeTessFactors, + out float2 RoundedInsideTessFactors, + out float2 UnroundedInsideTessFactors); + +__intrinsic void Process2DQuadTessFactorsMax( + in float4 RawEdgeFactors, + in float2 InsideScale, + out float4 RoundedEdgeTessFactors, + out float2 RoundedInsideTessFactors, + out float2 UnroundedInsideTessFactors); + +__intrinsic void Process2DQuadTessFactorsMin( + in float4 RawEdgeFactors, + in float2 InsideScale, + out float4 RoundedEdgeTessFactors, + out float2 RoundedInsideTessFactors, + out float2 UnroundedInsideTessFactors); + +__intrinsic void ProcessIsolineTessFactors( + in float RawDetailFactor, + in float RawDensityFactor, + out float RoundedDetailFactor, + out float RoundedDensityFactor); + +__intrinsic void ProcessQuadTessFactorsAvg( + in float4 RawEdgeFactors, + in float InsideScale, + out float4 RoundedEdgeTessFactors, + out float2 RoundedInsideTessFactors, + out float2 UnroundedInsideTessFactors); + +__intrinsic void ProcessQuadTessFactorsMax( + in float4 RawEdgeFactors, + in float InsideScale, + out float4 RoundedEdgeTessFactors, + out float2 RoundedInsideTessFactors, + out float2 UnroundedInsideTessFactors); + +__intrinsic void ProcessQuadTessFactorsMin( + in float4 RawEdgeFactors, + in float InsideScale, + out float4 RoundedEdgeTessFactors, + out float2 RoundedInsideTessFactors, + out float2 UnroundedInsideTessFactors); + +__intrinsic void ProcessTriTessFactorsAvg( + in float3 RawEdgeFactors, + in float InsideScale, + out float3 RoundedEdgeTessFactors, + out float RoundedInsideTessFactor, + out float UnroundedInsideTessFactor); + +__intrinsic void ProcessTriTessFactorsMax( + in float3 RawEdgeFactors, + in float InsideScale, + out float3 RoundedEdgeTessFactors, + out float RoundedInsideTessFactor, + out float UnroundedInsideTessFactor); + +__intrinsic void ProcessTriTessFactorsMin( + in float3 RawEdgeFactors, + in float InsideScale, + out float3 RoundedEdgeTessFactors, + out float RoundedInsideTessFactors, + out float UnroundedInsideTessFactors); + +// Degrees to radians +__generic<T : __BuiltinFloatingPointType> __intrinsic T radians(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> radians(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> radians(matrix<T,N,M> x); + +// Approximate reciprocal +__generic<T : __BuiltinFloatingPointType> __intrinsic T rcp(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> rcp(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> rcp(matrix<T,N,M> x); + +// Reflect incident vector across plane with given normal +__generic<T : __BuiltinFloatingPointType, let N : int> +__intrinsic +vector<T,N> reflect(vector<T,N> i, vector<T,N> n); + +// Refract incident vector given surface normal and index of refraction +__generic<T : __BuiltinFloatingPointType, let N : int> +__intrinsic +vector<T,N> refract(vector<T,N> i, vector<T,N> n, float eta); + +// Reverse order of bits +__intrinsic uint reversebits(uint value); +__generic<let N : int> vector<uint,N> reversebits(vector<uint,N> value); + +// Round-to-nearest +__generic<T : __BuiltinFloatingPointType> __intrinsic T round(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> round(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> round(matrix<T,N,M> x); + +// Reciprocal of square root +__generic<T : __BuiltinFloatingPointType> __intrinsic T rsqrt(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> rsqrt(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> rsqrt(matrix<T,N,M> x); + +// Clamp value to [0,1] range +__generic<T : __BuiltinFloatingPointType> __intrinsic T saturate(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> saturate(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> saturate(matrix<T,N,M> x); + + +// Extract sign of value +__generic<T : __BuiltinSignedArithmeticType> __intrinsic int sign(T x); +__generic<T : __BuiltinSignedArithmeticType, let N : int> __intrinsic vector<int,N> sign(vector<T,N> x); +__generic<T : __BuiltinSignedArithmeticType, let N : int, let M : int> __intrinsic matrix<int,N,M> sign(matrix<T,N,M> x); + +)", R"( + + +// Sine +__generic<T : __BuiltinFloatingPointType> __intrinsic T sin(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> sin(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> sin(matrix<T,N,M> x); + +// Sine and cosine +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic void sincos(T x, out T s, out T c); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic void sincos(vector<T,N> x, out vector<T,N> s, out vector<T,N> c); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic void sincos(matrix<T,N,M> x, out matrix<T,N,M> s, out matrix<T,N,M> c); + +// Hyperbolic Sine +__generic<T : __BuiltinFloatingPointType> __intrinsic T sinh(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> sinh(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> sinh(matrix<T,N,M> x); + +// Smooth step (Hermite interpolation) +__generic<T : __BuiltinFloatingPointType> __intrinsic T smoothstep(T min, T max, T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> smoothstep(vector<T,N> min, vector<T,N> max, vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> smoothstep(matrix<T,N,M> min, matrix<T,N,M> max, matrix<T,N,M> x); + +// Square root +__generic<T : __BuiltinFloatingPointType> __intrinsic T sqrt(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> sqrt(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> sqrt(matrix<T,N,M> x); + +// Step function +__generic<T : __BuiltinFloatingPointType> __intrinsic T step(T y, T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> step(vector<T,N> y, vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> step(matrix<T,N,M> y, matrix<T,N,M> x); + +// Tangent +__generic<T : __BuiltinFloatingPointType> __intrinsic T tan(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> tan(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> tan(matrix<T,N,M> x); + +// Hyperbolic tangent +__generic<T : __BuiltinFloatingPointType> __intrinsic T tanh(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> tanh(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> tanh(matrix<T,N,M> x); + +// Legacy texture-fetch operations + +/* +__intrinsic float4 tex1D(sampler1D s, float t); +__intrinsic float4 tex1D(sampler1D s, float t, float ddx, float ddy); +__intrinsic float4 tex1Dbias(sampler1D s, float4 t); +__intrinsic float4 tex1Dgrad(sampler1D s, float t, float ddx, float ddy); +__intrinsic float4 tex1Dlod(sampler1D s, float4 t); +__intrinsic float4 tex1Dproj(sampler1D s, float4 t); + +__intrinsic float4 tex2D(sampler2D s, float2 t); +__intrinsic float4 tex2D(sampler2D s, float2 t, float2 ddx, float2 ddy); +__intrinsic float4 tex2Dbias(sampler2D s, float4 t); +__intrinsic float4 tex2Dgrad(sampler2D s, float2 t, float2 ddx, float2 ddy); +__intrinsic float4 tex2Dlod(sampler2D s, float4 t); +__intrinsic float4 tex2Dproj(sampler2D s, float4 t); + +__intrinsic float4 tex3D(sampler3D s, float3 t); +__intrinsic float4 tex3D(sampler3D s, float3 t, float3 ddx, float3 ddy); +__intrinsic float4 tex3Dbias(sampler3D s, float4 t); +__intrinsic float4 tex3Dgrad(sampler3D s, float3 t, float3 ddx, float3 ddy); +__intrinsic float4 tex3Dlod(sampler3D s, float4 t); +__intrinsic float4 tex3Dproj(sampler3D s, float4 t); + +__intrinsic float4 texCUBE(samplerCUBE s, float3 t); +__intrinsic float4 texCUBE(samplerCUBE s, float3 t, float3 ddx, float3 ddy); +__intrinsic float4 texCUBEbias(samplerCUBE s, float4 t); +__intrinsic float4 texCUBEgrad(samplerCUBE s, float3 t, float3 ddx, float3 ddy); +__intrinsic float4 texCUBElod(samplerCUBE s, float4 t); +__intrinsic float4 texCUBEproj(samplerCUBE s, float4 t); +*/ + +// Matrix transpose +__generic<T : __BuiltinType, let N : int, let M : int> __intrinsic matrix<T,M,N> transpose(matrix<T,N,M> x); + +// Truncate to integer +__generic<T : __BuiltinFloatingPointType> __intrinsic T trunc(T x); +__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> trunc(vector<T,N> x); +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> trunc(matrix<T,N,M> x); + + +)", R"( + +// Shader model 6.0 stuff + +__intrinsic uint GlobalOrderedCountIncrement(uint countToAppendForThisLane); + +__generic<T : __BuiltinType> __intrinsic T QuadReadLaneAt(T sourceValue, int quadLaneID); +__generic<T : __BuiltinType, let N : int> __intrinsic vector<T,N> QuadReadLaneAt(vector<T,N> sourceValue, int quadLaneID); +__generic<T : __BuiltinType, let N : int, let M : int> __intrinsic matrix<T,N,M> QuadReadLaneAt(matrix<T,N,M> sourceValue, int quadLaneID); + +__generic<T : __BuiltinType> __intrinsic T QuadSwapX(T localValue); +__generic<T : __BuiltinType, let N : int> __intrinsic vector<T,N> QuadSwapX(vector<T,N> localValue); +__generic<T : __BuiltinType, let N : int, let M : int> __intrinsic matrix<T,N,M> QuadSwapX(matrix<T,N,M> localValue); + +__generic<T : __BuiltinType> __intrinsic T QuadSwapY(T localValue); +__generic<T : __BuiltinType, let N : int> __intrinsic vector<T,N> QuadSwapY(vector<T,N> localValue); +__generic<T : __BuiltinType, let N : int, let M : int> __intrinsic matrix<T,N,M> QuadSwapY(matrix<T,N,M> localValue); + +__generic<T : __BuiltinIntegerType> __intrinsic T WaveAllBitAnd(T expr); +__generic<T : __BuiltinIntegerType, let N : int> __intrinsic vector<T,N> WaveAllBitAnd(vector<T,N> expr); +__generic<T : __BuiltinIntegerType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveAllBitAnd(matrix<T,N,M> expr); + +__generic<T : __BuiltinIntegerType> __intrinsic T WaveAllBitOr(T expr); +__generic<T : __BuiltinIntegerType, let N : int> __intrinsic vector<T,N> WaveAllBitOr(vector<T,N> expr); +__generic<T : __BuiltinIntegerType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveAllBitOr(matrix<T,N,M> expr); + +__generic<T : __BuiltinIntegerType> __intrinsic T WaveAllBitXor(T expr); +__generic<T : __BuiltinIntegerType, let N : int> __intrinsic vector<T,N> WaveAllBitXor(vector<T,N> expr); +__generic<T : __BuiltinIntegerType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveAllBitXor(matrix<T,N,M> expr); + +__generic<T : __BuiltinArithmeticType> __intrinsic T WaveAllMax(T expr); +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> WaveAllMax(vector<T,N> expr); +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveAllMax(matrix<T,N,M> expr); + +__generic<T : __BuiltinArithmeticType> __intrinsic T WaveAllMin(T expr); +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> WaveAllMin(vector<T,N> expr); +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveAllMin(matrix<T,N,M> expr); + +__generic<T : __BuiltinArithmeticType> __intrinsic T WaveAllProduct(T expr); +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> WaveAllProduct(vector<T,N> expr); +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveAllProduct(matrix<T,N,M> expr); + +__generic<T : __BuiltinArithmeticType> __intrinsic T WaveAllSum(T expr); +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> WaveAllSum(vector<T,N> expr); +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveAllSum(matrix<T,N,M> expr); + +__intrinsic bool WaveAllEqual(bool expr); +__intrinsic bool WaveAllTrue(bool expr); +__intrinsic bool WaveAnyTrue(bool expr); + +uint64_t WaveBallot(bool expr); + +uint WaveGetLaneCount(); +uint WaveGetLaneIndex(); +uint WaveGetOrderedIndex(); + +bool WaveIsHelperLane(); + +bool WaveOnce(); + +__generic<T : __BuiltinArithmeticType> __intrinsic T WavePrefixProduct(T expr); +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> WavePrefixProduct(vector<T,N> expr); +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> WavePrefixProduct(matrix<T,N,M> expr); + +__generic<T : __BuiltinArithmeticType> __intrinsic T WavePrefixSum(T expr); +__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> WavePrefixSum(vector<T,N> expr); +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr); + +__generic<T : __BuiltinType> __intrinsic T WaveReadFirstLane(T expr); +__generic<T : __BuiltinType, let N : int> __intrinsic vector<T,N> WaveReadFirstLane(vector<T,N> expr); +__generic<T : __BuiltinType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveReadFirstLane(matrix<T,N,M> expr); + +__generic<T : __BuiltinType> __intrinsic T WaveReadLaneAt(T expr, int laneIndex); +__generic<T : __BuiltinType, let N : int> __intrinsic vector<T,N> WaveReadLaneAt(vector<T,N> expr, int laneIndex); +__generic<T : __BuiltinType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveReadLaneAt(matrix<T,N,M> expr, int laneIndex); + + +)", R"( + +// `typedef`s to help with the fact that HLSL has been sorta-kinda case insensitive at various points +typedef Texture2D texture2D; + +#line default +)" }; + + +using namespace CoreLib::Basic; + +namespace Slang +{ + namespace Compiler + { + static String stdlibPath; + + String getStdlibPath() + { + if(stdlibPath.Length() != 0) + return stdlibPath; + + StringBuilder pathBuilder; + for( auto cc = __FILE__; *cc; ++cc ) + { + switch( *cc ) + { + case '\n': + case '\t': + case '\\': + pathBuilder << "\\"; + default: + pathBuilder << *cc; + break; + } + } + stdlibPath = pathBuilder.ProduceString(); + + return stdlibPath; + } + + String SlangStdLib::code; + + enum + { + SINT_MASK = 1 << 0, + FLOAT_MASK = 1 << 1, + COMPARISON = 1 << 2, + BOOL_MASK = 1 << 3, + UINT_MASK = 1 << 4, + ASSIGNMENT = 1 << 5, + POSTFIX = 1 << 6, + + INT_MASK = SINT_MASK | UINT_MASK, + ARITHMETIC_MASK = INT_MASK | FLOAT_MASK, + LOGICAL_MASK = INT_MASK | BOOL_MASK, + ANY_MASK = INT_MASK | FLOAT_MASK | BOOL_MASK, + }; + + String SlangStdLib::GetCode() + { + if (code.Length() > 0) + return code; + StringBuilder sb; + + // generate operator overloads + + + + struct OpInfo { IntrinsicOp opCode; char const* opName; unsigned flags; }; + + OpInfo unaryOps[] = { + { IntrinsicOp::Neg, "-", ARITHMETIC_MASK }, + { IntrinsicOp::Not, "!", ANY_MASK }, + { IntrinsicOp::Not, "~", INT_MASK }, + { IntrinsicOp::PreInc, "++", ARITHMETIC_MASK | ASSIGNMENT }, + { IntrinsicOp::PreDec, "--", ARITHMETIC_MASK | ASSIGNMENT }, + { IntrinsicOp::PostInc, "++", ARITHMETIC_MASK | ASSIGNMENT | POSTFIX }, + { IntrinsicOp::PostDec, "--", ARITHMETIC_MASK | ASSIGNMENT | POSTFIX }, + }; + + OpInfo binaryOps[] = { + { IntrinsicOp::Add, "+", ARITHMETIC_MASK }, + { IntrinsicOp::Sub, "-", ARITHMETIC_MASK }, + { IntrinsicOp::Mul, "*", ARITHMETIC_MASK }, + { IntrinsicOp::Div, "/", ARITHMETIC_MASK }, + { IntrinsicOp::Mod, "%", INT_MASK }, + + { IntrinsicOp::And, "&&", LOGICAL_MASK }, + { IntrinsicOp::Or, "||", LOGICAL_MASK }, + + { IntrinsicOp::BitAnd, "&", LOGICAL_MASK }, + { IntrinsicOp::BitOr, "|", LOGICAL_MASK }, + { IntrinsicOp::BitXor, "^", LOGICAL_MASK }, + + { IntrinsicOp::Lsh, "<<", INT_MASK }, + { IntrinsicOp::Rsh, ">>", INT_MASK }, + + { IntrinsicOp::Eql, "==", ANY_MASK | COMPARISON }, + { IntrinsicOp::Neq, "!=", ANY_MASK | COMPARISON }, + + { IntrinsicOp::Greater, ">", ARITHMETIC_MASK | COMPARISON }, + { IntrinsicOp::Less, "<", ARITHMETIC_MASK | COMPARISON }, + { IntrinsicOp::Geq, ">=", ARITHMETIC_MASK | COMPARISON }, + { IntrinsicOp::Leq, "<=", ARITHMETIC_MASK | COMPARISON }, + + { IntrinsicOp::AddAssign, "+=", ASSIGNMENT | ARITHMETIC_MASK }, + { IntrinsicOp::SubAssign, "-=", ASSIGNMENT | ARITHMETIC_MASK }, + { IntrinsicOp::MulAssign, "*=", ASSIGNMENT | ARITHMETIC_MASK }, + { IntrinsicOp::DivAssign, "/=", ASSIGNMENT | ARITHMETIC_MASK }, + { IntrinsicOp::ModAssign, "%=", ASSIGNMENT | ARITHMETIC_MASK }, + { IntrinsicOp::AndAssign, "&=", ASSIGNMENT | LOGICAL_MASK }, + { IntrinsicOp::OrAssign, "|=", ASSIGNMENT | LOGICAL_MASK }, + { IntrinsicOp::XorAssign, "^=", ASSIGNMENT | LOGICAL_MASK }, + { IntrinsicOp::LshAssign, "<<=", ASSIGNMENT | INT_MASK }, + { IntrinsicOp::RshAssign, ">>=", ASSIGNMENT | INT_MASK }, + + + }; + + /* + String floatTypes[] = { "float", "float2", "float3", "float4" }; + String intTypes[] = { "int", "int2", "int3", "int4" }; + String uintTypes[] = { "uint", "uint2", "uint3", "uint4" }; + */ + + String path = getStdlibPath(); + + + +#define EMIT_LINE_DIRECTIVE() sb << "#line " << (__LINE__+1) << " \"" << path << "\"\n" + + // Generate declarations for all the base types + + static const struct { + char const* name; + BaseType tag; + unsigned flags; + } kBaseTypes[] = { + { "void", BaseType::Void, 0 }, + { "int", BaseType::Int, SINT_MASK }, + { "float", BaseType::Float, FLOAT_MASK }, + { "uint", BaseType::UInt, UINT_MASK }, + { "bool", BaseType::Bool, BOOL_MASK }, + { "uint64_t", BaseType::UInt64, UINT_MASK }, + }; + static const int kBaseTypeCount = sizeof(kBaseTypes) / sizeof(kBaseTypes[0]); + for (int tt = 0; tt < kBaseTypeCount; ++tt) + { + EMIT_LINE_DIRECTIVE(); + sb << "__builtin_type(" << int(kBaseTypes[tt].tag) << ") struct " << kBaseTypes[tt].name << "\n{\n"; + + // Declare trait conformances for this type + + sb << "__conforms __BuiltinType;\n"; + + switch( kBaseTypes[tt].tag ) + { + case BaseType::Float: + sb << "__conforms __BuiltinFloatingPointType;\n"; + sb << "__conforms __BuiltinRealType;\n"; + // fall through to: + case BaseType::Int: + sb << "__conforms __BuiltinSignedArithmeticType;\n"; + // fall through to: + case BaseType::UInt: + case BaseType::UInt64: + sb << "__conforms __BuiltinArithmeticType;\n"; + // fall through to: + case BaseType::Bool: + sb << "__conforms __BuiltinType;\n"; + break; + + default: + break; + } + + // Declare initializers to convert from various other types + for( int ss = 0; ss < kBaseTypeCount; ++ss ) + { + if( kBaseTypes[ss].tag == BaseType::Void ) + continue; + + EMIT_LINE_DIRECTIVE(); + sb << "__init(" << kBaseTypes[ss].name << " value);\n"; + } + + sb << "};\n"; + } + + // Declare ad hoc aliases for some types, just to get things compiling + // + // TODO(tfoley): At the very least, `double` should be treated as a distinct type. + sb << "typedef float double;\n"; + sb << "typedef float half;\n"; + + // Declare vector and matrix types + + sb << "__generic<T = float, let N : int = 4> __magic_type(Vector) struct vector\n{\n"; + sb << " __init(T value);\n"; // initialize from single scalar + sb << "};\n"; + sb << "__generic<T = float, let R : int = 4, let C : int = 4> __magic_type(Matrix) struct matrix {};\n"; + + static const struct { + char const* name; + char const* glslPrefix; + } kTypes[] = + { + {"float", ""}, + {"int", "i"}, + {"uint", "u"}, + {"bool", "b"}, + }; + static const int kTypeCount = sizeof(kTypes) / sizeof(kTypes[0]); + + for (int tt = 0; tt < kTypeCount; ++tt) + { + // Declare HLSL vector types + for (int ii = 1; ii <= 4; ++ii) + { + sb << "typedef vector<" << kTypes[tt].name << "," << ii << "> " << kTypes[tt].name << ii << ";\n"; + } + + // Declare HLSL matrix types + for (int rr = 2; rr <= 4; ++rr) + for (int cc = 2; cc <= 4; ++cc) + { + sb << "typedef matrix<" << kTypes[tt].name << "," << rr << "," << cc << "> " << kTypes[tt].name << rr << "x" << cc << ";\n"; + } + } + + static const char* kComponentNames[]{ "x", "y", "z", "w" }; + static const char* kVectorNames[]{ "", "x", "xy", "xyz", "xyzw" }; + + // Need to add constructors to the types above + for (int N = 2; N <= 4; ++N) + { + sb << "__generic<T> __extension vector<T, " << N << ">\n{\n"; + + // initialize from N scalars + sb << "__init("; + for (int ii = 0; ii < N; ++ii) + { + if (ii != 0) sb << ", "; + sb << "T " << kComponentNames[ii]; + } + sb << ");\n"; + + // Initialize from an M-vector and then scalars + for (int M = 2; M < N; ++M) + { + sb << "__init(vector<T," << M << "> " << kVectorNames[M]; + for (int ii = M; ii < N; ++ii) + { + sb << ", T " << kComponentNames[ii]; + } + sb << ");\n"; + } + + // initialize from another vector of the same size + // + // TODO(tfoley): this overlaps with implicit conversions. + // We should look for a way that we can define implicit + // conversions directly in the stdlib instead... + sb << "__generic<U> __init(vector<U," << N << ">);\n"; + + sb << "}\n"; + } + + for( int R = 2; R <= 4; ++R ) + for( int C = 2; C <= 4; ++C ) + { + sb << "__generic<T> __extension matrix<T, " << R << "," << C << ">\n{\n"; + + // initialize from R*C scalars + sb << "__init("; + for( int ii = 0; ii < R; ++ii ) + for( int jj = 0; jj < C; ++jj ) + { + if ((ii+jj) != 0) sb << ", "; + sb << "T m" << ii << jj; + } + sb << ");\n"; + + // Initialize from R C-vectors + sb << "__init("; + for (int ii = 0; ii < R; ++ii) + { + if(ii != 0) sb << ", "; + sb << "vector<T," << C << "> row" << ii; + } + sb << ");\n"; + + + // initialize from another matrix of the same size + // + // TODO(tfoley): See comment about how this overlaps + // with implicit conversion, in the `vector` case above + sb << "__generic<U> __init(matrix<U," << R << ", " << C << ">);\n"; + + sb << "}\n"; + } + + + // Declare built-in texture and sampler types + + sb << "__magic_type(SamplerState," << int(SamplerStateType::Flavor::SamplerState) << ") struct SamplerState {};"; + sb << "__magic_type(SamplerState," << int(SamplerStateType::Flavor::SamplerComparisonState) << ") struct SamplerComparisonState {};"; + + // TODO(tfoley): Need to handle `RW*` variants of texture types as well... + static const struct { + char const* name; + TextureType::Shape baseShape; + int coordCount; + } kBaseTextureTypes[] = { + { "Texture1D", TextureType::Shape1D, 1 }, + { "Texture2D", TextureType::Shape2D, 2 }, + { "Texture3D", TextureType::Shape3D, 3 }, + { "TextureCube", TextureType::ShapeCube, 3 }, + }; + static const int kBaseTextureTypeCount = sizeof(kBaseTextureTypes) / sizeof(kBaseTextureTypes[0]); + + + static const struct { + char const* name; + SlangResourceAccess access; + } kBaseTextureAccessLevels[] = { + { "", SLANG_RESOURCE_ACCESS_READ }, + { "RW", SLANG_RESOURCE_ACCESS_READ_WRITE }, + { "RasterizerOrdered", SLANG_RESOURCE_ACCESS_RASTER_ORDERED }, + }; + static const int kBaseTextureAccessLevelCount = sizeof(kBaseTextureAccessLevels) / sizeof(kBaseTextureAccessLevels[0]); + + for (int tt = 0; tt < kBaseTextureTypeCount; ++tt) + { + char const* name = kBaseTextureTypes[tt].name; + TextureType::Shape baseShape = kBaseTextureTypes[tt].baseShape; + + for (int isArray = 0; isArray < 2; ++isArray) + { + // Arrays of 3D textures aren't allowed + if (isArray && baseShape == TextureType::Shape3D) continue; + + for (int isMultisample = 0; isMultisample < 2; ++isMultisample) + for (int accessLevel = 0; accessLevel < kBaseTextureAccessLevelCount; ++accessLevel) + { + auto access = kBaseTextureAccessLevels[accessLevel].access; + + // TODO: any constraints to enforce on what gets to be multisampled? + + unsigned flavor = baseShape; + if (isArray) flavor |= TextureType::ArrayFlag; + if (isMultisample) flavor |= TextureType::MultisampleFlag; +// if (isShadow) flavor |= TextureType::ShadowFlag; + + flavor |= (access << 8); + + + // emit a generic signature + // TODO: allow for multisample count to come in as well... + sb << "__generic<T = float4> "; + + sb << "__magic_type(Texture," << int(flavor) << ") struct "; + sb << kBaseTextureAccessLevels[accessLevel].name; + sb << name; + if (isMultisample) sb << "MS"; + if (isArray) sb << "Array"; +// if (isShadow) sb << "Shadow"; + sb << "\n{"; + + if( !isMultisample ) + { + sb << "float CalculateLevelOfDetail(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount << " location);\n"; + + sb << "float CalculateLevelOfDetailUnclamped(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount << " location);\n"; + + // TODO: `Gather` operation + // (tricky because it returns a 4-vector of the element type + // of the texture components...) + } + + // TODO: `GetDimensions` operations + + for(int isFloat = 0; isFloat < 2; ++isFloat) + for(int includeMipInfo = 0; includeMipInfo < 2; ++includeMipInfo) + { + char const* t = isFloat ? "out float " : "out UINT "; + + sb << "void GetDimensions("; + if(includeMipInfo) + sb << "UINT mipLevel, "; + + switch(baseShape) + { + case TextureType::Shape1D: + sb << t << "width"; + break; + + case TextureType::Shape2D: + case TextureType::ShapeCube: + sb << t << "width,"; + sb << t << "height"; + break; + + case TextureType::Shape3D: + sb << t << "width,"; + sb << t << "height,"; + sb << t << "depth"; + break; + + default: + assert(!"unexpected"); + break; + } + + if(isArray) + { + sb << ", " << t << "elements"; + } + + if(includeMipInfo) + sb << ", " << t << "numberOfLevels"; + + sb << ");\n"; + } + + // `GetSamplePosition()` + if( isMultisample ) + { + sb << "float2 GetSamplePosition(int s);\n"; + } + + // `Load()` + + if( kBaseTextureTypes[tt].coordCount + isArray < 4 ) + { + sb << "T Load("; + sb << "int" << kBaseTextureTypes[tt].coordCount + isArray + 1 << " location);\n"; + + if( !isMultisample ) + { + sb << "T Load("; + sb << "int" << kBaseTextureTypes[tt].coordCount + isArray + 1 << " location, "; + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; + } + else + { + sb << "T Load("; + sb << "int" << kBaseTextureTypes[tt].coordCount + isArray + 1 << " location, "; + sb << "int sampleIndex, "; + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; + } + } + + if(baseShape != TextureType::ShapeCube) + { + // subscript operator + sb << "__intrinsic __subscript(uint" << kBaseTextureTypes[tt].coordCount + isArray << " location) -> T;\n"; + } + + if( !isMultisample ) + { + // `Sample()` + + sb << "T Sample(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location);\n"; + + if( baseShape != TextureType::ShapeCube ) + { + sb << "T Sample(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; + } + + sb << "T Sample(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + if( baseShape != TextureType::ShapeCube ) + { + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset, "; + } + sb << "float clamp);\n"; + + sb << "T Sample(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + if( baseShape != TextureType::ShapeCube ) + { + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset, "; + } + sb << "float clamp, out uint status);\n"; + + + // `SampleBias()` + sb << "T SampleBias(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, float bias);\n"; + + if( baseShape != TextureType::ShapeCube ) + { + sb << "T SampleBias(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, float bias, "; + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; + } + + // `SampleCmp()` and `SampleCmpLevelZero` + sb << "T SampleCmp(SamplerComparisonState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + sb << "float compareValue"; + sb << ");\n"; + + sb << "T SampleCmpLevelZero(SamplerComparisonState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + sb << "float compareValue"; + sb << ");\n"; + + if( baseShape != TextureType::ShapeCube ) + { + // Note(tfoley): MSDN seems confused, and claims that the `offset` + // parameter for `SampleCmp` is available for everything but 3D + // textures, while `Sample` and `SampleBias` are consistent in + // saying they only exclude `offset` for cube maps (which makes + // sense). I'm going to assume the documentation for `SampleCmp` + // is just wrong. + + sb << "T SampleCmp(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + sb << "float compareValue, "; + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; + + sb << "T SampleCmpLevelZero(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + sb << "float compareValue, "; + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; + } + + sb << "T SampleGrad(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + sb << "float" << kBaseTextureTypes[tt].coordCount << " gradX, "; + sb << "float" << kBaseTextureTypes[tt].coordCount << " gradY"; + sb << ");\n"; + + if( baseShape != TextureType::ShapeCube ) + { + sb << "T SampleGrad(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + sb << "float" << kBaseTextureTypes[tt].coordCount << " gradX, "; + sb << "float" << kBaseTextureTypes[tt].coordCount << " gradY, "; + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; + } + + // `SampleLevel` + + sb << "T SampleLevel(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + sb << "float level);\n"; + + if( baseShape != TextureType::ShapeCube ) + { + sb << "T SampleLevel(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + sb << "float level, "; + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; + } + } + + sb << "\n};\n"; + } + } + } + + // Declare additional built-in generic types + + sb << "__generic<T> __magic_type(ConstantBuffer) struct ConstantBuffer {};\n"; + sb << "__generic<T> __magic_type(TextureBuffer) struct TextureBuffer {};\n"; + + sb << "__generic<T> __magic_type(PackedBuffer) struct PackedBuffer {};\n"; + sb << "__generic<T> __magic_type(Uniform) struct Uniform {};\n"; + sb << "__generic<T> __magic_type(Patch) struct Patch {};\n"; + + + // Stale declarations for GLSL inner-product builtins +#if 0 + sb << "__intrinsic vec3 operator * (vec3, mat3);\n"; + sb << "__intrinsic vec3 operator * (mat3, vec3);\n"; + + sb << "__intrinsic vec4 operator * (vec4, mat4);\n"; + sb << "__intrinsic vec4 operator * (mat4, vec4);\n"; + + sb << "__intrinsic mat3 operator * (mat3, mat3);\n"; + sb << "__intrinsic mat4 operator * (mat4, mat4);\n"; +#endif + +#if 0 + sb << "__intrinsic(And) bool operator && (bool, bool);\n"; + sb << "__intrinsic(Or) bool operator || (bool, bool);\n"; + + for (auto type : intTypes) + { + sb << "__intrinsic(And) bool operator && (bool, " << type << ");\n"; + sb << "__intrinsic(Or) bool operator || (bool, " << type << ");\n"; + sb << "__intrinsic(And) bool operator && (" << type << ", bool);\n"; + sb << "__intrinsic(Or) bool operator || (" << type << ", bool);\n"; + } +#endif + + for (auto op : unaryOps) + { + for (auto type : kBaseTypes) + { + if ((type.flags & op.flags) == 0) + continue; + + char const* fixity = (op.flags & POSTFIX) != 0 ? "__postfix " : "__prefix "; + char const* qual = (op.flags & ASSIGNMENT) != 0 ? "in out " : ""; + + // scalar version + sb << fixity; + sb << "__intrinsic(" << int(op.opCode) << ") " << type.name << " operator" << op.opName << "(" << qual << type.name << " value);\n"; + + // vector version + sb << "__generic<let N : int> "; + sb << fixity; + sb << "__intrinsic(" << int(op.opCode) << ") vector<" << type.name << ",N> operator" << op.opName << "(" << qual << "vector<" << type.name << ",N> value);\n"; + + // matrix version + sb << "__generic<let N : int, let M : int> "; + sb << fixity; + sb << "__intrinsic(" << int(op.opCode) << ") matrix<" << type.name << ",N,M> operator" << op.opName << "(" << qual << "matrix<" << type.name << ",N,M> value);\n"; + } + } + + for (auto op : binaryOps) + { + for (auto type : kBaseTypes) + { + if ((type.flags & op.flags) == 0) + continue; + + char const* leftType = type.name; + char const* rightType = leftType; + char const* resultType = leftType; + + if (op.flags & COMPARISON) resultType = "bool"; + + char const* leftQual = ""; + if(op.flags & ASSIGNMENT) leftQual = "in out "; + + // TODO: handle `SHIFT` + + // scalar version + sb << "__intrinsic(" << int(op.opCode) << ") " << resultType << " operator" << op.opName << "(" << leftQual << leftType << " left, " << rightType << " right);\n"; + + // vector version + sb << "__generic<let N : int> "; + sb << "__intrinsic(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(" << leftQual << "vector<" << leftType << ",N> left, vector<" << rightType << ",N> right);\n"; + + // matrix version + sb << "__generic<let N : int, let M : int> "; + sb << "__intrinsic(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(" << leftQual << "matrix<" << leftType << ",N,M> left, matrix<" << rightType << ",N,M> right);\n"; + } + } + +#if 0 + for (auto op : intUnaryOps) + { + String opName = GetOperatorFunctionName(op); + for (int i = 0; i < 4; i++) + { + auto itype = intTypes[i]; + auto utype = uintTypes[i]; + for (int j = 0; j < 2; j++) + { + auto retType = (op == Operator::Not) ? "bool" : j == 0 ? itype : utype; + sb << "__intrinsic " << retType << " operator " << opName << "(" << (j == 0 ? itype : utype) << ");\n"; + } + } + } + + for (auto op : floatUnaryOps) + { + String opName = GetOperatorFunctionName(op); + for (int i = 0; i < 4; i++) + { + auto type = floatTypes[i]; + auto retType = (op == Operator::Not) ? "bool" : type; + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ");\n"; + } + } + + for (auto op : floatOps) + { + String opName = GetOperatorFunctionName(op); + for (int i = 0; i < 4; i++) + { + auto type = floatTypes[i]; + auto itype = intTypes[i]; + auto utype = uintTypes[i]; + auto retType = ((op >= Operator::Eql && op <= Operator::Leq) || op == Operator::And || op == Operator::Or) ? "bool" : type; + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << type << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << itype << ", " << type << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << utype << ", " << type << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << itype << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << utype << ");\n"; + if (i > 0) + { + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << floatTypes[0] << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << floatTypes[0] << ", " << type << ");\n"; + + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << intTypes[0] << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << intTypes[0] << ", " << type << ");\n"; + + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << uintTypes[0] << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << uintTypes[0] << ", " << type << ");\n"; + } + } + } + + for (auto op : intOps) + { + String opName = GetOperatorFunctionName(op); + for (int i = 0; i < 4; i++) + { + auto type = intTypes[i]; + auto utype = uintTypes[i]; + auto retType = ((op >= Operator::Eql && op <= Operator::Leq) || op == Operator::And || op == Operator::Or) ? "bool" : type; + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << type << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << utype << ", " << type << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << utype << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << utype << ", " << utype << ");\n"; + if (i > 0) + { + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << intTypes[0] << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << intTypes[0] << ", " << type << ");\n"; + + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << uintTypes[0] << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << uintTypes[0] << ", " << type << ");\n"; + } + } + } +#endif + + // Output a suitable `#line` directive to point at our raw stdlib code above + sb << "\n#line " << kLibIncludeStringLine << " \"" << path << "\"\n"; + + int chunkCount = sizeof(LibIncludeStringChunks) / sizeof(LibIncludeStringChunks[0]); + for (int cc = 0; cc < chunkCount; ++cc) + { + sb << LibIncludeStringChunks[cc]; + } + + code = sb.ProduceString(); + return code; + } + + + // GLSL-specific library code + + String glslLibraryCode; + + String getGLSLLibraryCode() + { + if(glslLibraryCode.Length() != 0) + return glslLibraryCode; + + String path = getStdlibPath(); + + StringBuilder sb; + +#define RAW(TEXT) \ + EMIT_LINE_DIRECTIVE(); \ + sb << TEXT; + + static const struct { + char const* name; + char const* glslPrefix; + } kTypes[] = + { + {"float", ""}, + {"int", "i"}, + {"uint", "u"}, + {"bool", "b"}, + }; + static const int kTypeCount = sizeof(kTypes) / sizeof(kTypes[0]); + + for( int tt = 0; tt < kTypeCount; ++tt ) + { + // Declare GLSL aliases for HLSL types + for (int vv = 2; vv <= 4; ++vv) + { + sb << "typedef " << kTypes[tt].name << vv << " " << kTypes[tt].glslPrefix << "vec" << vv << ";\n"; + sb << "typedef " << kTypes[tt].name << vv << "x" << vv << " " << kTypes[tt].glslPrefix << "mat" << vv << ";\n"; + } + for (int rr = 2; rr <= 4; ++rr) + for (int cc = 2; cc <= 4; ++cc) + { + sb << "typedef " << kTypes[tt].name << rr << "x" << cc << " " << kTypes[tt].glslPrefix << "mat" << rr << "x" << cc << ";\n"; + } + } + + // TODO(tfoley): Need to handle `RW*` variants of texture types as well... + static const struct { + char const* name; + TextureType::Shape baseShape; + int coordCount; + } kBaseTextureTypes[] = { + { "1D", TextureType::Shape1D, 1 }, + { "2D", TextureType::Shape2D, 2 }, + { "3D", TextureType::Shape3D, 3 }, + { "Cube", TextureType::ShapeCube, 3 }, + }; + static const int kBaseTextureTypeCount = sizeof(kBaseTextureTypes) / sizeof(kBaseTextureTypes[0]); + + + static const struct { + char const* name; + SlangResourceAccess access; + } kBaseTextureAccessLevels[] = { + { "", SLANG_RESOURCE_ACCESS_READ }, + { "RW", SLANG_RESOURCE_ACCESS_READ_WRITE }, + { "RasterizerOrdered", SLANG_RESOURCE_ACCESS_RASTER_ORDERED }, + }; + static const int kBaseTextureAccessLevelCount = sizeof(kBaseTextureAccessLevels) / sizeof(kBaseTextureAccessLevels[0]); + + for (int tt = 0; tt < kBaseTextureTypeCount; ++tt) + { + char const* shapeName = kBaseTextureTypes[tt].name; + TextureType::Shape baseShape = kBaseTextureTypes[tt].baseShape; + + for (int isArray = 0; isArray < 2; ++isArray) + { + // Arrays of 3D textures aren't allowed + if (isArray && baseShape == TextureType::Shape3D) continue; + + for (int isMultisample = 0; isMultisample < 2; ++isMultisample) + { + auto access = SLANG_RESOURCE_ACCESS_READ; + + // TODO: any constraints to enforce on what gets to be multisampled? + + + unsigned flavor = baseShape; + if (isArray) flavor |= TextureType::ArrayFlag; + if (isMultisample) flavor |= TextureType::MultisampleFlag; +// if (isShadow) flavor |= TextureType::ShadowFlag; + + flavor |= (access << 8); + + StringBuilder nameBuilder; + nameBuilder << shapeName; + if (isMultisample) nameBuilder << "MS"; + if (isArray) nameBuilder << "Array"; + auto name = nameBuilder.ProduceString(); + + sb << "__generic<T> "; + sb << "__magic_type(TextureSampler," << int(flavor) << ") struct "; + sb << "__sampler" << name; + sb << " {};\n"; + + sb << "__generic<T> "; + sb << "__magic_type(Texture," << int(flavor) << ") struct "; + sb << "__texture" << name; + sb << " {};\n"; + + sb << "__generic<T> "; + sb << "__magic_type(GLSLImageType," << int(flavor) << ") struct "; + sb << "__image" << name; + sb << " {};\n"; + + // TODO(tfoley): flesh this out for all the available prefixes + static const struct + { + char const* prefix; + char const* elementType; + } kTextureElementTypes[] = { + { "", "vec4" }, + { "i", "ivec4" }, + { "u", "uvec4" }, + { nullptr, nullptr }, + }; + for( auto ee = kTextureElementTypes; ee->prefix; ++ee ) + { + sb << "typedef __sampler" << name << "<" << ee->elementType << "> " << ee->prefix << "sampler" << name << ";\n"; + sb << "typedef __texture" << name << "<" << ee->elementType << "> " << ee->prefix << "texture" << name << ";\n"; + sb << "typedef __image" << name << "<" << ee->elementType << "> " << ee->prefix << "image" << name << ";\n"; + } + } + } + } + + sb << "__generic<T> __magic_type(GLSLInputParameterBlockType) struct __GLSLInputParameterBlock {};\n"; + sb << "__generic<T> __magic_type(GLSLOutputParameterBlockType) struct __GLSLOutputParameterBlock {};\n"; + sb << "__generic<T> __magic_type(GLSLShaderStorageBufferType) struct __GLSLShaderStorageBuffer {};\n"; + + sb << "__magic_type(SamplerState," << int(SamplerStateType::Flavor::SamplerState) << ") struct sampler {};"; + + sb << "__magic_type(GLSLInputAttachmentType) struct subpassInput {};"; + + // Define additional keywords + sb << "__modifier(GLSLBufferModifier) buffer;\n"; + sb << "__modifier(GLSLWriteOnlyModifier) writeonly;\n"; + sb << "__modifier(GLSLReadOnlyModifier) readonly;\n"; + sb << "__modifier(GLSLPatchModifier) patch;\n"; + + sb << "__modifier(SimpleModifier) flat;\n"; + + glslLibraryCode = sb.ProduceString(); + return glslLibraryCode; + } + + + + // + + void SlangStdLib::Finalize() + { + code = nullptr; + stdlibPath = String(); + glslLibraryCode = String(); + } + + } +} + diff --git a/source/slang/slang-stdlib.h b/source/slang/slang-stdlib.h new file mode 100644 index 000000000..65c70ecb5 --- /dev/null +++ b/source/slang/slang-stdlib.h @@ -0,0 +1,23 @@ +#ifndef SHADER_COMPILER_STD_LIB_H +#define SHADER_COMPILER_STD_LIB_H + +#include "../core/basic.h" + +namespace Slang +{ + namespace Compiler + { + class SlangStdLib + { + private: + static CoreLib::String code; + public: + static CoreLib::String GetCode(); + static void Finalize(); + }; + + CoreLib::String getGLSLLibraryCode(); + } +} + +#endif
\ No newline at end of file diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp new file mode 100644 index 000000000..2f37981c5 --- /dev/null +++ b/source/slang/slang.cpp @@ -0,0 +1,699 @@ +#include "../../slang.h" + +#include "../core/slang-io.h" +#include "../slang/slang-stdlib.h" +#include "../slang/parser.h" +#include "../slang/preprocessor.h" +#include "../slang/reflection.h" +#include "../slang/type-layout.h" + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include <Windows.h> +#undef WIN32_LEAN_AND_MEAN +#undef NOMINMAX +#endif + +using namespace CoreLib::Basic; +using namespace CoreLib::IO; +using namespace Slang::Compiler; + +namespace SlangLib +{ + static void stdlibDiagnosticCallback( + char const* message, + void* userData) + { + fputs(message, stderr); + fflush(stderr); +#ifdef WIN32 + OutputDebugStringA(message); +#endif + } + + class Session + { + public: + bool useCache = false; + CoreLib::String cacheDir; + + RefPtr<ShaderCompiler> compiler; + + RefPtr<Scope> slangLanguageScope; + RefPtr<Scope> hlslLanguageScope; + RefPtr<Scope> glslLanguageScope; + + List<RefPtr<ProgramSyntaxNode>> loadedModuleCode; + + + Session(bool /*pUseCache*/, CoreLib::String /*pCacheDir*/) + { + compiler = CreateShaderCompiler(); + + // Create scopes for various language builtins. + // + // TODO: load these on-demand to avoid parsing + // stdlib code for languages the user won't use. + + slangLanguageScope = new Scope(); + + hlslLanguageScope = new Scope(); + hlslLanguageScope->parent = slangLanguageScope; + + glslLanguageScope = new Scope(); + glslLanguageScope->parent = slangLanguageScope; + + addBuiltinSource(slangLanguageScope, "stdlib", SlangStdLib::GetCode()); + addBuiltinSource(glslLanguageScope, "glsl", getGLSLLibraryCode()); + } + + ~Session() + { + // We need to clean up the strings for the standard library + // code that we might have allocated and loaded into static + // variables (TODO: don't use `static` variables for this stuff) + + SlangStdLib::Finalize(); + + // Ditto for our type represnetation stuff + + ExpressionType::Finalize(); + } + + CompileUnit createPredefUnit() + { + CompileUnit translationUnit; + + + RefPtr<ProgramSyntaxNode> translationUnitSyntax = new ProgramSyntaxNode(); + + TranslationUnitOptions translationUnitOptions; + translationUnit.options = translationUnitOptions; + translationUnit.SyntaxNode = translationUnitSyntax; + + return translationUnit; + } + + void addBuiltinSource( + RefPtr<Scope> const& scope, + String const& path, + String const& source); + }; + + struct CompileRequest + { + // Pointer to parent session + Session* mSession; + + // Input options + CompileOptions Options; + + // Output stuff + DiagnosticSink mSink; + String mDiagnosticOutput; + + RefPtr<CollectionOfTranslationUnits> mCollectionOfTranslationUnits; + + RefPtr<ProgramLayout> mReflectionData; + + CompileResult mResult; + + List<String> mDependencyFilePaths; + + CompileRequest(Session* session) + : mSession(session) + {} + + ~CompileRequest() + {} + + struct IncludeHandlerImpl : IncludeHandler + { + CompileRequest* request; + + List<String> searchDirs; + + virtual bool TryToFindIncludeFile( + CoreLib::String const& pathToInclude, + CoreLib::String const& pathIncludedFrom, + CoreLib::String* outFoundPath, + CoreLib::String* outFoundSource) override + { + String path = Path::Combine(Path::GetDirectoryName(pathIncludedFrom), pathToInclude); + if (File::Exists(path)) + { + *outFoundPath = path; + *outFoundSource = File::ReadAllText(path); + + request->mDependencyFilePaths.Add(path); + + return true; + } + + for (auto & dir : searchDirs) + { + path = Path::Combine(dir, pathToInclude); + if (File::Exists(path)) + { + *outFoundPath = path; + *outFoundSource = File::ReadAllText(path); + + request->mDependencyFilePaths.Add(path); + + return true; + } + } + return false; + } + }; + + + CompileUnit parseTranslationUnit( + TranslationUnitOptions const& translationUnitOptions) + { + auto& options = Options; + + IncludeHandlerImpl includeHandler; + includeHandler.request = this; + + CompileUnit translationUnit; + + RefPtr<Scope> languageScope; + switch( translationUnitOptions.sourceLanguage ) + { + case SourceLanguage::HLSL: + languageScope = mSession->hlslLanguageScope; + break; + + case SourceLanguage::GLSL: + languageScope = mSession->glslLanguageScope; + break; + + case SourceLanguage::Slang: + default: + languageScope = mSession->slangLanguageScope; + break; + } + + + auto& preprocesorDefinitions = options.PreprocessorDefinitions; + + RefPtr<ProgramSyntaxNode> translationUnitSyntax = new ProgramSyntaxNode(); + + for( auto sourceFile : translationUnitOptions.sourceFiles ) + { + auto sourceFilePath = sourceFile->path; + + auto searchDirs = options.SearchDirectories; + searchDirs.Reverse(); + searchDirs.Add(Path::GetDirectoryName(sourceFilePath)); + searchDirs.Reverse(); + includeHandler.searchDirs = searchDirs; + + String source = sourceFile->content; + + auto tokens = preprocessSource( + source, + sourceFilePath, + mResult.GetErrorWriter(), + &includeHandler, + preprocesorDefinitions, + translationUnitSyntax.Ptr()); + + parseSourceFile( + translationUnitSyntax.Ptr(), + options, + tokens, + mResult.GetErrorWriter(), + sourceFilePath, + languageScope); + } + + translationUnit.options = translationUnitOptions; + translationUnit.SyntaxNode = translationUnitSyntax; + + return translationUnit; + } + + int executeCompilerDriverActions() + { + // If we are being asked to do pass-through, then we need to do that here... + if (Options.passThrough != PassThroughMode::None) + { + for( auto& translationUnitOptions : Options.translationUnits ) + { + switch( translationUnitOptions.sourceLanguage ) + { + // We can pass-through code written in a native shading language + case SourceLanguage::GLSL: + case SourceLanguage::HLSL: + break; + + // All other translation units need to be skipped + default: + continue; + } + + auto sourceFile = translationUnitOptions.sourceFiles[0]; + auto sourceFilePath = sourceFile->path; + String source = sourceFile->content; + + mSession->compiler->PassThrough( + source, + sourceFilePath, + Options, + translationUnitOptions); + } + return 0; + } + + // TODO: load the stdlib + + mCollectionOfTranslationUnits = new CollectionOfTranslationUnits(); + + // Parse everything from the input files requested + // + // TODO: this may trigger the loading and/or compilation of additional modules. + for( auto& translationUnitOptions : Options.translationUnits ) + { + auto translationUnit = parseTranslationUnit(translationUnitOptions); + mCollectionOfTranslationUnits->translationUnits.Add(translationUnit); + } + if( mResult.GetErrorCount() != 0 ) + return 1; + + // Now perform semantic checks, emit output, etc. + mSession->compiler->Compile( + mResult, mCollectionOfTranslationUnits.Ptr(), Options); + if(mResult.GetErrorCount() != 0) + return 1; + + mReflectionData = mCollectionOfTranslationUnits->layout; + + return 0; + } + + // Act as expected of the API-based compiler + int executeAPIActions() + { + mResult.mSink = &mSink; + + int err = executeCompilerDriverActions(); + + mDiagnosticOutput = mSink.outputBuffer.ProduceString(); + + if(mSink.GetErrorCount() != 0) + return mSink.GetErrorCount(); + + return err; + } + + int addTranslationUnit(SourceLanguage language, String const& name) + { + int result = Options.translationUnits.Count(); + + TranslationUnitOptions translationUnit; + translationUnit.sourceLanguage = SourceLanguage(language); + + Options.translationUnits.Add(translationUnit); + + return result; + } + + void addTranslationUnitSourceString( + int translationUnitIndex, + String const& path, + String const& source) + { + RefPtr<SourceFile> sourceFile = new SourceFile(); + sourceFile->path = path; + sourceFile->content = source; + + Options.translationUnits[translationUnitIndex].sourceFiles.Add(sourceFile); + } + + void addTranslationUnitSourceFile( + int translationUnitIndex, + String const& path) + { + String source; + try + { + source = File::ReadAllText(path); + } + catch( ... ) + { + // Emit a diagnostic! + mSink.diagnose( + CodePosition(0,0,0,path), + Diagnostics::cannotOpenFile, + path); + return; + } + + addTranslationUnitSourceString( + translationUnitIndex, + path, + source); + + mDependencyFilePaths.Add(path); + } + + int addTranslationUnitEntryPoint( + int translationUnitIndex, + String const& name, + Profile profile) + { + EntryPointOption entryPoint; + entryPoint.name = name; + entryPoint.profile = profile; + + // TODO: realistically want this to be global across all TUs... + int result = Options.translationUnits[translationUnitIndex].entryPoints.Count(); + + Options.translationUnits[translationUnitIndex].entryPoints.Add(entryPoint); + return result; + } + }; + + void Session::addBuiltinSource( + RefPtr<Scope> const& scope, + String const& path, + String const& source) + { + CompileRequest compileRequest(this); + + auto translationUnitIndex = compileRequest.addTranslationUnit(SourceLanguage::Slang, path); + + compileRequest.addTranslationUnitSourceString( + translationUnitIndex, + path, + source); + + int err = compileRequest.executeAPIActions(); + if(err) + { + fprintf(stderr, "%s", compileRequest.mDiagnosticOutput.Buffer()); + +#ifdef _WIN32 + OutputDebugStringA(compileRequest.mDiagnosticOutput.Buffer()); +#endif + + assert(!"error in stdlib"); + } + + // Extract the AST for the code we just parsed + auto syntax = compileRequest.mCollectionOfTranslationUnits->translationUnits[translationUnitIndex].SyntaxNode; + + // HACK(tfoley): mark all declarations in the "stdlib" so + // that we can detect them later (e.g., so we don't emit them) + for (auto m : syntax->Members) + { + auto fromStdLibModifier = new FromStdLibModifier(); + + fromStdLibModifier->next = m->modifiers.first; + m->modifiers.first = fromStdLibModifier; + } + + // Add the resulting code to the appropriate scope + if( !scope->containerDecl ) + { + // We are the first chunk of code to be loaded for this scope + scope->containerDecl = syntax.Ptr(); + } + else + { + // We need to create a new scope to link into the whole thing + auto subScope = new Scope(); + subScope->containerDecl = syntax.Ptr(); + subScope->nextSibling = scope->nextSibling; + scope->nextSibling = subScope; + } + + // We need to retain this AST so that we can use it in other code + // (Note that the `Scope` type does not retain the AST it points to) + loadedModuleCode.Add(syntax); + } +} + +using namespace SlangLib; + +// implementation of C interface + +#define SESSION(x) reinterpret_cast<SlangLib::Session *>(x) +#define REQ(x) reinterpret_cast<SlangLib::CompileRequest*>(x) + +SLANG_API SlangSession* spCreateSession(const char * cacheDir) +{ + return reinterpret_cast<SlangSession *>(new SlangLib::Session((cacheDir ? true : false), cacheDir)); +} + +SLANG_API void spDestroySession( + SlangSession* session) +{ + if(!session) return; + delete SESSION(session); +} + +SLANG_API void spAddBuiltins( + SlangSession* session, + char const* sourcePath, + char const* sourceString) +{ + auto s = SESSION(session); + s->addBuiltinSource( + + // TODO(tfoley): Add ability to directly new builtins to the approriate scope + s->slangLanguageScope, + + sourcePath, + sourceString); +} + + +SLANG_API SlangCompileRequest* spCreateCompileRequest( + SlangSession* session) +{ + auto s = SESSION(session); + auto req = new SlangLib::CompileRequest(s); + return reinterpret_cast<SlangCompileRequest*>(req); +} + +/*! +@brief Destroy a compile request. +*/ +SLANG_API void spDestroyCompileRequest( + SlangCompileRequest* request) +{ + if(!request) return; + auto req = REQ(request); + delete req; +} + +SLANG_API void spSetCompileFlags( + SlangCompileRequest* request, + SlangCompileFlags flags) +{ + REQ(request)->Options.flags = flags; +} + +SLANG_API void spSetCodeGenTarget( + SlangCompileRequest* request, + int target) +{ + REQ(request)->Options.Target = (CodeGenTarget)target; +} + +SLANG_API void spSetPassThrough( + SlangCompileRequest* request, + SlangPassThrough passThrough) +{ + REQ(request)->Options.passThrough = PassThroughMode(passThrough); +} + +SLANG_API void spSetDiagnosticCallback( + SlangCompileRequest* request, + SlangDiagnosticCallback callback, + void const* userData) +{ + if(!request) return; + auto req = REQ(request); + + req->mSink.callback = callback; + req->mSink.callbackUserData = (void*) userData; +} + +SLANG_API void spAddSearchPath( + SlangCompileRequest* request, + const char* searchDir) +{ + REQ(request)->Options.SearchDirectories.Add(searchDir); +} + +SLANG_API void spAddPreprocessorDefine( + SlangCompileRequest* request, + const char* key, + const char* value) +{ + REQ(request)->Options.PreprocessorDefinitions[key] = value; +} + +SLANG_API char const* spGetDiagnosticOutput( + SlangCompileRequest* request) +{ + if(!request) return 0; + auto req = REQ(request); + return req->mDiagnosticOutput.begin(); +} + +// New-fangled compilation API + +SLANG_API int spAddTranslationUnit( + SlangCompileRequest* request, + SlangSourceLanguage language, + char const* name) +{ + auto req = REQ(request); + + return req->addTranslationUnit( + SourceLanguage(language), + name ? name : ""); +} + +SLANG_API void spAddTranslationUnitSourceFile( + SlangCompileRequest* request, + int translationUnitIndex, + char const* path) +{ + if(!request) return; + auto req = REQ(request); + if(!path) return; + if(translationUnitIndex < 0) return; + if(translationUnitIndex >= req->Options.translationUnits.Count()) return; + + req->addTranslationUnitSourceFile( + translationUnitIndex, + path); +} + +// Add a source string to the given translation unit +SLANG_API void spAddTranslationUnitSourceString( + SlangCompileRequest* request, + int translationUnitIndex, + char const* path, + char const* source) +{ + if(!request) return; + auto req = REQ(request); + if(!source) return; + if(translationUnitIndex < 0) return; + if(translationUnitIndex >= req->Options.translationUnits.Count()) return; + + if(!path) path = ""; + + req->addTranslationUnitSourceString( + translationUnitIndex, + path, + source); + +} + +SLANG_API SlangProfileID spFindProfile( + SlangSession* session, + char const* name) +{ + return Profile::LookUp(name).raw; +} + +SLANG_API int spAddTranslationUnitEntryPoint( + SlangCompileRequest* request, + int translationUnitIndex, + char const* name, + SlangProfileID profile) +{ + if(!request) return -1; + auto req = REQ(request); + if(!name) return -1; + if(translationUnitIndex < 0) return -1; + if(translationUnitIndex >= req->Options.translationUnits.Count()) return -1; + + + return req->addTranslationUnitEntryPoint( + translationUnitIndex, + name, + Profile(Profile::RawVal(profile))); +} + + +// Compile in a context that already has its translation units specified +SLANG_API int spCompile( + SlangCompileRequest* request) +{ + auto req = REQ(request); + + int anyErrors = req->executeAPIActions(); + return anyErrors; +} + +SLANG_API int +spGetDependencyFileCount( + SlangCompileRequest* request) +{ + if(!request) return 0; + auto req = REQ(request); + return req->mDependencyFilePaths.Count(); +} + +/** Get the path to a file this compilation dependend on. +*/ +SLANG_API char const* +spGetDependencyFilePath( + SlangCompileRequest* request, + int index) +{ + if(!request) return 0; + auto req = REQ(request); + return req->mDependencyFilePaths[index].begin(); +} + +SLANG_API int +spGetTranslationUnitCount( + SlangCompileRequest* request) +{ + auto req = REQ(request); + return req->mResult.translationUnits.Count(); +} + +// Get the output code associated with a specific translation unit +SLANG_API char const* spGetTranslationUnitSource( + SlangCompileRequest* request, + int translationUnitIndex) +{ + auto req = REQ(request); + return req->mResult.translationUnits[translationUnitIndex].outputSource.Buffer(); +} + +SLANG_API char const* spGetEntryPointSource( + SlangCompileRequest* request, + int translationUnitIndex, + int entryPointIndex) +{ + auto req = REQ(request); + return req->mResult.translationUnits[translationUnitIndex].entryPoints[entryPointIndex].outputSource.Buffer(); + +} + +// Reflection API + +SLANG_API SlangReflection* spGetReflection( + SlangCompileRequest* request) +{ + if( !request ) return 0; + + auto req = REQ(request); + return (SlangReflection*) req->mReflectionData.Ptr(); +} + + +// ... rest of reflection API implementation is in `Reflection.cpp` diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis new file mode 100644 index 000000000..bfc1e7317 --- /dev/null +++ b/source/slang/slang.natvis @@ -0,0 +1,14 @@ +<?xml version="1.0" encoding="utf-8"?> +<AutoVisualizer xmlns="http://schemas.microsoft.com/vstudio/debugger/natvis/2010"> + <Type Name="Slang::Compiler::CFGNode"> + <DisplayString>{{CFG Basic Block}}</DisplayString> + <Expand> + <LinkedListItems> + <Size>kvPairs.FCount</Size> + <HeadPointer>kvPairs.FHead</HeadPointer> + <NextPointer>pNext</NextPointer> + <ValueNode>Value</ValueNode> + </LinkedListItems> + </Expand> + </Type> +</AutoVisualizer>
\ No newline at end of file diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj new file mode 100644 index 000000000..df34e40dc --- /dev/null +++ b/source/slang/slang.vcxproj @@ -0,0 +1,427 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project DefaultTargets="Build" ToolsVersion="14.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup Label="ProjectConfigurations"> + <ProjectConfiguration Include="DebugClang|Win32"> + <Configuration>DebugClang</Configuration> + <Platform>Win32</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="DebugClang|x64"> + <Configuration>DebugClang</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Debug_VS2013|Win32"> + <Configuration>Debug_VS2013</Configuration> + <Platform>Win32</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Debug_VS2013|x64"> + <Configuration>Debug_VS2013</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Debug|Win32"> + <Configuration>Debug</Configuration> + <Platform>Win32</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Debug|x64"> + <Configuration>Debug</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release_VS2013|Win32"> + <Configuration>Release_VS2013</Configuration> + <Platform>Win32</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release_VS2013|x64"> + <Configuration>Release_VS2013</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release|Win32"> + <Configuration>Release</Configuration> + <Platform>Win32</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release|x64"> + <Configuration>Release</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + </ItemGroup> + <PropertyGroup Label="Globals"> + <ProjectGuid>{DB00DA62-0533-4AFD-B59F-A67D5B3A0808}</ProjectGuid> + <Keyword>Win32Proj</Keyword> + <RootNamespace>SpireCore</RootNamespace> + <ProjectName>slang</ProjectName> + <WindowsTargetPlatformVersion>8.1</WindowsTargetPlatformVersion> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <PlatformToolset>v140</PlatformToolset> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug_VS2013|Win32'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <PlatformToolset>v120</PlatformToolset> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='DebugClang|Win32'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <PlatformToolset>v140_clang_3_7</PlatformToolset> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <PlatformToolset>v140</PlatformToolset> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug_VS2013|x64'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <PlatformToolset>v120</PlatformToolset> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='DebugClang|x64'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <PlatformToolset>v140_Clang_3_7</PlatformToolset> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v140</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release_VS2013|Win32'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v120</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v140</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release_VS2013|x64'" Label="Configuration"> + <ConfigurationType>StaticLibrary</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v120</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" /> + <ImportGroup Label="ExtensionSettings"> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + <Import Project="..\..\build\slang-build.props" /> + </ImportGroup> + <ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug_VS2013|Win32'" Label="PropertySheets"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + <Import Project="..\..\build\slang-build.props" /> + </ImportGroup> + <ImportGroup Condition="'$(Configuration)|$(Platform)'=='DebugClang|Win32'" Label="PropertySheets"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + <Import Project="..\..\build\slang-build.props" /> + </ImportGroup> + <ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="PropertySheets"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + <Import Project="..\..\build\slang-build.props" /> + </ImportGroup> + <ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug_VS2013|x64'" Label="PropertySheets"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + <Import Project="..\..\build\slang-build.props" /> + </ImportGroup> + <ImportGroup Condition="'$(Configuration)|$(Platform)'=='DebugClang|x64'" Label="PropertySheets"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + <Import Project="..\..\build\slang-build.props" /> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + <Import Project="..\..\build\slang-build.props" /> + </ImportGroup> + <ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release_VS2013|Win32'" Label="PropertySheets"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + <Import Project="..\..\build\slang-build.props" /> + </ImportGroup> + <ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="PropertySheets"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + <Import Project="..\..\build\slang-build.props" /> + </ImportGroup> + <ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release_VS2013|x64'" Label="PropertySheets"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + <Import Project="..\..\build\slang-build.props" /> + </ImportGroup> + <PropertyGroup Label="UserMacros" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'"> + <LinkIncremental>true</LinkIncremental> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug_VS2013|Win32'"> + <LinkIncremental>true</LinkIncremental> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='DebugClang|Win32'"> + <LinkIncremental>true</LinkIncremental> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <LinkIncremental>true</LinkIncremental> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug_VS2013|x64'"> + <LinkIncremental>true</LinkIncremental> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='DebugClang|x64'"> + <LinkIncremental>true</LinkIncremental> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'"> + <LinkIncremental>false</LinkIncremental> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release_VS2013|Win32'"> + <LinkIncremental>false</LinkIncremental> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <LinkIncremental>false</LinkIncremental> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release_VS2013|x64'"> + <LinkIncremental>false</LinkIncremental> + </PropertyGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'"> + <ClCompile> + <PrecompiledHeader> + </PrecompiledHeader> + <WarningLevel>Level4</WarningLevel> + <Optimization>Disabled</Optimization> + <PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories> + <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary> + <MultiProcessorCompilation>false</MultiProcessorCompilation> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug_VS2013|Win32'"> + <ClCompile> + <PrecompiledHeader> + </PrecompiledHeader> + <WarningLevel>Level4</WarningLevel> + <Optimization>Disabled</Optimization> + <PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories> + <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary> + <MultiProcessorCompilation>false</MultiProcessorCompilation> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='DebugClang|Win32'"> + <ClCompile> + <PrecompiledHeader> + </PrecompiledHeader> + <WarningLevel>EnableAllWarnings</WarningLevel> + <Optimization>Disabled</Optimization> + <PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories> + <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary> + <MultiProcessorCompilation>false</MultiProcessorCompilation> + <RuntimeTypeInfo>true</RuntimeTypeInfo> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <ClCompile> + <PrecompiledHeader> + </PrecompiledHeader> + <WarningLevel>Level4</WarningLevel> + <Optimization>Disabled</Optimization> + <PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories> + <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary> + <BrowseInformation>true</BrowseInformation> + <MultiProcessorCompilation>false</MultiProcessorCompilation> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + <Bscmake> + <PreserveSbr>true</PreserveSbr> + </Bscmake> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug_VS2013|x64'"> + <ClCompile> + <PrecompiledHeader> + </PrecompiledHeader> + <WarningLevel>Level4</WarningLevel> + <Optimization>Disabled</Optimization> + <PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories> + <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary> + <BrowseInformation>true</BrowseInformation> + <MultiProcessorCompilation>false</MultiProcessorCompilation> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + <Bscmake> + <PreserveSbr>true</PreserveSbr> + </Bscmake> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='DebugClang|x64'"> + <ClCompile> + <PrecompiledHeader> + </PrecompiledHeader> + <WarningLevel>Level4</WarningLevel> + <Optimization>Disabled</Optimization> + <PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories> + <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary> + <BrowseInformation>true</BrowseInformation> + <MultiProcessorCompilation>false</MultiProcessorCompilation> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + <Bscmake> + <PreserveSbr>true</PreserveSbr> + </Bscmake> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <PrecompiledHeader> + </PrecompiledHeader> + <Optimization>MaxSpeed</Optimization> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories> + <RuntimeLibrary>MultiThreaded</RuntimeLibrary> + <MultiProcessorCompilation>false</MultiProcessorCompilation> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release_VS2013|Win32'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <PrecompiledHeader> + </PrecompiledHeader> + <Optimization>MaxSpeed</Optimization> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories> + <RuntimeLibrary>MultiThreaded</RuntimeLibrary> + <MultiProcessorCompilation>false</MultiProcessorCompilation> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <ClCompile> + <WarningLevel>Level4</WarningLevel> + <PrecompiledHeader> + </PrecompiledHeader> + <Optimization>MaxSpeed</Optimization> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories> + <RuntimeLibrary>MultiThreaded</RuntimeLibrary> + <MultiProcessorCompilation>false</MultiProcessorCompilation> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release_VS2013|x64'"> + <ClCompile> + <WarningLevel>Level4</WarningLevel> + <PrecompiledHeader> + </PrecompiledHeader> + <Optimization>MaxSpeed</Optimization> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories> + <RuntimeLibrary>MultiThreaded</RuntimeLibrary> + <MultiProcessorCompilation>false</MultiProcessorCompilation> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + </Link> + </ItemDefinitionGroup> + <ItemGroup> + <Natvis Include="slang.natvis" /> + </ItemGroup> + <ItemGroup> + <ClInclude Include="compiled-program.h" /> + <ClInclude Include="compiler.h" /> + <ClInclude Include="diagnostic-defs.h" /> + <ClInclude Include="diagnostics.h" /> + <ClInclude Include="emit.h" /> + <ClInclude Include="intrinsic-defs.h" /> + <ClInclude Include="lexer.h" /> + <ClInclude Include="lookup.h" /> + <ClInclude Include="parameter-binding.h" /> + <ClInclude Include="parser.h" /> + <ClInclude Include="preprocessor.h" /> + <ClInclude Include="profile-defs.h" /> + <ClInclude Include="profile.h" /> + <ClInclude Include="reflection.h" /> + <ClInclude Include="slang-stdlib.h" /> + <ClInclude Include="source-loc.h" /> + <ClInclude Include="syntax-visitors.h" /> + <ClInclude Include="syntax.h" /> + <ClInclude Include="token-defs.h" /> + <ClInclude Include="token.h" /> + <ClInclude Include="type-layout.h" /> + </ItemGroup> + <ItemGroup> + <ClCompile Include="check.cpp" /> + <ClCompile Include="compiler.cpp" /> + <ClCompile Include="diagnostics.cpp" /> + <ClCompile Include="emit.cpp" /> + <ClCompile Include="lexer.cpp" /> + <ClCompile Include="lookup.cpp" /> + <ClCompile Include="parameter-binding.cpp" /> + <ClCompile Include="parser.cpp" /> + <ClCompile Include="preprocessor.cpp" /> + <ClCompile Include="profile.cpp" /> + <ClCompile Include="reflection.cpp" /> + <ClCompile Include="slang-stdlib.cpp" /> + <ClCompile Include="slang.cpp" /> + <ClCompile Include="syntax.cpp" /> + <ClCompile Include="token.cpp" /> + <ClCompile Include="type-layout.cpp" /> + </ItemGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> + <ImportGroup Label="ExtensionTargets"> + </ImportGroup> +</Project>
\ No newline at end of file diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters new file mode 100644 index 000000000..dac58bbf7 --- /dev/null +++ b/source/slang/slang.vcxproj.filters @@ -0,0 +1,48 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup> + <Natvis Include="NatvisFile.natvis" /> + <Natvis Include="slang.natvis" /> + </ItemGroup> + <ItemGroup> + <ClInclude Include="compiled-program.h" /> + <ClInclude Include="compiler.h" /> + <ClInclude Include="diagnostic-defs.h" /> + <ClInclude Include="diagnostics.h" /> + <ClInclude Include="emit.h" /> + <ClInclude Include="intrinsic-defs.h" /> + <ClInclude Include="lexer.h" /> + <ClInclude Include="lookup.h" /> + <ClInclude Include="parameter-binding.h" /> + <ClInclude Include="parser.h" /> + <ClInclude Include="preprocessor.h" /> + <ClInclude Include="profile.h" /> + <ClInclude Include="profile-defs.h" /> + <ClInclude Include="reflection.h" /> + <ClInclude Include="slang-stdlib.h" /> + <ClInclude Include="source-loc.h" /> + <ClInclude Include="syntax.h" /> + <ClInclude Include="syntax-visitors.h" /> + <ClInclude Include="token.h" /> + <ClInclude Include="token-defs.h" /> + <ClInclude Include="type-layout.h" /> + </ItemGroup> + <ItemGroup> + <ClCompile Include="check.cpp" /> + <ClCompile Include="compiler.cpp" /> + <ClCompile Include="diagnostics.cpp" /> + <ClCompile Include="emit.cpp" /> + <ClCompile Include="lexer.cpp" /> + <ClCompile Include="lookup.cpp" /> + <ClCompile Include="parameter-binding.cpp" /> + <ClCompile Include="parser.cpp" /> + <ClCompile Include="preprocessor.cpp" /> + <ClCompile Include="profile.cpp" /> + <ClCompile Include="reflection.cpp" /> + <ClCompile Include="slang.cpp" /> + <ClCompile Include="slang-stdlib.cpp" /> + <ClCompile Include="syntax.cpp" /> + <ClCompile Include="token.cpp" /> + <ClCompile Include="type-layout.cpp" /> + </ItemGroup> +</Project>
\ No newline at end of file diff --git a/source/slang/source-loc.h b/source/slang/source-loc.h new file mode 100644 index 000000000..dc353f402 --- /dev/null +++ b/source/slang/source-loc.h @@ -0,0 +1,47 @@ +// source-loc.h +#ifndef SLANG_SOURCE_LOC_H_INCLUDED +#define SLANG_SOURCE_LOC_H_INCLUDED + +#include "../core/basic.h" + +namespace Slang { +namespace Compiler { + +using namespace CoreLib::Basic; + +class CodePosition +{ +public: + int Line = -1, Col = -1, Pos = -1; + String FileName; + String ToString() + { + StringBuilder sb(100); + sb << FileName; + if (Line != -1) + sb << "(" << Line << ")"; + return sb.ProduceString(); + } + CodePosition() = default; + CodePosition(int line, int col, int pos, String fileName) + { + Line = line; + Col = col; + Pos = pos; + this->FileName = fileName; + } + bool operator < (const CodePosition & pos) const + { + return FileName < pos.FileName || (FileName == pos.FileName && Line < pos.Line) || + (FileName == pos.FileName && Line == pos.Line && Col < pos.Col); + } + bool operator == (const CodePosition & pos) const + { + return FileName == pos.FileName && Line == pos.Line && Col == pos.Col; + } +}; + + +}} + +#endif diff --git a/source/slang/syntax-visitors.h b/source/slang/syntax-visitors.h new file mode 100644 index 000000000..565ac3ace --- /dev/null +++ b/source/slang/syntax-visitors.h @@ -0,0 +1,21 @@ +#ifndef RASTER_RENDERER_SYNTAX_PRINTER_H +#define RASTER_RENDERER_SYNTAX_PRINTER_H + +#include "diagnostics.h" +#include "syntax.h" +#include "compiled-program.h" + +namespace Slang +{ + namespace Compiler + { + class CompileOptions; + class ShaderCompiler; + class ShaderLinkInfo; + class ShaderSymbol; + + SyntaxVisitor * CreateSemanticsVisitor(DiagnosticSink * err, CompileOptions const& options); + } +} + +#endif
\ No newline at end of file diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp new file mode 100644 index 000000000..3050afff6 --- /dev/null +++ b/source/slang/syntax.cpp @@ -0,0 +1,1484 @@ +#include "syntax.h" +#include "syntax-visitors.h" +#include <typeinfo> +#include <assert.h> + +namespace Slang +{ + namespace Compiler + { + // BasicExpressionType + + bool BasicExpressionType::EqualsImpl(ExpressionType * type) + { + auto basicType = dynamic_cast<const BasicExpressionType*>(type); + if (basicType == nullptr) + return false; + return basicType->BaseType == BaseType; + } + + ExpressionType* BasicExpressionType::CreateCanonicalType() + { + // A basic type is already canonical, in our setup + return this; + } + + CoreLib::Basic::String BasicExpressionType::ToString() + { + CoreLib::Basic::StringBuilder res; + + switch (BaseType) + { + case Compiler::BaseType::Int: + res.Append("int"); + break; + case Compiler::BaseType::UInt: + res.Append("uint"); + break; + case Compiler::BaseType::UInt64: + res.Append("uint64_t"); + break; + case Compiler::BaseType::Bool: + res.Append("bool"); + break; + case Compiler::BaseType::Float: + res.Append("float"); + break; + case Compiler::BaseType::Void: + res.Append("void"); + break; + default: + break; + } + return res.ProduceString(); + } + + RefPtr<SyntaxNode> ProgramSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitProgram(this); + } + + RefPtr<SyntaxNode> FunctionSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitFunction(this); + } + + // + + RefPtr<SyntaxNode> ScopeDecl::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitScopeDecl(this); + } + + // + + RefPtr<SyntaxNode> BlockStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitBlockStatement(this); + } + + RefPtr<SyntaxNode> BreakStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitBreakStatement(this); + } + + RefPtr<SyntaxNode> ContinueStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitContinueStatement(this); + } + + RefPtr<SyntaxNode> DoWhileStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitDoWhileStatement(this); + } + + RefPtr<SyntaxNode> EmptyStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitEmptyStatement(this); + } + + RefPtr<SyntaxNode> ForStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitForStatement(this); + } + + RefPtr<SyntaxNode> IfStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitIfStatement(this); + } + + RefPtr<SyntaxNode> ReturnStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitReturnStatement(this); + } + + RefPtr<SyntaxNode> VarDeclrStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitVarDeclrStatement(this); + } + + RefPtr<SyntaxNode> Variable::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitDeclrVariable(this); + } + + RefPtr<SyntaxNode> WhileStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitWhileStatement(this); + } + + RefPtr<SyntaxNode> ExpressionStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitExpressionStatement(this); + } + + RefPtr<SyntaxNode> ConstantExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitConstantExpression(this); + } + + RefPtr<SyntaxNode> IndexExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitIndexExpression(this); + } + RefPtr<SyntaxNode> MemberExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitMemberExpression(this); + } + + // SwizzleExpr + + RefPtr<SyntaxNode> SwizzleExpr::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitSwizzleExpression(this); + } + + // DerefExpr + + RefPtr<SyntaxNode> DerefExpr::Accept(SyntaxVisitor * /*visitor*/) + { + // throw "unimplemented"; + return this; + } + + // + + RefPtr<SyntaxNode> InvokeExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitInvokeExpression(this); + } + + RefPtr<SyntaxNode> TypeCastExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitTypeCastExpression(this); + } + + RefPtr<SyntaxNode> VarExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitVarExpression(this); + } + + // OverloadedExpr + + RefPtr<SyntaxNode> OverloadedExpr::Accept(SyntaxVisitor * /*visitor*/) + { +// throw "unimplemented"; + return this; + } + + // + + RefPtr<SyntaxNode> ParameterSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitParameter(this); + } + + // UsingFileDecl + + RefPtr<SyntaxNode> UsingFileDecl::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitUsingFileDecl(this); + } + + // + + RefPtr<SyntaxNode> StructField::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitStructField(this); + } + RefPtr<SyntaxNode> StructSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitStruct(this); + } + RefPtr<SyntaxNode> ClassSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitClass(this); + } + RefPtr<SyntaxNode> TypeDefDecl::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitTypeDefDecl(this); + } + + RefPtr<SyntaxNode> DiscardStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitDiscardStatement(this); + } + + // BasicExpressionType + + BasicExpressionType* BasicExpressionType::GetScalarType() + { + return this; + } + + // + + bool ExpressionType::Equals(ExpressionType * type) + { + return GetCanonicalType()->EqualsImpl(type->GetCanonicalType()); + } + + bool ExpressionType::Equals(RefPtr<ExpressionType> type) + { + return Equals(type.Ptr()); + } + + bool ExpressionType::EqualsVal(Val* val) + { + if (auto type = dynamic_cast<ExpressionType*>(val)) + return const_cast<ExpressionType*>(this)->Equals(type); + return false; + } + + NamedExpressionType* ExpressionType::AsNamedType() + { + return dynamic_cast<NamedExpressionType*>(this); + } + + RefPtr<Val> ExpressionType::SubstituteImpl(Substitutions* subst, int* ioDiff) + { + int diff = 0; + auto canSubst = GetCanonicalType()->SubstituteImpl(subst, &diff); + + // If nothing changed, then don't drop any sugar that is applied + if (!diff) + return this; + + // If the canonical type changed, then we return a canonical type, + // rather than try to re-construct any amount of sugar + (*ioDiff)++; + return canSubst; + } + + + ExpressionType* ExpressionType::GetCanonicalType() + { + if (!this) return nullptr; + ExpressionType* et = const_cast<ExpressionType*>(this); + if (!et->canonicalType) + { + // TODO(tfoley): worry about thread safety here? + et->canonicalType = et->CreateCanonicalType(); + assert(et->canonicalType); + } + return et->canonicalType; + } + + bool ExpressionType::IsTextureOrSampler() + { + return IsTexture() || IsSampler(); + } + bool ExpressionType::IsStruct() + { + auto declRefType = AsDeclRefType(); + if (!declRefType) return false; + auto structDeclRef = declRefType->declRef.As<StructDeclRef>(); + if (!structDeclRef) return false; + return true; + } + + bool ExpressionType::IsClass() + { + auto declRefType = AsDeclRefType(); + if (!declRefType) return false; + auto classDeclRef = declRefType->declRef.As<ClassDeclRef>(); + if (!classDeclRef) return false; + return true; + } + +#if 0 + RefPtr<ExpressionType> ExpressionType::Bool; + RefPtr<ExpressionType> ExpressionType::UInt; + RefPtr<ExpressionType> ExpressionType::Int; + RefPtr<ExpressionType> ExpressionType::Float; + RefPtr<ExpressionType> ExpressionType::Float2; + RefPtr<ExpressionType> ExpressionType::Void; +#endif + RefPtr<ExpressionType> ExpressionType::Error; + RefPtr<ExpressionType> ExpressionType::initializerListType; + RefPtr<ExpressionType> ExpressionType::Overloaded; + + Dictionary<int, RefPtr<ExpressionType>> ExpressionType::sBuiltinTypes; + Dictionary<String, Decl*> ExpressionType::sMagicDecls; + List<RefPtr<ExpressionType>> ExpressionType::sCanonicalTypes; + + void ExpressionType::Init() + { + Error = new ErrorType(); + initializerListType = new InitializerListType(); + Overloaded = new OverloadGroupType(); + } + void ExpressionType::Finalize() + { + Error = nullptr; + initializerListType = nullptr; + Overloaded = nullptr; + // Note(tfoley): This seems to be just about the only way to clear out a List<T> + sCanonicalTypes = List<RefPtr<ExpressionType>>(); + sBuiltinTypes = Dictionary<int, RefPtr<ExpressionType>>(); + sMagicDecls = Dictionary<String, Decl*>(); + } + bool ArrayExpressionType::EqualsImpl(ExpressionType * type) + { + auto arrType = type->AsArrayType(); + if (!arrType) + return false; + return (ArrayLength == arrType->ArrayLength && BaseType->Equals(arrType->BaseType.Ptr())); + } + ExpressionType* ArrayExpressionType::CreateCanonicalType() + { + auto canonicalBaseType = BaseType->GetCanonicalType(); + auto canonicalArrayType = new ArrayExpressionType(); + sCanonicalTypes.Add(canonicalArrayType); + canonicalArrayType->BaseType = canonicalBaseType; + canonicalArrayType->ArrayLength = ArrayLength; + return canonicalArrayType; + } + int ArrayExpressionType::GetHashCode() + { + if (ArrayLength) + return (BaseType->GetHashCode() * 16777619) ^ ArrayLength->GetHashCode(); + else + return BaseType->GetHashCode(); + } + CoreLib::Basic::String ArrayExpressionType::ToString() + { + if (ArrayLength) + return BaseType->ToString() + "[" + ArrayLength->ToString() + "]"; + else + return BaseType->ToString() + "[]"; + } + RefPtr<SyntaxNode> GenericAppExpr::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitGenericApp(this); + } + + // DeclRefType + + String DeclRefType::ToString() + { + return declRef.GetName(); + } + + int DeclRefType::GetHashCode() + { + return (declRef.GetHashCode() * 16777619) ^ (int)(typeid(this).hash_code()); + } + + bool DeclRefType::EqualsImpl(ExpressionType * type) + { + if (auto declRefType = type->AsDeclRefType()) + { + return declRef.Equals(declRefType->declRef); + } + return false; + } + + ExpressionType* DeclRefType::CreateCanonicalType() + { + // A declaration reference is already canonical + return this; + } + + RefPtr<Val> DeclRefType::SubstituteImpl(Substitutions* subst, int* ioDiff) + { + if (!subst) return this; + + // the case we especially care about is when this type references a declaration + // of a generic parameter, since that is what we might be substituting... + if (auto genericTypeParamDecl = dynamic_cast<GenericTypeParamDecl*>(declRef.GetDecl())) + { + // search for a substitution that might apply to us + for (auto s = subst; s; s = s->outer.Ptr()) + { + // the generic decl associated with the substitution list must be + // the generic decl that declared this parameter + auto genericDecl = s->genericDecl; + if (genericDecl != genericTypeParamDecl->ParentDecl) + continue; + + int index = 0; + for (auto m : genericDecl->Members) + { + if (m.Ptr() == genericTypeParamDecl) + { + // We've found it, so return the corresponding specialization argument + (*ioDiff)++; + return s->args[index]; + } + else if(auto typeParam = m.As<GenericTypeParamDecl>()) + { + index++; + } + else if(auto valParam = m.As<GenericValueParamDecl>()) + { + index++; + } + else + { + } + } + + } + } + + + int diff = 0; + DeclRef substDeclRef = declRef.SubstituteImpl(subst, &diff); + + if (!diff) + return this; + + // Make sure to record the difference! + *ioDiff += diff; + + // Re-construct the type in case we are using a specialized sub-class + return DeclRefType::Create(substDeclRef); + } + + static RefPtr<ExpressionType> ExtractGenericArgType(RefPtr<Val> val) + { + auto type = val.As<ExpressionType>(); + assert(type.Ptr()); + return type; + } + + static RefPtr<IntVal> ExtractGenericArgInteger(RefPtr<Val> val) + { + auto intVal = val.As<IntVal>(); + assert(intVal.Ptr()); + return intVal; + } + + // TODO: need to figure out how to unify this with the logic + // in the generic case... + DeclRefType* DeclRefType::Create(DeclRef declRef) + { + if (auto builtinMod = declRef.GetDecl()->FindModifier<BuiltinTypeModifier>()) + { + auto type = new BasicExpressionType(builtinMod->tag); + type->declRef = declRef; + return type; + } + else if (auto magicMod = declRef.GetDecl()->FindModifier<MagicTypeModifier>()) + { + Substitutions* subst = declRef.substitutions.Ptr(); + + if (magicMod->name == "SamplerState") + { + auto type = new SamplerStateType(); + type->declRef = declRef; + type->flavor = SamplerStateType::Flavor(magicMod->tag); + return type; + } + else if (magicMod->name == "Vector") + { + assert(subst && subst->args.Count() == 2); + auto vecType = new VectorExpressionType(); + vecType->declRef = declRef; + vecType->elementType = ExtractGenericArgType(subst->args[0]); + vecType->elementCount = ExtractGenericArgInteger(subst->args[1]); + return vecType; + } + else if (magicMod->name == "Matrix") + { + assert(subst && subst->args.Count() == 3); + auto matType = new MatrixExpressionType(); + matType->declRef = declRef; + return matType; + } + else if (magicMod->name == "Texture") + { + assert(subst && subst->args.Count() >= 1); + auto textureType = new TextureType( + TextureType::Flavor(magicMod->tag), + ExtractGenericArgType(subst->args[0])); + textureType->declRef = declRef; + return textureType; + } + else if (magicMod->name == "TextureSampler") + { + assert(subst && subst->args.Count() >= 1); + auto textureType = new TextureSamplerType( + TextureType::Flavor(magicMod->tag), + ExtractGenericArgType(subst->args[0])); + textureType->declRef = declRef; + return textureType; + } + else if (magicMod->name == "GLSLImageType") + { + assert(subst && subst->args.Count() >= 1); + auto textureType = new GLSLImageType( + TextureType::Flavor(magicMod->tag), + ExtractGenericArgType(subst->args[0])); + textureType->declRef = declRef; + return textureType; + } + + #define CASE(n,T) \ + else if(magicMod->name == #n) { \ + assert(subst && subst->args.Count() == 1); \ + auto type = new T(); \ + type->elementType = ExtractGenericArgType(subst->args[0]); \ + type->declRef = declRef; \ + return type; \ + } + + CASE(ConstantBuffer, ConstantBufferType) + CASE(TextureBuffer, TextureBufferType) + CASE(GLSLInputParameterBlockType, GLSLInputParameterBlockType) + CASE(GLSLOutputParameterBlockType, GLSLOutputParameterBlockType) + CASE(GLSLShaderStorageBufferType, GLSLShaderStorageBufferType) + + CASE(PackedBuffer, PackedBufferType) + CASE(Uniform, UniformBufferType) + CASE(Patch, PatchType) + + CASE(HLSLBufferType, HLSLBufferType) + CASE(HLSLStructuredBufferType, HLSLStructuredBufferType) + CASE(HLSLRWBufferType, HLSLRWBufferType) + CASE(HLSLRWStructuredBufferType, HLSLRWStructuredBufferType) + CASE(HLSLAppendStructuredBufferType, HLSLAppendStructuredBufferType) + CASE(HLSLConsumeStructuredBufferType, HLSLConsumeStructuredBufferType) + CASE(HLSLInputPatchType, HLSLInputPatchType) + CASE(HLSLOutputPatchType, HLSLOutputPatchType) + + CASE(HLSLPointStreamType, HLSLPointStreamType) + CASE(HLSLLineStreamType, HLSLPointStreamType) + CASE(HLSLTriangleStreamType, HLSLPointStreamType) + + #undef CASE + + // "magic" builtin types which have no generic parameters + #define CASE(n,T) \ + else if(magicMod->name == #n) { \ + auto type = new T(); \ + type->declRef = declRef; \ + return type; \ + } + + CASE(HLSLByteAddressBufferType, HLSLByteAddressBufferType) + CASE(HLSLRWByteAddressBufferType, HLSLRWByteAddressBufferType) + CASE(UntypedBufferResourceType, UntypedBufferResourceType) + + CASE(GLSLInputAttachmentType, GLSLInputAttachmentType) + + #undef CASE + + else + { + throw "unimplemented"; + } + } + else + { + return new DeclRefType(declRef); + } + } + + // OverloadGroupType + + String OverloadGroupType::ToString() + { + return "overload group"; + } + + bool OverloadGroupType::EqualsImpl(ExpressionType * /*type*/) + { + return false; + } + + ExpressionType* OverloadGroupType::CreateCanonicalType() + { + return this; + } + + int OverloadGroupType::GetHashCode() + { + return (int)(int64_t)(void*)this; + } + + // InitializerListType + + String InitializerListType::ToString() + { + return "initializer list"; + } + + bool InitializerListType::EqualsImpl(ExpressionType * /*type*/) + { + return false; + } + + ExpressionType* InitializerListType::CreateCanonicalType() + { + return this; + } + + int InitializerListType::GetHashCode() + { + return (int)(int64_t)(void*)this; + } + + // ErrorType + + String ErrorType::ToString() + { + return "error"; + } + + bool ErrorType::EqualsImpl(ExpressionType* type) + { + if (auto errorType = type->As<ErrorType>()) + return true; + return false; + } + + ExpressionType* ErrorType::CreateCanonicalType() + { + return this; + } + + int ErrorType::GetHashCode() + { + return (int)(int64_t)(void*)this; + } + + + // NamedExpressionType + + String NamedExpressionType::ToString() + { + return declRef.GetName(); + } + + bool NamedExpressionType::EqualsImpl(ExpressionType * /*type*/) + { + assert(!"unreachable"); + return false; + } + + ExpressionType* NamedExpressionType::CreateCanonicalType() + { + return declRef.GetType()->GetCanonicalType(); + } + + int NamedExpressionType::GetHashCode() + { + assert(!"unreachable"); + return 0; + } + + // FuncType + + String FuncType::ToString() + { + // TODO: a better approach than this + if (declRef) + return declRef.GetName(); + else + return "/* unknown FuncType */"; + } + + bool FuncType::EqualsImpl(ExpressionType * type) + { + if (auto funcType = type->As<FuncType>()) + { + return declRef == funcType->declRef; + } + return false; + } + + ExpressionType* FuncType::CreateCanonicalType() + { + return this; + } + + int FuncType::GetHashCode() + { + return declRef.GetHashCode(); + } + + // TypeType + + String TypeType::ToString() + { + StringBuilder sb; + sb << "typeof(" << type->ToString() << ")"; + return sb.ProduceString(); + } + + bool TypeType::EqualsImpl(ExpressionType * t) + { + if (auto typeType = t->As<TypeType>()) + { + return t->Equals(typeType->type); + } + return false; + } + + ExpressionType* TypeType::CreateCanonicalType() + { + auto canType = new TypeType(type->GetCanonicalType()); + sCanonicalTypes.Add(canType); + return canType; + } + + int TypeType::GetHashCode() + { + assert(!"unreachable"); + return 0; + } + + // GenericDeclRefType + + String GenericDeclRefType::ToString() + { + // TODO: what is appropriate here? + return "<GenericDeclRef>"; + } + + bool GenericDeclRefType::EqualsImpl(ExpressionType * type) + { + if (auto genericDeclRefType = type->As<GenericDeclRefType>()) + { + return declRef.Equals(genericDeclRefType->declRef); + } + return false; + } + + int GenericDeclRefType::GetHashCode() + { + return declRef.GetHashCode(); + } + + ExpressionType* GenericDeclRefType::CreateCanonicalType() + { + return this; + } + + // ArithmeticExpressionType + + // VectorExpressionType + + String VectorExpressionType::ToString() + { + StringBuilder sb; + sb << "vector<" << elementType->ToString() << "," << elementCount->ToString() << ">"; + return sb.ProduceString(); + } + + BasicExpressionType* VectorExpressionType::GetScalarType() + { + return elementType->AsBasicType(); + } + + // MatrixExpressionType + + String MatrixExpressionType::ToString() + { + StringBuilder sb; + sb << "matrix<" << getElementType()->ToString() << "," << getRowCount()->ToString() << "," << getColumnCount()->ToString() << ">"; + return sb.ProduceString(); + } + + BasicExpressionType* MatrixExpressionType::GetScalarType() + { + return getElementType()->AsBasicType(); + } + + ExpressionType* MatrixExpressionType::getElementType() + { + return this->declRef.substitutions->args[0].As<ExpressionType>().Ptr(); + } + + IntVal* MatrixExpressionType::getRowCount() + { + return this->declRef.substitutions->args[1].As<IntVal>().Ptr(); + } + + IntVal* MatrixExpressionType::getColumnCount() + { + return this->declRef.substitutions->args[2].As<IntVal>().Ptr(); + } + + // + +#if 0 + String GetOperatorFunctionName(Operator op) + { + switch (op) + { + case Operator::Add: + case Operator::AddAssign: + return "+"; + case Operator::Sub: + case Operator::SubAssign: + return "-"; + case Operator::Neg: + return "-"; + case Operator::Not: + return "!"; + case Operator::BitNot: + return "~"; + case Operator::PreInc: + case Operator::PostInc: + return "++"; + case Operator::PreDec: + case Operator::PostDec: + return "--"; + case Operator::Mul: + case Operator::MulAssign: + return "*"; + case Operator::Div: + case Operator::DivAssign: + return "/"; + case Operator::Mod: + case Operator::ModAssign: + return "%"; + case Operator::Lsh: + case Operator::LshAssign: + return "<<"; + case Operator::Rsh: + case Operator::RshAssign: + return ">>"; + case Operator::Eql: + return "=="; + case Operator::Neq: + return "!="; + case Operator::Greater: + return ">"; + case Operator::Less: + return "<"; + case Operator::Geq: + return ">="; + case Operator::Leq: + return "<="; + case Operator::BitAnd: + case Operator::AndAssign: + return "&"; + case Operator::BitXor: + case Operator::XorAssign: + return "^"; + case Operator::BitOr: + case Operator::OrAssign: + return "|"; + case Operator::And: + return "&&"; + case Operator::Or: + return "||"; + case Operator::Sequence: + return ","; + case Operator::Select: + return "?:"; + case Operator::Assign: + return "="; + default: + return ""; + } + } +#endif + String OperatorToString(Operator op) + { + switch (op) + { + case Slang::Compiler::Operator::Neg: + return "-"; + case Slang::Compiler::Operator::Not: + return "!"; + case Slang::Compiler::Operator::PreInc: + return "++"; + case Slang::Compiler::Operator::PreDec: + return "--"; + case Slang::Compiler::Operator::PostInc: + return "++"; + case Slang::Compiler::Operator::PostDec: + return "--"; + case Slang::Compiler::Operator::Mul: + case Slang::Compiler::Operator::MulAssign: + return "*"; + case Slang::Compiler::Operator::Div: + case Slang::Compiler::Operator::DivAssign: + return "/"; + case Slang::Compiler::Operator::Mod: + case Slang::Compiler::Operator::ModAssign: + return "%"; + case Slang::Compiler::Operator::Add: + case Slang::Compiler::Operator::AddAssign: + return "+"; + case Slang::Compiler::Operator::Sub: + case Slang::Compiler::Operator::SubAssign: + return "-"; + case Slang::Compiler::Operator::Lsh: + case Slang::Compiler::Operator::LshAssign: + return "<<"; + case Slang::Compiler::Operator::Rsh: + case Slang::Compiler::Operator::RshAssign: + return ">>"; + case Slang::Compiler::Operator::Eql: + return "=="; + case Slang::Compiler::Operator::Neq: + return "!="; + case Slang::Compiler::Operator::Greater: + return ">"; + case Slang::Compiler::Operator::Less: + return "<"; + case Slang::Compiler::Operator::Geq: + return ">="; + case Slang::Compiler::Operator::Leq: + return "<="; + case Slang::Compiler::Operator::BitAnd: + case Slang::Compiler::Operator::AndAssign: + return "&"; + case Slang::Compiler::Operator::BitXor: + case Slang::Compiler::Operator::XorAssign: + return "^"; + case Slang::Compiler::Operator::BitOr: + case Slang::Compiler::Operator::OrAssign: + return "|"; + case Slang::Compiler::Operator::And: + return "&&"; + case Slang::Compiler::Operator::Or: + return "||"; + case Slang::Compiler::Operator::Assign: + return "="; + default: + return "ERROR"; + } + } + + // TypeExp + + TypeExp TypeExp::Accept(SyntaxVisitor* visitor) + { + return visitor->VisitTypeExp(*this); + } + + // BuiltinTypeModifier + + // MagicTypeModifier + + // GenericDecl + + RefPtr<SyntaxNode> GenericDecl::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitGenericDecl(this); + } + + // GenericTypeParamDecl + + RefPtr<SyntaxNode> GenericTypeParamDecl::Accept(SyntaxVisitor * /*visitor*/) { + //throw "unimplemented"; + return this; + } + + // GenericTypeConstraintDecl + + RefPtr<SyntaxNode> GenericTypeConstraintDecl::Accept(SyntaxVisitor * visitor) + { + return this; + } + + // GenericValueParamDecl + + RefPtr<SyntaxNode> GenericValueParamDecl::Accept(SyntaxVisitor * /*visitor*/) { + //throw "unimplemented"; + return this; + } + + // GenericParamIntVal + + bool GenericParamIntVal::EqualsVal(Val* val) + { + if (auto genericParamVal = dynamic_cast<GenericParamIntVal*>(val)) + { + return declRef.Equals(genericParamVal->declRef); + } + return false; + } + + String GenericParamIntVal::ToString() + { + return declRef.GetName(); + } + + int GenericParamIntVal::GetHashCode() + { + return declRef.GetHashCode() ^ 0xFFFF; + } + + RefPtr<Val> GenericParamIntVal::SubstituteImpl(Substitutions* subst, int* ioDiff) + { + // search for a substitution that might apply to us + for (auto s = subst; s; s = s->outer.Ptr()) + { + // the generic decl associated with the substitution list must be + // the generic decl that declared this parameter + auto genericDecl = s->genericDecl; + if (genericDecl != declRef.GetDecl()->ParentDecl) + continue; + + int index = 0; + for (auto m : genericDecl->Members) + { + if (m.Ptr() == declRef.GetDecl()) + { + // We've found it, so return the corresponding specialization argument + (*ioDiff)++; + return s->args[index]; + } + else if(auto typeParam = m.As<GenericTypeParamDecl>()) + { + index++; + } + else if(auto valParam = m.As<GenericValueParamDecl>()) + { + index++; + } + else + { + } + } + } + + // Nothing found: don't substittue. + return this; + } + + // ExtensionDecl + + RefPtr<SyntaxNode> ExtensionDecl::Accept(SyntaxVisitor * visitor) + { + visitor->VisitExtensionDecl(this); + return this; + } + + // ConstructorDecl + + RefPtr<SyntaxNode> ConstructorDecl::Accept(SyntaxVisitor * visitor) + { + visitor->VisitConstructorDecl(this); + return this; + } + + // SubscriptDecl + + RefPtr<SyntaxNode> SubscriptDecl::Accept(SyntaxVisitor * visitor) + { + visitor->visitSubscriptDecl(this); + return this; + } + + // AccessorDecl + + RefPtr<SyntaxNode> AccessorDecl::Accept(SyntaxVisitor * visitor) + { + visitor->visitAccessorDecl(this); + return this; + } + + // Substitutions + + RefPtr<Substitutions> Substitutions::SubstituteImpl(Substitutions* subst, int* ioDiff) + { + if (!this) return nullptr; + + int diff = 0; + auto outerSubst = outer->SubstituteImpl(subst, &diff); + + List<RefPtr<Val>> substArgs; + for (auto a : args) + { + substArgs.Add(a->SubstituteImpl(subst, &diff)); + } + + if (!diff) return this; + + (*ioDiff)++; + auto substSubst = new Substitutions(); + substSubst->genericDecl = genericDecl; + substSubst->args = substArgs; + return substSubst; + } + + bool Substitutions::Equals(Substitutions* subst) + { + // both must be NULL, or non-NULL + if (!this || !subst) + return !this && !subst; + + if (genericDecl != subst->genericDecl) + return false; + + int argCount = args.Count(); + assert(args.Count() == subst->args.Count()); + for (int aa = 0; aa < argCount; ++aa) + { + if (!args[aa]->EqualsVal(subst->args[aa].Ptr())) + return false; + } + + if (!outer->Equals(subst->outer.Ptr())) + return false; + + return true; + } + + + // DeclRef + + RefPtr<ExpressionType> DeclRef::Substitute(RefPtr<ExpressionType> type) const + { + // No substitutions? Easy. + if (!substitutions) + return type; + + // Otherwise we need to recurse on the type structure + // and apply substitutions where it makes sense + + return type->Substitute(substitutions.Ptr()).As<ExpressionType>(); + } + + DeclRef DeclRef::Substitute(DeclRef declRef) const + { + if(!substitutions) + return declRef; + + int diff = 0; + return declRef.SubstituteImpl(substitutions.Ptr(), &diff); + } + + RefPtr<ExpressionSyntaxNode> DeclRef::Substitute(RefPtr<ExpressionSyntaxNode> expr) const + { + // No substitutions? Easy. + if (!substitutions) + return expr; + + assert(!"unimplemented"); + + return expr; + } + + + DeclRef DeclRef::SubstituteImpl(Substitutions* subst, int* ioDiff) + { + if (!substitutions) return *this; + + int diff = 0; + RefPtr<Substitutions> substSubst = substitutions->SubstituteImpl(subst, &diff); + + if (!diff) + return *this; + + *ioDiff += diff; + + DeclRef substDeclRef; + substDeclRef.decl = decl; + substDeclRef.substitutions = substSubst; + return substDeclRef; + } + + + // Check if this is an equivalent declaration reference to another + bool DeclRef::Equals(DeclRef const& declRef) const + { + if (decl != declRef.decl) + return false; + + if (!substitutions->Equals(declRef.substitutions.Ptr())) + return false; + + return true; + } + + // Convenience accessors for common properties of declarations + String const& DeclRef::GetName() const + { + return decl->Name.Content; + } + + DeclRef DeclRef::GetParent() const + { + auto parentDecl = decl->ParentDecl; + if (auto parentGeneric = dynamic_cast<GenericDecl*>(parentDecl)) + { + // We need to strip away one layer of specialization + assert(substitutions); + return DeclRef(parentGeneric, substitutions->outer); + } + else + { + // If the parent isn't a generic, then it must + // use the same specializations as this declaration + return DeclRef(parentDecl, substitutions); + } + + } + + int DeclRef::GetHashCode() const + { + auto rs = PointerHash<1>::GetHashCode(decl); + if (substitutions) + { + rs *= 16777619; + rs ^= substitutions->GetHashCode(); + } + return rs; + } + + // Val + + RefPtr<Val> Val::Substitute(Substitutions* subst) + { + if (!this) return nullptr; + if (!subst) return this; + int diff = 0; + return SubstituteImpl(subst, &diff); + } + + RefPtr<Val> Val::SubstituteImpl(Substitutions* /*subst*/, int* /*ioDiff*/) + { + // Default behavior is to not substitute at all + return this; + } + + // IntVal + + int GetIntVal(RefPtr<IntVal> val) + { + if (auto constantVal = val.As<ConstantIntVal>()) + { + return constantVal->value; + } + assert(!"unexpected"); + return 0; + } + + // ConstantIntVal + + bool ConstantIntVal::EqualsVal(Val* val) + { + if (auto intVal = dynamic_cast<ConstantIntVal*>(val)) + return value == intVal->value; + return false; + } + + String ConstantIntVal::ToString() + { + return String(value); + } + + int ConstantIntVal::GetHashCode() + { + return value; + } + + // SwitchStmt + + RefPtr<SyntaxNode> SwitchStmt::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitSwitchStmt(this); + } + + RefPtr<SyntaxNode> CaseStmt::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitCaseStmt(this); + } + + RefPtr<SyntaxNode> DefaultStmt::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitDefaultStmt(this); + } + + // TraitDecl + + RefPtr<SyntaxNode> TraitDecl::Accept(SyntaxVisitor * visitor) + { + visitor->VisitTraitDecl(this); + return this; + } + + // TraitConformanceDecl + + RefPtr<SyntaxNode> TraitConformanceDecl::Accept(SyntaxVisitor * visitor) + { + visitor->VisitTraitConformanceDecl(this); + return this; + } + + // SharedTypeExpr + + RefPtr<SyntaxNode> SharedTypeExpr::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitSharedTypeExpr(this); + } + + // OperatorExpressionSyntaxNode + +#if 0 + void OperatorExpressionSyntaxNode::SetOperator(RefPtr<Scope> scope, Slang::Compiler::Operator op) + { + this->Operator = op; + auto opExpr = new VarExpressionSyntaxNode(); + opExpr->Variable = GetOperatorFunctionName(Operator); + opExpr->scope = scope; + opExpr->Position = this->Position; + this->FunctionExpr = opExpr; + } +#endif + + RefPtr<SyntaxNode> OperatorExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitOperatorExpression(this); + } + + // DeclGroup + + RefPtr<SyntaxNode> DeclGroup::Accept(SyntaxVisitor * visitor) + { + visitor->VisitDeclGroup(this); + return this; + } + + // + + void RegisterBuiltinDecl( + RefPtr<Decl> decl, + RefPtr<BuiltinTypeModifier> modifier) + { + auto type = DeclRefType::Create(DeclRef(decl.Ptr(), nullptr)); + ExpressionType::sBuiltinTypes[(int)modifier->tag] = type; + } + + void RegisterMagicDecl( + RefPtr<Decl> decl, + RefPtr<MagicTypeModifier> modifier) + { + ExpressionType::sMagicDecls[modifier->name] = decl.Ptr(); + } + + RefPtr<Decl> findMagicDecl( + String const& name) + { + return ExpressionType::sMagicDecls[name].GetValue(); + } + + ExpressionType* ExpressionType::GetBool() + { + return sBuiltinTypes[(int)BaseType::Bool].GetValue().Ptr(); + } + + ExpressionType* ExpressionType::GetFloat() + { + return sBuiltinTypes[(int)BaseType::Float].GetValue().Ptr(); + } + + ExpressionType* ExpressionType::GetInt() + { + return sBuiltinTypes[(int)BaseType::Int].GetValue().Ptr(); + } + + ExpressionType* ExpressionType::GetUInt() + { + return sBuiltinTypes[(int)BaseType::UInt].GetValue().Ptr(); + } + + ExpressionType* ExpressionType::GetVoid() + { + return sBuiltinTypes[(int)BaseType::Void].GetValue().Ptr(); + } + + ExpressionType* ExpressionType::getInitializerListType() + { + return initializerListType.Ptr(); + } + + ExpressionType* ExpressionType::GetError() + { + return ExpressionType::Error.Ptr(); + } + + // + + RefPtr<SyntaxNode> UnparsedStmt::Accept(SyntaxVisitor * visitor) + { + return this; + } + + // + + RefPtr<SyntaxNode> InitializerListExpr::Accept(SyntaxVisitor * visitor) + { + return visitor->visitInitializerListExpr(this); + } + + // + + RefPtr<SyntaxNode> ModifierDecl::Accept(SyntaxVisitor * visitor) + { + return this; + } + + // + + RefPtr<SyntaxNode> EmptyDecl::Accept(SyntaxVisitor * visitor) + { + return this; + } + + // + + SyntaxNodeBase* createInstanceOfSyntaxClassByName( + String const& name) + { + if(0) {} + #define CASE(NAME) \ + else if(name == #NAME) return new NAME() + + CASE(GLSLBufferModifier); + CASE(GLSLWriteOnlyModifier); + CASE(GLSLReadOnlyModifier); + CASE(GLSLPatchModifier); + CASE(SimpleModifier); + + #undef CASE + else + { + assert(!"unexpected"); + return nullptr; + } + } + + IntrinsicOp findIntrinsicOp(char const* name) + { + // TODO: need to make this faster by using a dictionary... + + if (0) {} +#define INTRINSIC(NAME) else if(strcmp(name, #NAME) == 0) return IntrinsicOp::NAME; +#include "intrinsic-defs.h" + + return IntrinsicOp::Unknown; + } + + } +}
\ No newline at end of file diff --git a/source/slang/syntax.h b/source/slang/syntax.h new file mode 100644 index 000000000..9e4486d4e --- /dev/null +++ b/source/slang/syntax.h @@ -0,0 +1,2771 @@ +#ifndef RASTER_RENDERER_SYNTAX_H +#define RASTER_RENDERER_SYNTAX_H + +#include "../core/basic.h" +#include "Lexer.h" +#include "Profile.h" + +#include "../../Slang.h" + +#include <assert.h> + +namespace Slang +{ + namespace Compiler + { + using namespace CoreLib::Basic; + class SyntaxVisitor; + class FunctionSyntaxNode; + + class SyntaxNodeBase : public RefObject + { + public: + CodePosition Position; + }; + + + + + // + // Other modifiers may have more elaborate data, and so + // are represented as heap-allocated objects, in a linked + // list. + // + class Modifier : public SyntaxNodeBase + { + public: + // Next modifier in linked list of modifiers on same piece of syntax + RefPtr<Modifier> next; + + // The token that was used to name this modifier. + Token nameToken; + }; + +#define SIMPLE_MODIFIER(NAME) \ + class NAME##Modifier : public Modifier {} + + SIMPLE_MODIFIER(Uniform); + SIMPLE_MODIFIER(In); + SIMPLE_MODIFIER(Out); + SIMPLE_MODIFIER(Const); + SIMPLE_MODIFIER(Instance); + SIMPLE_MODIFIER(Builtin); + SIMPLE_MODIFIER(Inline); + SIMPLE_MODIFIER(Public); + SIMPLE_MODIFIER(Require); + SIMPLE_MODIFIER(Param); + SIMPLE_MODIFIER(Extern); + SIMPLE_MODIFIER(Input); + SIMPLE_MODIFIER(Transparent); + SIMPLE_MODIFIER(FromStdLib); + SIMPLE_MODIFIER(Prefix); + SIMPLE_MODIFIER(Postfix); + +#undef SIMPLE_MODIFIER + + enum class IntrinsicOp + { + Unknown = 0, +#define INTRINSIC(NAME) NAME, +#include "intrinsic-defs.h" + }; + + IntrinsicOp findIntrinsicOp(char const* name); + + class IntrinsicModifier : public Modifier + { + public: + // token that names the intrinsic op + Token opToken; + + // The opcode for the intrinsic operation + IntrinsicOp op = IntrinsicOp::Unknown; + }; + + + class InOutModifier : public OutModifier {}; + + // This is a special sentinel modifier that gets added + // to the list when we have multiple variable declarations + // all sharing the same modifiers: + // + // static uniform int a : FOO, *b : register(x0); + // + // In this case both `a` and `b` share the syntax + // for part of their modifier list, but then have + // their own modifiers as well: + // + // a: SemanticModifier("FOO") --> SharedModifiers --> StaticModifier --> UniformModifier + // / + // b: RegisterModifier("x0") / + // + class SharedModifiers : public Modifier {}; + + // A GLSL `layout` modifier + // + // We use a distinct modifier for each key that + // appears within the `layout(...)` construct, + // and each key might have an optional value token. + // + // TODO: We probably want a notion of "modifier groups" + // so that we can recover good source location info + // for modifiers that were part of the same vs. + // different constructs. + class GLSLLayoutModifier : public Modifier + { + public: + // THe token used to introduce the modifier is stored + // as the `nameToken` field. + + // TODO: may want to accept a full expression here + Token valToken; + }; + + // We divide GLSL `layout` modifiers into those we have parsed + // (in the sense of having some notion of their semantics), and + // those we have not. + class GLSLParsedLayoutModifier : public GLSLLayoutModifier {}; + class GLSLUnparsedLayoutModifier : public GLSLLayoutModifier {}; + + // Specific cases for known GLSL `layout` modifiers that we need to work with + class GLSLConstantIDLayoutModifier : public GLSLParsedLayoutModifier {}; + class GLSLBindingLayoutModifier : public GLSLParsedLayoutModifier {}; + class GLSLSetLayoutModifier : public GLSLParsedLayoutModifier {}; + class GLSLLocationLayoutModifier : public GLSLParsedLayoutModifier {}; + + // A catch-all for single-keyword modifiers + class SimpleModifier : public Modifier {}; + + // Some GLSL-specific modifiers + class GLSLBufferModifier : public SimpleModifier {}; + class GLSLWriteOnlyModifier : public SimpleModifier {}; + class GLSLReadOnlyModifier : public SimpleModifier {}; + class GLSLPatchModifier : public SimpleModifier {}; + + // Indicates that this is a variable declaration that corresponds to + // a parameter block declaration in the source program. + class ImplicitParameterBlockVariableModifier : public Modifier {}; + + // Indicates that this is a type that corresponds to the element + // type of a parameter block declaration in the source program. + class ImplicitParameterBlockElementTypeModifier : public Modifier {}; + + // An HLSL semantic + class HLSLSemantic : public Modifier + { + public: + Token name; + }; + + + // An HLSL semantic that affects layout + class HLSLLayoutSemantic : public HLSLSemantic + { + public: + Token registerName; + Token componentMask; + }; + + // An HLSL `register` semantic + class HLSLRegisterSemantic : public HLSLLayoutSemantic + { + }; + + // TODO(tfoley): `packoffset` + class HLSLPackOffsetSemantic : public HLSLLayoutSemantic + { + }; + + // An HLSL semantic that just associated a declaration with a semantic name + class HLSLSimpleSemantic : public HLSLSemantic + { + }; + + // GLSL + + // Directives that came in via the preprocessor, but + // that we need to keep around for later steps + class GLSLPreprocessorDirective : public Modifier + { + }; + + // A GLSL `#version` directive + class GLSLVersionDirective : public GLSLPreprocessorDirective + { + public: + // Token giving the version number to use + Token versionNumberToken; + + // Optional token giving the sub-profile to be used + Token glslProfileToken; + }; + + // A GLSL `#extension` directive + class GLSLExtensionDirective : public GLSLPreprocessorDirective + { + public: + // Token giving the version number to use + Token extensionNameToken; + + // Optional token giving the sub-profile to be used + Token dispositionToken; + }; + + class ParameterBlockReflectionName : public Modifier + { + public: + Token nameToken; + }; + + // Helper class for iterating over a list of heap-allocated modifiers + struct ModifierList + { + struct Iterator + { + Modifier* current; + + Modifier* operator*() + { + return current; + } + + void operator++() + { + current = current->next.Ptr(); + } + + bool operator!=(Iterator other) + { + return current != other.current; + }; + + Iterator() + : current(nullptr) + {} + + Iterator(Modifier* modifier) + : current(modifier) + {} + }; + + ModifierList() + : modifiers(nullptr) + {} + + ModifierList(Modifier* modifiers) + : modifiers(modifiers) + {} + + Iterator begin() { return Iterator(modifiers); } + Iterator end() { return Iterator(nullptr); } + + Modifier* modifiers; + }; + + // Helper class for iterating over heap-allocated modifiers + // of a specific type. + template<typename T> + struct FilteredModifierList + { + struct Iterator + { + Modifier* current; + + T* operator*() + { + return (T*)current; + } + + void operator++() + { + current = Adjust(current->next.Ptr()); + } + + bool operator!=(Iterator other) + { + return current != other.current; + }; + + Iterator() + : current(nullptr) + {} + + Iterator(Modifier* modifier) + : current(modifier) + {} + }; + + FilteredModifierList() + : modifiers(nullptr) + {} + + FilteredModifierList(Modifier* modifiers) + : modifiers(Adjust(modifiers)) + {} + + Iterator begin() { return Iterator(modifiers); } + Iterator end() { return Iterator(nullptr); } + + static Modifier* Adjust(Modifier* modifier) + { + Modifier* m = modifier; + for (;;) + { + if (!m) return m; + if (dynamic_cast<T*>(m)) return m; + m = m->next.Ptr(); + } + } + + Modifier* modifiers; + }; + + // A set of modifiers attached to a syntax node + struct Modifiers + { + // The first modifier in the linked list of heap-allocated modifiers + RefPtr<Modifier> first; + + template<typename T> + FilteredModifierList<T> getModifiersOfType() { return FilteredModifierList<T>(first.Ptr()); } + + // Find the first modifier of a given type, or return `nullptr` if none is found. + template<typename T> + T* findModifier() + { + return *getModifiersOfType<T>().begin(); + } + + template<typename T> + bool hasModifier() { return findModifier<T>() != nullptr; } + + FilteredModifierList<Modifier>::Iterator begin() { return FilteredModifierList<Modifier>::Iterator(first.Ptr()); } + FilteredModifierList<Modifier>::Iterator end() { return FilteredModifierList<Modifier>::Iterator(nullptr); } + }; + + + enum class BaseType + { + // Note(tfoley): These are ordered in terms of promotion rank, so be vareful when messing with this + + Void = 0, + Bool, + Int, + UInt, + UInt64, + Float, +#if 0 + Texture2D = 48, + TextureCube = 49, + Texture2DArray = 50, + Texture2DShadow = 51, + TextureCubeShadow = 52, + Texture2DArrayShadow = 53, + Texture3D = 54, + SamplerState = 4096, SamplerComparisonState = 4097, + Error = 16384, +#endif + }; + + class Decl; + class StructSyntaxNode; + class BasicExpressionType; + class ArrayExpressionType; + class TypeDefDecl; + class DeclRefType; + class NamedExpressionType; + class TypeType; + class GenericDeclRefType; + class VectorExpressionType; + class MatrixExpressionType; + class ArithmeticExpressionType; + class GenericDecl; + class Substitutions; + class TextureType; + class SamplerStateType; + + // A compile-time constant value (usually a type) + class Val : public RefObject + { + public: + // construct a new value by applying a set of parameter + // substitutions to this one + RefPtr<Val> Substitute(Substitutions* subst); + + // Lower-level interface for substition. Like the basic + // `Substitute` above, but also takes a by-reference + // integer parameter that should be incremented when + // returning a modified value (this can help the caller + // decide whether they need to do anything). + virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff); + + virtual bool EqualsVal(Val* val) = 0; + virtual String ToString() = 0; + virtual int GetHashCode() = 0; + bool operator == (const Val & v) + { + return EqualsVal(const_cast<Val*>(&v)); + } + }; + + // A compile-time integer (may not have a specific concrete value) + class IntVal : public Val + { + }; + + // Try to extract a simple integer value from an `IntVal`. + // This fill assert-fail if the object doesn't represent a literal value. + int GetIntVal(RefPtr<IntVal> val); + + // Trivial case of a value that is just a constant integer + class ConstantIntVal : public IntVal + { + public: + int value; + + ConstantIntVal(int value) + : value(value) + {} + + virtual bool EqualsVal(Val* val) override; + virtual String ToString() override; + virtual int GetHashCode() override; + }; + + // TODO(tfoley): classes for more general compile-time integers, + // including references to template parameters + + // A type, representing a classifier for some term in the AST. + // + // Types can include "sugar" in that they may refer to a + // `typedef` which gives them a good name when printed as + // part of diagnostic messages. + // + // In order to operation on types, though, we often want + // to look past any sugar, and operate on an underlying + // "canonical" type. The reprsentation caches a pointer to + // a canonical type on every type, so we can easily + // operate on the raw representation when needed. + class ExpressionType : public Val + { + public: + static RefPtr<ExpressionType> Error; + static RefPtr<ExpressionType> initializerListType; + static RefPtr<ExpressionType> Overloaded; + + static Dictionary<int, RefPtr<ExpressionType>> sBuiltinTypes; + static Dictionary<String, Decl*> sMagicDecls; + + // Note: just exists to make sure we can clean up + // canonical types we create along the way + static List<RefPtr<ExpressionType>> sCanonicalTypes; + + + + static ExpressionType* GetBool(); + static ExpressionType* GetFloat(); + static ExpressionType* GetInt(); + static ExpressionType* GetUInt(); + static ExpressionType* GetVoid(); + static ExpressionType* getInitializerListType(); + static ExpressionType* GetError(); + + public: + virtual String ToString() = 0; + + bool Equals(ExpressionType * type); + bool Equals(RefPtr<ExpressionType> type); + + bool IsVectorType() { return As<VectorExpressionType>() != nullptr; } + bool IsArray() { return As<ArrayExpressionType>() != nullptr; } + + template<typename T> + T* As() + { + return dynamic_cast<T*>(GetCanonicalType()); + } + + // Convenience/legacy wrappers for `As<>` + ArithmeticExpressionType * AsArithmeticType() { return As<ArithmeticExpressionType>(); } + BasicExpressionType * AsBasicType() { return As<BasicExpressionType>(); } + VectorExpressionType * AsVectorType() { return As<VectorExpressionType>(); } + MatrixExpressionType * AsMatrixType() { return As<MatrixExpressionType>(); } + ArrayExpressionType * AsArrayType() { return As<ArrayExpressionType>(); } + + DeclRefType* AsDeclRefType() { return As<DeclRefType>(); } + + NamedExpressionType* AsNamedType(); + + bool IsTextureOrSampler(); + bool IsTexture() { return As<TextureType>() != nullptr; } + bool IsSampler() { return As<SamplerStateType>() != nullptr; } + bool IsStruct(); + bool IsClass(); + static void Init(); + static void Finalize(); + ExpressionType* GetCanonicalType(); + + virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override; + + virtual bool EqualsVal(Val* val) override; + protected: + virtual bool EqualsImpl(ExpressionType * type) = 0; + + virtual ExpressionType* CreateCanonicalType() = 0; + ExpressionType* canonicalType = nullptr; + }; + + // A substitution represents a binding of certain + // type-level variables to concrete argument values + class Substitutions : public RefObject + { + public: + // The generic declaration that defines the + // parametesr we are binding to arguments + GenericDecl* genericDecl; + + // The actual values of the arguments + List<RefPtr<Val>> args; + + // Any further substitutions, relating to outer generic declarations + RefPtr<Substitutions> outer; + + // Apply a set of substitutions to the bindings in this substitution + RefPtr<Substitutions> SubstituteImpl(Substitutions* subst, int* ioDiff); + + // Check if these are equivalent substitutiosn to another set + bool Equals(Substitutions* subst); + bool operator == (const Substitutions & subst) + { + return Equals(const_cast<Substitutions*>(&subst)); + } + int GetHashCode() const + { + int rs = 0; + for (auto && v : args) + { + rs ^= v->GetHashCode(); + rs *= 16777619; + } + return rs; + } + }; + + class SyntaxNode : public SyntaxNodeBase + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) = 0; + }; + + class ContainerDecl; + class SpecializeModifier; + + // Represents how much checking has been applied to a declaration. + enum class DeclCheckState : uint8_t + { + // The declaration has been parsed, but not checked + Unchecked, + + // We are in the process of checking the declaration "header" + // (those parts of the declaration needed in order to + // reference it) + CheckingHeader, + + // We are done checking the declaration header. + CheckedHeader, + + // We have checked the declaration fully. + Checked, + }; + + // A syntax node which can have modifiers appled + class ModifiableSyntaxNode : public SyntaxNode + { + public: + Modifiers modifiers; + + template<typename T> + FilteredModifierList<T> GetModifiersOfType() { return FilteredModifierList<T>(modifiers.first.Ptr()); } + + // Find the first modifier of a given type, or return `nullptr` if none is found. + template<typename T> + T* FindModifier() + { + return *GetModifiersOfType<T>().begin(); + } + + template<typename T> + bool HasModifier() { return FindModifier<T>() != nullptr; } + }; + + void addModifier( + RefPtr<ModifiableSyntaxNode> syntax, + RefPtr<Modifier> modifier); + + + // An intermediate type to represent either a single declaration, or a group of declarations + class DeclBase : public ModifiableSyntaxNode + { + public: + }; + + class Decl : public DeclBase + { + public: + ContainerDecl* ParentDecl; + + Token Name; + String const& getName() { return Name.Content; } + Token const& getNameToken() { return Name; } + + + DeclCheckState checkState = DeclCheckState::Unchecked; + + // The next declaration defined in the same container with the same name + Decl* nextInContainerWithSameName = nullptr; + + bool IsChecked(DeclCheckState state) { return checkState >= state; } + void SetCheckState(DeclCheckState state) + { + assert(state >= checkState); + checkState = state; + } + }; + + struct QualType + { + RefPtr<ExpressionType> type; + bool IsLeftValue; + + QualType() + : IsLeftValue(false) + {} + + QualType(RefPtr<ExpressionType> type) + : type(type) + , IsLeftValue(false) + {} + + QualType(ExpressionType* type) + : type(type) + , IsLeftValue(false) + {} + + void operator=(RefPtr<ExpressionType> t) + { + *this = QualType(t); + } + + void operator=(ExpressionType* t) + { + *this = QualType(t); + } + + ExpressionType* Ptr() { return type.Ptr(); } + + operator RefPtr<ExpressionType>() { return type; } + RefPtr<ExpressionType> operator->() { return type; } + }; + + class ExpressionSyntaxNode : public SyntaxNode + { + public: + QualType Type; + ExpressionSyntaxNode() + {} + }; + + + + + // A reference to a declaration, which may include + // substitutions for generic parameters. + struct DeclRef + { + typedef Decl DeclType; + + // The underlying declaration + Decl* decl = nullptr; + Decl* GetDecl() const { return decl; } + + // Optionally, a chain of substititions to perform + RefPtr<Substitutions> substitutions; + + DeclRef() + {} + + DeclRef(Decl* decl, RefPtr<Substitutions> substitutions) + : decl(decl) + , substitutions(substitutions) + {} + + // Apply substitutions to a type or ddeclaration + RefPtr<ExpressionType> Substitute(RefPtr<ExpressionType> type) const; + DeclRef Substitute(DeclRef declRef) const; + + // Apply substitutions to an expression + RefPtr<ExpressionSyntaxNode> Substitute(RefPtr<ExpressionSyntaxNode> expr) const; + + // Apply substitutions to this declaration reference + DeclRef SubstituteImpl(Substitutions* subst, int* ioDiff); + + // Check if this is an equivalent declaration reference to another + bool Equals(DeclRef const& declRef) const; + bool operator == (const DeclRef& other) const + { + return Equals(other); + } + + // Convenience accessors for common properties of declarations + String const& GetName() const; + DeclRef GetParent() const; + + // "dynamic cast" to a more specific declaration reference type + template<typename T> + T As() const + { + T result; + result.decl = dynamic_cast<T::DeclType*>(decl); + result.substitutions = substitutions; + return result; + } + + // Implicit conversion mostly so we can use a `DeclRef` + // in a conditional context + operator Decl*() const + { + return decl; + } + + int GetHashCode() const; + }; + + // Helper macro for defining `DeclRef` subtypes + #define SLANG_DECLARE_DECL_REF(D) \ + typedef D DeclType; \ + D* GetDecl() const { return (D*) decl; } \ + /* */ + + + + // The type of a reference to an overloaded name + class OverloadGroupType : public ExpressionType + { + public: + virtual String ToString() override; + + protected: + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + virtual int GetHashCode() override; + }; + + // The type of an initializer-list expression (before it has + // been coerced to some other type) + class InitializerListType : public ExpressionType + { + public: + virtual String ToString() override; + + protected: + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + virtual int GetHashCode() override; + }; + + // The type of an expression that was erroneous + class ErrorType : public ExpressionType + { + public: + virtual String ToString() override; + + protected: + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + virtual int GetHashCode() override; + }; + + // A type that takes the form of a reference to some declaration + class DeclRefType : public ExpressionType + { + public: + DeclRef declRef; + + virtual String ToString() override; + virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override; + + static DeclRefType* Create(DeclRef declRef); + + protected: + DeclRefType() + {} + DeclRefType(DeclRef declRef) + : declRef(declRef) + {} + virtual int GetHashCode() override; + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + }; + + // Base class for types that can be used in arithmetic expressions + class ArithmeticExpressionType : public DeclRefType + { + public: + virtual BasicExpressionType* GetScalarType() = 0; + }; + + class FunctionDeclBase; + + class BasicExpressionType : public ArithmeticExpressionType + { + public: + BaseType BaseType; + + BasicExpressionType() + { + BaseType = Compiler::BaseType::Int; + } + BasicExpressionType(Compiler::BaseType baseType) + { + BaseType = baseType; + } + virtual CoreLib::Basic::String ToString() override; + protected: + virtual BasicExpressionType* GetScalarType() override; + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + }; + + + class TextureTypeBase : public DeclRefType + { + public: + // The type that results from fetching an element from this texture + RefPtr<ExpressionType> elementType; + + // Bits representing the kind of texture type we are looking at + // (e.g., `Texture2DMS` vs. `TextureCubeArray`) + typedef uint16_t Flavor; + Flavor flavor; + + enum + { + // Mask for the overall "shape" of the texture + ShapeMask = SLANG_RESOURCE_BASE_SHAPE_MASK, + + // Flag for whether the shape has "array-ness" + ArrayFlag = SLANG_TEXTURE_ARRAY_FLAG, + + // Whether or not the texture stores multiple samples per pixel + MultisampleFlag = SLANG_TEXTURE_MULTISAMPLE_FLAG, + + // Whether or not this is a shadow texture + // + // TODO(tfoley): is this even meaningful/used? + // ShadowFlag = 0x80, + }; + + enum Shape : uint8_t + { + Shape1D = SLANG_TEXTURE_1D, + Shape2D = SLANG_TEXTURE_2D, + Shape3D = SLANG_TEXTURE_3D, + ShapeCube = SLANG_TEXTURE_CUBE, + + Shape1DArray = Shape1D | ArrayFlag, + Shape2DArray = Shape2D | ArrayFlag, + // No Shape3DArray + ShapeCubeArray = ShapeCube | ArrayFlag, + }; + + + Shape GetBaseShape() const { return Shape(flavor & ShapeMask); } + bool isArray() const { return (flavor & ArrayFlag) != 0; } + bool isMultisample() const { return (flavor & MultisampleFlag) != 0; } +// bool isShadow() const { return (flavor & ShadowFlag) != 0; } + + SlangResourceShape getShape() const { return flavor & 0xFF; } + SlangResourceAccess getAccess() const { return (flavor >> 8) & 0xFF; } + + TextureTypeBase( + Flavor flavor, + RefPtr<ExpressionType> elementType) + : elementType(elementType) + , flavor(flavor) + {} + }; + + class TextureType : public TextureTypeBase + { + public: + TextureType( + Flavor flavor, + RefPtr<ExpressionType> elementType) + : TextureTypeBase(flavor, elementType) + {} + }; + + // This is a base type for texture/sampler pairs, + // as they exist in, e.g., GLSL + class TextureSamplerType : public TextureTypeBase + { + public: + TextureSamplerType( + Flavor flavor, + RefPtr<ExpressionType> elementType) + : TextureTypeBase(flavor, elementType) + {} + }; + + // This is a base type for `image*` types, as they exist in GLSL + class GLSLImageType : public TextureTypeBase + { + public: + GLSLImageType( + Flavor flavor, + RefPtr<ExpressionType> elementType) + : TextureTypeBase(flavor, elementType) + {} + }; + + class SamplerStateType : public DeclRefType + { + public: + // What flavor of sampler state is this + enum class Flavor : uint8_t + { + SamplerState, + SamplerComparisonState, + }; + Flavor flavor; + }; + + // Other cases of generic types known to the compiler + class BuiltinGenericType : public DeclRefType + { + public: + RefPtr<ExpressionType> elementType; + }; + + // Types that behave like pointers, in that they can be + // dereferenced (implicitly) to access members defined + // in the element type. + class PointerLikeType : public BuiltinGenericType + {}; + + // Generic types used in existing Slang code + // TODO(tfoley): check that these are actually working right... + class PatchType : public PointerLikeType {}; + class StorageBufferType : public BuiltinGenericType {}; + class UniformBufferType : public PointerLikeType {}; + class PackedBufferType : public BuiltinGenericType {}; + + // HLSL buffer-type resources + + class HLSLBufferType : public BuiltinGenericType {}; + class HLSLRWBufferType : public BuiltinGenericType {}; + class HLSLStructuredBufferType : public BuiltinGenericType {}; + class HLSLRWStructuredBufferType : public BuiltinGenericType {}; + + class UntypedBufferResourceType : public DeclRefType {}; + class HLSLByteAddressBufferType : public UntypedBufferResourceType {}; + class HLSLRWByteAddressBufferType : public UntypedBufferResourceType {}; + + class HLSLAppendStructuredBufferType : public BuiltinGenericType {}; + class HLSLConsumeStructuredBufferType : public BuiltinGenericType {}; + + class HLSLInputPatchType : public BuiltinGenericType {}; + class HLSLOutputPatchType : public BuiltinGenericType {}; + + // HLSL geometry shader output stream types + + class HLSLStreamOutputType : public BuiltinGenericType {}; + class HLSLPointStreamType : public HLSLStreamOutputType {}; + class HLSLLineStreamType : public HLSLStreamOutputType {}; + class HLSLTriangleStreamType : public HLSLStreamOutputType {}; + + // + class GLSLInputAttachmentType : public DeclRefType {}; + + // Base class for types used when desugaring parameter block + // declarations, includeing HLSL `cbuffer` or GLSL `uniform` blocks. + class ParameterBlockType : public PointerLikeType {}; + + class UniformParameterBlockType : public ParameterBlockType {}; + class VaryingParameterBlockType : public ParameterBlockType {}; + + // Type for HLSL `cbuffer` declarations, and `ConstantBuffer<T>` + // ALso used for GLSL `uniform` blocks. + class ConstantBufferType : public UniformParameterBlockType {}; + + // Type for HLSL `tbuffer` declarations, and `TextureBuffer<T>` + class TextureBufferType : public UniformParameterBlockType {}; + + // Type for GLSL `in` and `out` blocks + class GLSLInputParameterBlockType : public VaryingParameterBlockType {}; + class GLSLOutputParameterBlockType : public VaryingParameterBlockType {}; + + // Type for GLLSL `buffer` blocks + class GLSLShaderStorageBufferType : public UniformParameterBlockType {}; + + class ArrayExpressionType : public ExpressionType + { + public: + RefPtr<ExpressionType> BaseType; + RefPtr<IntVal> ArrayLength; + virtual CoreLib::Basic::String ToString() override; + protected: + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + virtual int GetHashCode() override; + }; + + // The "type" of an expression that resolves to a type. + // For example, in the expression `float(2)` the sub-expression, + // `float` would have the type `TypeType(float)`. + class TypeType : public ExpressionType + { + public: + TypeType(RefPtr<ExpressionType> type) + : type(type) + {} + + // The type that this is the type of... + RefPtr<ExpressionType> type; + + + virtual String ToString() override; + + protected: + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + virtual int GetHashCode() override; + }; + + class GenericDecl; + + // A vector type, e.g., `vector<T,N>` + class VectorExpressionType : public ArithmeticExpressionType + { + public: +#if 0 + VectorExpressionType( + RefPtr<ExpressionType> elementType, + RefPtr<IntVal> elementCount) + : elementType(elementType) + , elementCount(elementCount) + {} +#endif + + // The type of vector elements. + // As an invariant, this should be a basic type or an alias. + RefPtr<ExpressionType> elementType; + + // The number of elements + RefPtr<IntVal> elementCount; + + virtual String ToString() override; + + protected: + virtual BasicExpressionType* GetScalarType() override; + }; + + // A matrix type, e.g., `matrix<T,R,C>` + class MatrixExpressionType : public ArithmeticExpressionType + { + public: + // TODO: consider adding these back for convenience, + // with a way to initialize them on-demand from the + // real storage (which is in the `DeclRefType` +#if 0 + // The type of vector elements. + // As an invariant, this should be a basic type or an alias. + RefPtr<ExpressionType> elementType; + + // The type of the matrix rows + RefPtr<VectorExpressionType> rowType; + + // The number of rows and columns + RefPtr<IntVal> rowCount; + RefPtr<IntVal> colCount; +#endif + ExpressionType* getElementType(); + IntVal* getRowCount(); + IntVal* getColumnCount(); + + + virtual String ToString() override; + + protected: + virtual BasicExpressionType* GetScalarType() override; + }; + + inline BaseType GetVectorBaseType(VectorExpressionType* vecType) { + return vecType->elementType->AsBasicType()->BaseType; + } + + inline int GetVectorSize(VectorExpressionType* vecType) + { + auto constantVal = vecType->elementCount.As<ConstantIntVal>(); + if (constantVal) + return constantVal->value; + // TODO: what to do in this case? + return 0; + } + + class Type + { + public: + RefPtr<ExpressionType> DataType; + // ContrainedWorlds: Implementation must be defined at at least one of of these worlds in order to satisfy global dependency + // FeasibleWorlds: The component can be computed at any of these worlds + EnumerableHashSet<String> ConstrainedWorlds, FeasibleWorlds; + EnumerableHashSet<String> PinnedWorlds; + }; + + class ContainerDecl; + + + // A group of declarations that should be treated as a unit + class DeclGroup : public DeclBase + { + public: + List<RefPtr<Decl>> decls; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + template<typename T> + struct FilteredMemberList + { + typedef RefPtr<Decl> Element; + + FilteredMemberList() + : mBegin(NULL) + , mEnd(NULL) + {} + + explicit FilteredMemberList( + List<Element> const& list) + : mBegin(Adjust(list.begin(), list.end())) + , mEnd(list.end()) + {} + + struct Iterator + { + Element* mCursor; + Element* mEnd; + + bool operator!=(Iterator const& other) + { + return mCursor != other.mCursor; + } + + void operator++() + { + mCursor = Adjust(mCursor + 1, mEnd); + } + + RefPtr<T>& operator*() + { + return *(RefPtr<T>*)mCursor; + } + }; + + Iterator begin() + { + Iterator iter = { mBegin, mEnd }; + return iter; + } + + Iterator end() + { + Iterator iter = { mEnd, mEnd }; + return iter; + } + + static Element* Adjust(Element* cursor, Element* end) + { + while (cursor != end) + { + if ((*cursor).As<T>()) + return cursor; + cursor++; + } + return cursor; + } + + // TODO(tfoley): It is ugly to have these. + // We should probably fix the call sites instead. + RefPtr<T>& First() { return *begin(); } + int Count() + { + int count = 0; + for (auto iter : (*this)) + { + (void)iter; + count++; + } + return count; + } + + List<RefPtr<T>> ToArray() + { + List<RefPtr<T>> result; + for (auto element : (*this)) + { + result.Add(element); + } + return result; + } + + Element* mBegin; + Element* mEnd; + }; + + struct TransparentMemberInfo + { + // The declaration of the transparent member + Decl* decl; + }; + + // A "container" decl is a parent to other declarations + class ContainerDecl : public Decl + { + public: + List<RefPtr<Decl>> Members; + + template<typename T> + FilteredMemberList<T> GetMembersOfType() + { + return FilteredMemberList<T>(Members); + } + + + // Dictionary for looking up members by name. + // This is built on demand before performing lookup. + Dictionary<String, Decl*> memberDictionary; + + // Whether the `memberDictionary` is valid. + // Should be set to `false` if any members get added/remoed. + bool memberDictionaryIsValid = false; + + // A list of transparent members, to be used in lookup + // Note: this is only valid if `memberDictionaryIsValid` is true + List<TransparentMemberInfo> transparentMembers; + }; + + template<typename T> + struct FilteredMemberRefList + { + List<RefPtr<Decl>> const& decls; + RefPtr<Substitutions> substitutions; + + FilteredMemberRefList( + List<RefPtr<Decl>> const& decls, + RefPtr<Substitutions> substitutions) + : decls(decls) + , substitutions(substitutions) + {} + + int Count() const + { + int count = 0; + for (auto d : *this) + count++; + return count; + } + + List<T> ToArray() const + { + List<T> result; + for (auto d : *this) + result.Add(d); + return result; + } + + struct Iterator + { + FilteredMemberRefList const* list; + RefPtr<Decl>* ptr; + RefPtr<Decl>* end; + + Iterator() : list(nullptr), ptr(nullptr) {} + Iterator( + FilteredMemberRefList const* list, + RefPtr<Decl>* ptr, + RefPtr<Decl>* end) + : list(list) + , ptr(ptr) + , end(end) + {} + + bool operator!=(Iterator other) + { + return ptr != other.ptr; + } + + void operator++() + { + ptr = list->Adjust(ptr + 1, end); + } + + T operator*() + { + return DeclRef(ptr->Ptr(), list->substitutions).As<T>(); + } + }; + + Iterator begin() const { return Iterator(this, Adjust(decls.begin(), decls.end()), decls.end()); } + Iterator end() const { return Iterator(this, decls.end(), decls.end()); } + + RefPtr<Decl>* Adjust(RefPtr<Decl>* ptr, RefPtr<Decl>* end) const + { + while (ptr != end) + { + DeclRef declRef(ptr->Ptr(), substitutions); + if (declRef.As<T>()) + return ptr; + ptr++; + } + return end; + } + }; + + struct ContainerDeclRef : DeclRef + { + SLANG_DECLARE_DECL_REF(ContainerDecl); + + FilteredMemberRefList<DeclRef> GetMembers() const + { + return FilteredMemberRefList<DeclRef>(GetDecl()->Members, substitutions); + } + + template<typename T> + FilteredMemberRefList<T> GetMembersOfType() const + { + return FilteredMemberRefList<T>(GetDecl()->Members, substitutions); + } + + }; + + // + // Type Expressions + // + + // A "type expression" is a term that we expect to resolve to a type during checking. + // We store both the original syntax and the resolved type here. + struct TypeExp + { + TypeExp() {} + TypeExp(TypeExp const& other) + : exp(other.exp) + , type(other.type) + {} + explicit TypeExp(RefPtr<ExpressionSyntaxNode> exp) + : exp(exp) + {} + TypeExp(RefPtr<ExpressionSyntaxNode> exp, RefPtr<ExpressionType> type) + : exp(exp) + , type(type) + {} + + RefPtr<ExpressionSyntaxNode> exp; + RefPtr<ExpressionType> type; + + bool Equals(ExpressionType* other) { + return type->Equals(other); + } + bool Equals(RefPtr<ExpressionType> other) { + return type->Equals(other.Ptr()); + } + ExpressionType* Ptr() { return type.Ptr(); } + operator RefPtr<ExpressionType>() + { + return type; + } + ExpressionType* operator->() { return Ptr(); } + + TypeExp Accept(SyntaxVisitor* visitor); + }; + + + // + // Declarations + // + + // Base class for all variable-like declarations + class VarDeclBase : public Decl + { + public: + // Type of the variable + TypeExp Type; + + ExpressionType* getType() { return Type.type.Ptr(); } + + // Initializer expression (optional) + RefPtr<ExpressionSyntaxNode> Expr; + }; + + struct VarDeclBaseRef : DeclRef + { + SLANG_DECLARE_DECL_REF(VarDeclBase); + + RefPtr<ExpressionType> GetType() const { return Substitute(GetDecl()->Type); } + + RefPtr<ExpressionSyntaxNode> getInitExpr() const { return Substitute(GetDecl()->Expr); } + }; + + // A field of a `struct` type + class StructField : public VarDeclBase + { + public: + StructField() + {} + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + struct FieldDeclRef : VarDeclBaseRef + { + SLANG_DECLARE_DECL_REF(StructField) + }; + + // An extension to apply to an existing type + class ExtensionDecl : public ContainerDecl + { + public: + TypeExp targetType; + + // next extension attached to the same nominal type + ExtensionDecl* nextCandidateExtension = nullptr; + + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + struct ExtensionDeclRef : ContainerDeclRef + { + SLANG_DECLARE_DECL_REF(ExtensionDecl); + + RefPtr<ExpressionType> GetTargetType() const { return Substitute(GetDecl()->targetType); } + }; + + // Declaration of a type that represents some sort of aggregate + class AggTypeDecl : public ContainerDecl + { + public: + // extensions that might apply to this declaration + ExtensionDecl* candidateExtensions = nullptr; + FilteredMemberList<StructField> GetFields() + { + return GetMembersOfType<StructField>(); + } + StructField* FindField(String name) + { + for (auto field : GetFields()) + { + if (field->Name.Content == name) + return field.Ptr(); + } + return nullptr; + } + int FindFieldIndex(String name) + { + int index = 0; + for (auto field : GetFields()) + { + if (field->Name.Content == name) + return index; + index++; + } + return -1; + } + }; + + struct AggTypeDeclRef : public ContainerDeclRef + { + SLANG_DECLARE_DECL_REF(AggTypeDecl); + + ExtensionDecl* GetCandidateExtensions() const { return GetDecl()->candidateExtensions; } + }; + + class StructSyntaxNode : public AggTypeDecl + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + struct StructDeclRef : public AggTypeDeclRef + { + SLANG_DECLARE_DECL_REF(StructSyntaxNode); + + FilteredMemberRefList<FieldDeclRef> GetFields() const { return GetMembersOfType<FieldDeclRef>(); } + }; + + class ClassSyntaxNode : public AggTypeDecl + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + struct ClassDeclRef : public AggTypeDeclRef + { + SLANG_DECLARE_DECL_REF(ClassSyntaxNode); + + FilteredMemberRefList<FieldDeclRef> GetFields() const { return GetMembersOfType<FieldDeclRef>(); } + }; + + // A trait which other types can conform to + class TraitDecl : public AggTypeDecl + { + public: + List<TypeExp> bases; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + struct TraitDeclRef : public AggTypeDeclRef + { + SLANG_DECLARE_DECL_REF(TraitDecl); + }; + + + // A declaration that states that the enclosing type supports a given trait + // + // TODO: this same construct might be used for represent other inheritance-like cases + class TraitConformanceDecl : public Decl + { + public: + // The type expression as written + TypeExp base; + + // The trait that we found we conform to... + TraitDeclRef traitDeclRef; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + struct TraitConformanceDeclRef : public DeclRef + { + SLANG_DECLARE_DECL_REF(TraitConformanceDecl); + + TraitDeclRef GetTraitDeclRef() { return Substitute(GetDecl()->traitDeclRef).As<TraitDeclRef>(); } + }; + + // A declaration that represents a simple (non-aggregate) type + class SimpleTypeDecl : public Decl + { + }; + + struct SimpleTypeDeclRef : DeclRef + { + SLANG_DECLARE_DECL_REF(SimpleTypeDecl) + }; + + // A `typedef` declaration + class TypeDefDecl : public SimpleTypeDecl + { + public: + TypeExp Type; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + struct TypeDefDeclRef : SimpleTypeDeclRef + { + SLANG_DECLARE_DECL_REF(TypeDefDecl); + + RefPtr<ExpressionType> GetType() const { return Substitute(GetDecl()->Type); } + }; + + // A type alias of some kind (e.g., via `typedef`) + class NamedExpressionType : public ExpressionType + { + public: + NamedExpressionType(TypeDefDeclRef declRef) + : declRef(declRef) + {} + + TypeDefDeclRef declRef; + + virtual String ToString() override; + + protected: + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + virtual int GetHashCode() override; + }; + + + class StatementSyntaxNode : public ModifiableSyntaxNode + { + public: + }; + + // A scope for local declarations (e.g., as part of a statement) + class ScopeDecl : public ContainerDecl + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class ScopeStmt : public StatementSyntaxNode + { + public: + RefPtr<ScopeDecl> scopeDecl; + }; + + class BlockStatementSyntaxNode : public ScopeStmt + { + public: + List<RefPtr<StatementSyntaxNode>> Statements; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class UnparsedStmt : public StatementSyntaxNode + { + public: + // The tokens that were contained between `{` and `}` + List<Token> tokens; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class ParameterSyntaxNode : public VarDeclBase + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + struct ParamDeclRef : VarDeclBaseRef + { + SLANG_DECLARE_DECL_REF(ParameterSyntaxNode); + }; + + // Base class for things that have parameter lists and can thus be applied to arguments ("called") + class CallableDecl : public ContainerDecl + { + public: + FilteredMemberList<ParameterSyntaxNode> GetParameters() + { + return GetMembersOfType<ParameterSyntaxNode>(); + } + TypeExp ReturnType; + }; + + struct CallableDeclRef : ContainerDeclRef + { + SLANG_DECLARE_DECL_REF(CallableDecl); + + RefPtr<ExpressionType> GetResultType() const + { + return Substitute(GetDecl()->ReturnType.type.Ptr()); + } + + FilteredMemberRefList<ParamDeclRef> GetParameters() + { + return GetMembersOfType<ParamDeclRef>(); + } + }; + + // Base class for callable things that may also have a body that is evaluated to produce their result + class FunctionDeclBase : public CallableDecl + { + public: + RefPtr<StatementSyntaxNode> Body; + }; + + struct FuncDeclBaseRef : CallableDeclRef + { + SLANG_DECLARE_DECL_REF(FunctionDeclBase); + }; + + // Function types are currently used for references to symbols that name + // either ordinary functions, or "component functions." + // We do not directly store a representation of the type, and instead + // use a reference to the symbol to stand in for its logical type + class FuncType : public ExpressionType + { + public: + CallableDeclRef declRef; + + virtual String ToString() override; + protected: + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + virtual int GetHashCode() override; + }; + + // A constructor/initializer to create instances of a type + class ConstructorDecl : public FunctionDeclBase + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + struct ConstructorDeclRef : FuncDeclBaseRef + { + SLANG_DECLARE_DECL_REF(ConstructorDecl); + }; + + // A subscript operation used to index instances of a type + class SubscriptDecl : public CallableDecl + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + struct SubscriptDeclRef : CallableDeclRef + { + SLANG_DECLARE_DECL_REF(SubscriptDecl); + }; + + // An "accessor" for a subscript or property + class AccessorDecl : public FunctionDeclBase + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class GetterDecl : public AccessorDecl + { + }; + + class SetterDecl : public AccessorDecl + { + }; + + // + + class FunctionSyntaxNode : public FunctionDeclBase + { + public: + String InternalName; + bool IsInline() { return HasModifier<InlineModifier>(); } + bool IsExtern() { return HasModifier<ExternModifier>(); } + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + FunctionSyntaxNode() + { + } + }; + + struct FuncDeclRef : FuncDeclBaseRef + { + SLANG_DECLARE_DECL_REF(FunctionSyntaxNode); + }; + + + struct Scope : public RefObject + { + // The parent of this scope (where lookup should go if nothing is found locally) + RefPtr<Scope> parent; + + // The next sibling of this scope (a peer for lookup) + RefPtr<Scope> nextSibling; + + // The container to use for lookup + // + // Note(tfoley): This is kept as an unowned pointer + // so that a scope can't keep parts of the AST alive, + // but the opposite it allowed. + ContainerDecl* containerDecl; + }; + + // Base class for expressions that will reference declarations + class DeclRefExpr : public ExpressionSyntaxNode + { + public: + // The scope in which to perform lookup + RefPtr<Scope> scope; + + // The declaration of the symbol being referenced + DeclRef declRef; + }; + + class VarExpressionSyntaxNode : public DeclRefExpr + { + public: + // The name of the symbol being referenced + String Variable; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // Masks to be applied when lookup up declarations + enum class LookupMask : uint8_t + { + Type = 0x1, + Function = 0x2, + Value = 0x4, + + All = Type | Function | Value, + }; + + // Represents one item found during lookup + struct LookupResultItem + { + // Sometimes lookup finds an item, but there were additional + // "hops" taken to reach it. We need to remember these steps + // so that if/when we consturct a full expression we generate + // appropriate AST nodes for all the steps. + // + // We build up a list of these "breadcrumbs" while doing + // lookup, and store them alongside each item found. + class Breadcrumb : public RefObject + { + public: + enum class Kind + { + Member, // A member was references + Deref, // A value with pointer(-like) type was dereferenced + }; + + Kind kind; + DeclRef declRef; + RefPtr<Breadcrumb> next; + + Breadcrumb(Kind kind, DeclRef declRef, RefPtr<Breadcrumb> next) + : kind(kind) + , declRef(declRef) + , next(next) + {} + }; + + // A properly-specialized reference to the declaration that was found. + DeclRef declRef; + + // Any breadcrumbs needed in order to turn that declaration + // reference into a well-formed expression. + // + // This is unused in the simple case where a declaration + // is being referenced directly (rather than through + // transparent members). + RefPtr<Breadcrumb> breadcrumbs; + + LookupResultItem() = default; + explicit LookupResultItem(DeclRef declRef) + : declRef(declRef) + {} + LookupResultItem(DeclRef declRef, RefPtr<Breadcrumb> breadcrumbs) + : declRef(declRef) + , breadcrumbs(breadcrumbs) + {} + }; + + + // Result of looking up a name in some lexical/semantic environment. + // Can be used to enumerate all the declarations matching that name, + // in the case where the result is overloaded. + struct LookupResult + { + // The one item that was found, in the smple case + LookupResultItem item; + + // All of the items that were found, in the complex case. + // Note: if there was no overloading, then this list isn't + // used at all, to avoid allocation. + List<LookupResultItem> items; + + // Was at least one result found? + bool isValid() const { return item.declRef.GetDecl() != nullptr; } + + bool isOverloaded() const { return items.Count() > 1; } + }; + + struct LookupRequest + { + RefPtr<Scope> scope = nullptr; + RefPtr<Scope> endScope = nullptr; + + LookupMask mask = LookupMask::All; + }; + + // An expression that references an overloaded set of declarations + // having the same name. + class OverloadedExpr : public ExpressionSyntaxNode + { + public: + // Optional: the base expression is this overloaded result + // arose from a member-reference expression. + RefPtr<ExpressionSyntaxNode> base; + + // The lookup result that was ambiguous + LookupResult lookupResult2; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + typedef double FloatingPointLiteralValue; + + class ConstantExpressionSyntaxNode : public ExpressionSyntaxNode + { + public: + enum class ConstantType + { + Int, Bool, Float + }; + ConstantType ConstType; + union + { + int IntValue; + FloatingPointLiteralValue FloatValue; + }; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + enum class Operator + { + Neg, Not, BitNot, PreInc, PreDec, PostInc, PostDec, + Mul, Div, Mod, + Add, Sub, + Lsh, Rsh, + Eql, Neq, Greater, Less, Geq, Leq, + BitAnd, BitXor, BitOr, + And, + Or, + Sequence, + Select, + Assign = 200, AddAssign, SubAssign, MulAssign, DivAssign, ModAssign, + LshAssign, RshAssign, OrAssign, AndAssign, XorAssign, + }; + String GetOperatorFunctionName(Operator op); + String OperatorToString(Operator op); + + // An initializer list, e.g. `{ 1, 2, 3 }` + class InitializerListExpr : public ExpressionSyntaxNode + { + public: + List<RefPtr<ExpressionSyntaxNode>> args; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // A base expression being applied to arguments: covers + // both ordinary `()` function calls and `<>` generic application + class AppExprBase : public ExpressionSyntaxNode + { + public: + RefPtr<ExpressionSyntaxNode> FunctionExpr; + List<RefPtr<ExpressionSyntaxNode>> Arguments; + }; + + + class InvokeExpressionSyntaxNode : public AppExprBase + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class OperatorExpressionSyntaxNode : public InvokeExpressionSyntaxNode + { + public: +// Operator Operator; +// void SetOperator(RefPtr<Scope> scope, Slang::Compiler::Operator op); + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class InfixExpr : public OperatorExpressionSyntaxNode {}; + class PrefixExpr : public OperatorExpressionSyntaxNode {}; + class PostfixExpr : public OperatorExpressionSyntaxNode {}; + + class IndexExpressionSyntaxNode : public ExpressionSyntaxNode + { + public: + RefPtr<ExpressionSyntaxNode> BaseExpression; + RefPtr<ExpressionSyntaxNode> IndexExpression; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class MemberExpressionSyntaxNode : public DeclRefExpr + { + public: + RefPtr<ExpressionSyntaxNode> BaseExpression; + String MemberName; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class SwizzleExpr : public ExpressionSyntaxNode + { + public: + RefPtr<ExpressionSyntaxNode> base; + int elementCount; + int elementIndices[4]; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // A dereference of a pointer or pointer-like type + class DerefExpr : public ExpressionSyntaxNode + { + public: + RefPtr<ExpressionSyntaxNode> base; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class TypeCastExpressionSyntaxNode : public ExpressionSyntaxNode + { + public: + TypeExp TargetType; + RefPtr<ExpressionSyntaxNode> Expression; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class SelectExpressionSyntaxNode : public OperatorExpressionSyntaxNode + { + public: + }; + + + class EmptyStatementSyntaxNode : public StatementSyntaxNode + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class DiscardStatementSyntaxNode : public StatementSyntaxNode + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + struct Variable : public VarDeclBase + { + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class VarDeclrStatementSyntaxNode : public StatementSyntaxNode + { + public: + RefPtr<DeclBase> decl; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class UsingFileDecl : public Decl + { + public: + Token fileName; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class ProgramSyntaxNode : public ContainerDecl + { + public: + // Access members of specific types + FilteredMemberList<UsingFileDecl> GetUsings() + { + return GetMembersOfType<UsingFileDecl>(); + } + FilteredMemberList<FunctionSyntaxNode> GetFunctions() + { + return GetMembersOfType<FunctionSyntaxNode>(); + } + + FilteredMemberList<ClassSyntaxNode> GetClasses() + { + return GetMembersOfType<ClassSyntaxNode>(); + } + FilteredMemberList<StructSyntaxNode> GetStructs() + { + return GetMembersOfType<StructSyntaxNode>(); + } + FilteredMemberList<TypeDefDecl> GetTypeDefs() + { + return GetMembersOfType<TypeDefDecl>(); + } +#if 0 + void Include(ProgramSyntaxNode * other) + { + Members.AddRange(other->Members); + } +#endif + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class IfStatementSyntaxNode : public StatementSyntaxNode + { + public: + RefPtr<ExpressionSyntaxNode> Predicate; + RefPtr<StatementSyntaxNode> PositiveStatement; + RefPtr<StatementSyntaxNode> NegativeStatement; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // A statement that can be escaped with a `break` + class BreakableStmt : public ScopeStmt + {}; + + class SwitchStmt : public BreakableStmt + { + public: + RefPtr<ExpressionSyntaxNode> condition; + RefPtr<StatementSyntaxNode> body; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // A statement that is expected to appear lexically nested inside + // some other construct, and thus needs to keep track of the + // outer statement that it is associated with... + class ChildStmt : public StatementSyntaxNode + { + public: + StatementSyntaxNode* parentStmt = nullptr; + }; + + // a `case` or `default` statement inside a `switch` + // + // Note(tfoley): A correct AST for a C-like language would treat + // these as a labelled statement, and so they would contain a + // sub-statement. I'm leaving that out for now for simplicity. + class CaseStmtBase : public ChildStmt + { + public: + }; + + // a `case` statement inside a `switch` + class CaseStmt : public CaseStmtBase + { + public: + RefPtr<ExpressionSyntaxNode> expr; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // a `default` statement inside a `switch` + class DefaultStmt : public CaseStmtBase + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // A statement that represents a loop, and can thus be escaped with a `continue` + class LoopStmt : public BreakableStmt + {}; + + class ForStatementSyntaxNode : public LoopStmt + { + public: + RefPtr<StatementSyntaxNode> InitialStatement; + RefPtr<ExpressionSyntaxNode> SideEffectExpression, PredicateExpression; + RefPtr<StatementSyntaxNode> Statement; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class WhileStatementSyntaxNode : public LoopStmt + { + public: + RefPtr<ExpressionSyntaxNode> Predicate; + RefPtr<StatementSyntaxNode> Statement; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class DoWhileStatementSyntaxNode : public LoopStmt + { + public: + RefPtr<StatementSyntaxNode> Statement; + RefPtr<ExpressionSyntaxNode> Predicate; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // The case of child statements that do control flow relative + // to their parent statement. + class JumpStmt : public ChildStmt + { + public: + StatementSyntaxNode* parentStmt = nullptr; + }; + + class BreakStatementSyntaxNode : public JumpStmt + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class ContinueStatementSyntaxNode : public JumpStmt + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class ReturnStatementSyntaxNode : public StatementSyntaxNode + { + public: + RefPtr<ExpressionSyntaxNode> Expression; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class ExpressionStatementSyntaxNode : public StatementSyntaxNode + { + public: + RefPtr<ExpressionSyntaxNode> Expression; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // Note(tfoley): Moved this further down in the file because it depends on + // `ExpressionSyntaxNode` and a forward reference just isn't good enough + // for `RefPtr`. + // + class GenericAppExpr : public AppExprBase + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // An expression representing re-use of the syntax for a type in more + // than once conceptually-distinct declaration + class SharedTypeExpr : public ExpressionSyntaxNode + { + public: + // The underlying type expression that we want to share + TypeExp base; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + + // A modifier that indicates a built-in base type (e.g., `float`) + class BuiltinTypeModifier : public Modifier + { + public: + BaseType tag; + }; + + // A modifier that indicates a built-in type that isn't a base type (e.g., `vector`) + // + // TODO(tfoley): This deserves a better name than "magic" + class MagicTypeModifier : public Modifier + { + public: + String name; + uint32_t tag; + }; + + // Modifiers that affect the storage layout for matrices + class MatrixLayoutModifier : public Modifier {}; + + // Modifiers that specify row- and column-major layout, respectively + class RowMajorLayoutModifier : public MatrixLayoutModifier {}; + class ColumnMajorLayoutModifier : public MatrixLayoutModifier {}; + + // The HLSL flavor of those modifiers + class HLSLRowMajorLayoutModifier : public RowMajorLayoutModifier {}; + class HLSLColumnMajorLayoutModifier : public ColumnMajorLayoutModifier {}; + + // The GLSL flavor of those modifiers + // + // Note(tfoley): The GLSL versions of these modifiers are "backwards" + // in the sense that when a GLSL programmer requests row-major layout, + // we actually interpret that as requesting column-major. This makes + // sense because we interpret matrix conventions backwards from how + // GLSL specifies them. + class GLSLRowMajorLayoutModifier : public ColumnMajorLayoutModifier {}; + class GLSLColumnMajorLayoutModifier : public RowMajorLayoutModifier {}; + + // More HLSL Keyword + + // HLSL `nointerpolation` modifier + class HLSLNoInterpolationModifier : public Modifier {}; + + // HLSL `linear` modifier + class HLSLLinearModifier : public Modifier {}; + + // HLSL `sample` modifier + class HLSLSampleModifier : public Modifier {}; + + // HLSL `centroid` modifier + class HLSLCentroidModifier : public Modifier {}; + + // HLSL `precise` modifier + class HLSLPreciseModifier : public Modifier {}; + + // HLSL `shared` modifier (which is used by the effect system, + // and shouldn't be confused with `groupshared`) + class HLSLEffectSharedModifier : public Modifier {}; + + // HLSL `groupshared` modifier + class HLSLGroupSharedModifier : public Modifier {}; + + // HLSL `static` modifier (probably doesn't need to be + // treated as HLSL-specific) + class HLSLStaticModifier : public Modifier {}; + + // HLSL `uniform` modifier (distinct meaning from GLSL + // use of the keyword) + class HLSLUniformModifier : public Modifier {}; + + // HLSL `volatile` modifier (ignored) + class HLSLVolatileModifier : public Modifier {}; + + // An HLSL `[name(arg0, ...)]` style attribute. + class HLSLAttribute : public Modifier + { + public: + Token nameToken; + List<RefPtr<ExpressionSyntaxNode>> args; + }; + + // An HLSL `[name(...)]` attribute that hasn't undergone + // any semantic analysis. + // After analysis, this might be transformed into a more specific case. + class HLSLUncheckedAttribute : public HLSLAttribute + { + public: + }; + + // An HLSL `[numthreads(x,y,z)]` attribute + class HLSLNumThreadsAttribute : public HLSLAttribute + { + public: + // The number of threads to use along each axis + int32_t x; + int32_t y; + int32_t z; + }; + + // HLSL modifiers for geometry shader input topology + class HLSLGeometryShaderInputPrimitiveTypeModifier : public Modifier {}; + class HLSLPointModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {}; + class HLSLLineModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {}; + class HLSLTriangleModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {}; + class HLSLLineAdjModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {}; + class HLSLTriangleAdjModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {}; + + // + + // A generic declaration, parameterized on types/values + class GenericDecl : public ContainerDecl + { + public: + // The decl that is genericized... + RefPtr<Decl> inner; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + struct GenericDeclRef : ContainerDeclRef + { + SLANG_DECLARE_DECL_REF(GenericDecl); + + Decl* GetInner() const { return GetDecl()->inner.Ptr(); } + }; + + // The "type" of an expression that names a generic declaration. + class GenericDeclRefType : public ExpressionType + { + public: + GenericDeclRefType(GenericDeclRef declRef) + : declRef(declRef) + {} + + GenericDeclRef declRef; + GenericDeclRef const& GetDeclRef() const { return declRef; } + + virtual String ToString() override; + + protected: + virtual bool EqualsImpl(ExpressionType * type) override; + virtual int GetHashCode() override; + virtual ExpressionType* CreateCanonicalType() override; + }; + + + + class GenericTypeParamDecl : public SimpleTypeDecl + { + public: + // The bound for the type parameter represents a trait that any + // type used as this parameter must conform to +// TypeExp bound; + + // The "initializer" for the parameter represents a default value + TypeExp initType; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + struct GenericTypeParamDeclRef : SimpleTypeDeclRef + { + SLANG_DECLARE_DECL_REF(GenericTypeParamDecl); + }; + + // A constraint placed as part of a generic declaration + class GenericTypeConstraintDecl : public Decl + { + public: + // A type constraint like `T : U` is constraining `T` to be "below" `U` + // on a lattice of types. This may not be a subtyping relationship + // per se, but it makes sense to use that terminology here, so we + // think of these fields as the sub-type and sup-ertype, respectively. + TypeExp sub; + TypeExp sup; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + struct GenericTypeConstraintDeclRef : DeclRef + { + SLANG_DECLARE_DECL_REF(GenericTypeConstraintDecl); + + RefPtr<ExpressionType> GetSub() { return Substitute(GetDecl()->sub); } + RefPtr<ExpressionType> GetSup() { return Substitute(GetDecl()->sup); } + }; + + + class GenericValueParamDecl : public VarDeclBase + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + struct GenericValueParamDeclRef : VarDeclBaseRef + { + SLANG_DECLARE_DECL_REF(GenericValueParamDecl); + }; + + // The logical "value" of a rererence to a generic value parameter + class GenericParamIntVal : public IntVal + { + public: + VarDeclBaseRef declRef; + + GenericParamIntVal(VarDeclBaseRef declRef) + : declRef(declRef) + {} + + virtual bool EqualsVal(Val* val) override; + virtual String ToString() override; + virtual int GetHashCode() override; + virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override; + }; + + // Declaration of a user-defined modifier + class ModifierDecl : public Decl + { + public: + // The name of the C++ class to instantiate + // (this is a reference to a class in the compiler source code, + // and not the user's source code) + Token classNameToken; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // An empty declaration (which might still have modifiers attached). + // + // An empty declaration is uncommon in HLSL, but + // in GLSL it is often used at the global scope + // to declare metadata that logically belongs + // to the entry point, e.g.: + // + // layout(local_size_x = 16) in; + // + class EmptyDecl : public Decl + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // + + class SyntaxVisitor : public Object + { + protected: + DiagnosticSink * sink = nullptr; + DiagnosticSink* getSink() { return sink; } + + SourceLanguage sourceLanguage = SourceLanguage::Unknown; + public: + void setSourceLanguage(SourceLanguage language) + { + sourceLanguage = language; + } + + SyntaxVisitor(DiagnosticSink * sink) + : sink(sink) + {} + virtual RefPtr<ProgramSyntaxNode> VisitProgram(ProgramSyntaxNode* program) + { + for (auto & m : program->Members) + m = m->Accept(this).As<Decl>(); + return program; + } + + virtual RefPtr<UsingFileDecl> VisitUsingFileDecl(UsingFileDecl * decl) + { + return decl; + } + + virtual RefPtr<FunctionSyntaxNode> VisitFunction(FunctionSyntaxNode* func) + { + func->ReturnType = func->ReturnType.Accept(this); + for (auto & member : func->Members) + member = member->Accept(this).As<Decl>(); + if (func->Body) + func->Body = func->Body->Accept(this).As<BlockStatementSyntaxNode>(); + return func; + } + virtual RefPtr<ScopeDecl> VisitScopeDecl(ScopeDecl* decl) + { + // By default don't visit children, because they will always + // be encountered in the ordinary flow of the corresponding statement. + return decl; + } + virtual RefPtr<StructSyntaxNode> VisitStruct(StructSyntaxNode * s) + { + for (auto & f : s->Members) + f = f->Accept(this).As<Decl>(); + return s; + } + virtual RefPtr<ClassSyntaxNode> VisitClass(ClassSyntaxNode * s) + { + for (auto & f : s->Members) + f = f->Accept(this).As<Decl>(); + return s; + } + virtual RefPtr<GenericDecl> VisitGenericDecl(GenericDecl * decl) + { + for (auto & m : decl->Members) + m = m->Accept(this).As<Decl>(); + decl->inner = decl->inner->Accept(this).As<Decl>(); + return decl; + } + virtual RefPtr<TypeDefDecl> VisitTypeDefDecl(TypeDefDecl* decl) + { + decl->Type = decl->Type.Accept(this); + return decl; + } + virtual RefPtr<StatementSyntaxNode> VisitDiscardStatement(DiscardStatementSyntaxNode * stmt) + { + return stmt; + } + virtual RefPtr<StructField> VisitStructField(StructField * f) + { + f->Type = f->Type.Accept(this); + return f; + } + virtual RefPtr<StatementSyntaxNode> VisitBlockStatement(BlockStatementSyntaxNode* stmt) + { + for (auto & s : stmt->Statements) + s = s->Accept(this).As<StatementSyntaxNode>(); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitBreakStatement(BreakStatementSyntaxNode* stmt) + { + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitContinueStatement(ContinueStatementSyntaxNode* stmt) + { + return stmt; + } + + virtual RefPtr<StatementSyntaxNode> VisitDoWhileStatement(DoWhileStatementSyntaxNode* stmt) + { + if (stmt->Predicate) + stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); + if (stmt->Statement) + stmt->Statement = stmt->Statement->Accept(this).As<StatementSyntaxNode>(); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitEmptyStatement(EmptyStatementSyntaxNode* stmt) + { + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitForStatement(ForStatementSyntaxNode* stmt) + { + if (stmt->InitialStatement) + stmt->InitialStatement = stmt->InitialStatement->Accept(this).As<StatementSyntaxNode>(); + if (stmt->PredicateExpression) + stmt->PredicateExpression = stmt->PredicateExpression->Accept(this).As<ExpressionSyntaxNode>(); + if (stmt->SideEffectExpression) + stmt->SideEffectExpression = stmt->SideEffectExpression->Accept(this).As<ExpressionSyntaxNode>(); + if (stmt->Statement) + stmt->Statement = stmt->Statement->Accept(this).As<StatementSyntaxNode>(); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitIfStatement(IfStatementSyntaxNode* stmt) + { + if (stmt->Predicate) + stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); + if (stmt->PositiveStatement) + stmt->PositiveStatement = stmt->PositiveStatement->Accept(this).As<StatementSyntaxNode>(); + if (stmt->NegativeStatement) + stmt->NegativeStatement = stmt->NegativeStatement->Accept(this).As<StatementSyntaxNode>(); + return stmt; + } + virtual RefPtr<SwitchStmt> VisitSwitchStmt(SwitchStmt* stmt) + { + if (stmt->condition) + stmt->condition = stmt->condition->Accept(this).As<ExpressionSyntaxNode>(); + if (stmt->body) + stmt->body = stmt->body->Accept(this).As<BlockStatementSyntaxNode>(); + return stmt; + } + virtual RefPtr<CaseStmt> VisitCaseStmt(CaseStmt* stmt) + { + if (stmt->expr) + stmt->expr = stmt->expr->Accept(this).As<ExpressionSyntaxNode>(); + return stmt; + } + virtual RefPtr<DefaultStmt> VisitDefaultStmt(DefaultStmt* stmt) + { + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitReturnStatement(ReturnStatementSyntaxNode* stmt) + { + if (stmt->Expression) + stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitVarDeclrStatement(VarDeclrStatementSyntaxNode* stmt) + { + stmt->decl = stmt->decl->Accept(this).As<DeclBase>(); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitWhileStatement(WhileStatementSyntaxNode* stmt) + { + if (stmt->Predicate) + stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); + if (stmt->Statement) + stmt->Statement = stmt->Statement->Accept(this).As<StatementSyntaxNode>(); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitExpressionStatement(ExpressionStatementSyntaxNode* stmt) + { + if (stmt->Expression) + stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); + return stmt; + } + + virtual RefPtr<ExpressionSyntaxNode> VisitOperatorExpression(OperatorExpressionSyntaxNode* expr) + { + for (auto && child : expr->Arguments) + child->Accept(this); + return expr; + } + virtual RefPtr<ExpressionSyntaxNode> VisitConstantExpression(ConstantExpressionSyntaxNode* expr) + { + return expr; + } + virtual RefPtr<ExpressionSyntaxNode> VisitIndexExpression(IndexExpressionSyntaxNode* expr) + { + if (expr->BaseExpression) + expr->BaseExpression = expr->BaseExpression->Accept(this).As<ExpressionSyntaxNode>(); + if (expr->IndexExpression) + expr->IndexExpression = expr->IndexExpression->Accept(this).As<ExpressionSyntaxNode>(); + return expr; + } + virtual RefPtr<ExpressionSyntaxNode> VisitMemberExpression(MemberExpressionSyntaxNode * stmt) + { + if (stmt->BaseExpression) + stmt->BaseExpression = stmt->BaseExpression->Accept(this).As<ExpressionSyntaxNode>(); + return stmt; + } + virtual RefPtr<ExpressionSyntaxNode> VisitSwizzleExpression(SwizzleExpr * expr) + { + if (expr->base) + expr->base->Accept(this); + return expr; + } + virtual RefPtr<ExpressionSyntaxNode> VisitInvokeExpression(InvokeExpressionSyntaxNode* stmt) + { + stmt->FunctionExpr->Accept(this); + for (auto & arg : stmt->Arguments) + arg = arg->Accept(this).As<ExpressionSyntaxNode>(); + return stmt; + } + virtual RefPtr<ExpressionSyntaxNode> VisitTypeCastExpression(TypeCastExpressionSyntaxNode * stmt) + { + if (stmt->Expression) + stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); + return stmt->Expression; + } + virtual RefPtr<ExpressionSyntaxNode> VisitVarExpression(VarExpressionSyntaxNode* expr) + { + return expr; + } + + virtual RefPtr<ParameterSyntaxNode> VisitParameter(ParameterSyntaxNode* param) + { + return param; + } + virtual RefPtr<ExpressionSyntaxNode> VisitGenericApp(GenericAppExpr* type) + { + return type; + } + + virtual RefPtr<Variable> VisitDeclrVariable(Variable* dclr) + { + if (dclr->Expr) + dclr->Expr = dclr->Expr->Accept(this).As<ExpressionSyntaxNode>(); + return dclr; + } + + virtual TypeExp VisitTypeExp(TypeExp const& typeExp) + { + TypeExp result = typeExp; + result.exp = typeExp.exp->Accept(this).As<ExpressionSyntaxNode>(); + if (auto typeType = result.exp->Type.type.As<TypeType>()) + { + result.type = typeType->type; + } + return result; + } + + virtual void VisitExtensionDecl(ExtensionDecl* /*decl*/) + {} + + virtual void VisitConstructorDecl(ConstructorDecl* /*decl*/) + {} + + virtual void visitSubscriptDecl(SubscriptDecl* decl) = 0; + + virtual void visitAccessorDecl(AccessorDecl* decl) = 0; + + virtual void VisitTraitDecl(TraitDecl* /*decl*/) + {} + + virtual void VisitTraitConformanceDecl(TraitConformanceDecl* /*decl*/) + {} + + virtual RefPtr<ExpressionSyntaxNode> VisitSharedTypeExpr(SharedTypeExpr* typeExpr) + { + return typeExpr; + } + + virtual void VisitDeclGroup(DeclGroup* declGroup) + { + for (auto decl : declGroup->decls) + { + decl->Accept(this); + } + } + + virtual RefPtr<ExpressionSyntaxNode> visitInitializerListExpr(InitializerListExpr* expr) = 0; + }; + + // Note(tfoley): These logically belong to `ExpressionType`, + // but order-of-declaration stuff makes that tricky + // + // TODO(tfoley): These should really belong to the compilation context! + // + void RegisterBuiltinDecl( + RefPtr<Decl> decl, + RefPtr<BuiltinTypeModifier> modifier); + void RegisterMagicDecl( + RefPtr<Decl> decl, + RefPtr<MagicTypeModifier> modifier); + + // Look up a magic declaration by its name + RefPtr<Decl> findMagicDecl( + String const& name); + + // Create an instance of a syntax class by name + SyntaxNodeBase* createInstanceOfSyntaxClassByName( + String const& name); + + } +} + +#endif
\ No newline at end of file diff --git a/source/slang/token-defs.h b/source/slang/token-defs.h new file mode 100644 index 000000000..f29574bbb --- /dev/null +++ b/source/slang/token-defs.h @@ -0,0 +1,93 @@ +// token-defs.h + +// This file is meant to be included multiple times, to produce different +// pieces of code related to tokens +// +// Each token is declared here with: +// +// TOKEN(id, desc) +// +// where `id` is the identifier that will be used for the token in +// ordinary code, while `desc` is name we should print when +// referring to this token in diagnostic messages. + + +#ifndef TOKEN +#error Need to define TOKEN(ID, DESC) before including "token-defs.h" +#endif + +TOKEN(Unknown, "<unknown>") +TOKEN(EndOfFile, "end of file") +TOKEN(EndOfDirective, "end of line") +TOKEN(Invalid, "invalid character") +TOKEN(Identifier, "identifier") +TOKEN(IntLiterial, "integer literal") +TOKEN(DoubleLiterial, "floating-point literal") +TOKEN(StringLiterial, "string literal") +TOKEN(CharLiterial, "character literal") +TOKEN(WhiteSpace, "whitespace") +TOKEN(NewLine, "newline") +TOKEN(LineComment, "line comment") +TOKEN(BlockComment, "block comment") + +#define PUNCTUATION(id, text) \ + TOKEN(id, "'" text "'") + +PUNCTUATION(Semicolon, ";") +PUNCTUATION(Comma, ",") +PUNCTUATION(Dot, ".") + +PUNCTUATION(LBrace, "{") +PUNCTUATION(RBrace, "}") +PUNCTUATION(LBracket, "[") +PUNCTUATION(RBracket, "]") +PUNCTUATION(LParent, "(") +PUNCTUATION(RParent, ")") + +PUNCTUATION(OpAssign, "=") +PUNCTUATION(OpAdd, "+") +PUNCTUATION(OpSub, "-") +PUNCTUATION(OpMul, "*") +PUNCTUATION(OpDiv, "/") +PUNCTUATION(OpMod, "%") +PUNCTUATION(OpNot, "!") +PUNCTUATION(OpBitNot, "~") +PUNCTUATION(OpLsh, "<<") +PUNCTUATION(OpRsh, ">>") +PUNCTUATION(OpEql, "==") +PUNCTUATION(OpNeq, "!=") +PUNCTUATION(OpGreater, ">") +PUNCTUATION(OpLess, "<") +PUNCTUATION(OpGeq, ">=") +PUNCTUATION(OpLeq, "<=") +PUNCTUATION(OpAnd, "&&") +PUNCTUATION(OpOr, "||") +PUNCTUATION(OpBitAnd, "&") +PUNCTUATION(OpBitOr, "|") +PUNCTUATION(OpBitXor, "^") +PUNCTUATION(OpInc, "++") +PUNCTUATION(OpDec, "--") + +PUNCTUATION(OpAddAssign, "+=") +PUNCTUATION(OpSubAssign, "-=") +PUNCTUATION(OpMulAssign, "*=") +PUNCTUATION(OpDivAssign, "/=") +PUNCTUATION(OpModAssign, "%=") +PUNCTUATION(OpShlAssign, "<<=") +PUNCTUATION(OpShrAssign, ">>=") +PUNCTUATION(OpAndAssign, "&=") +PUNCTUATION(OpOrAssign, "|=") +PUNCTUATION(OpXorAssign, "^=") + +PUNCTUATION(QuestionMark, "?") +PUNCTUATION(Colon, ":") +PUNCTUATION(RightArrow, "->") +PUNCTUATION(At, "@") +PUNCTUATION(Dollar, "$") +PUNCTUATION(Pound, "#") +PUNCTUATION(PoundPound, "##") + +#undef PUNCTUATION + +// Un-define the `TOKEN` macro so that client doesn't have to +#undef TOKEN diff --git a/source/slang/token.cpp b/source/slang/token.cpp new file mode 100644 index 000000000..436a3a740 --- /dev/null +++ b/source/slang/token.cpp @@ -0,0 +1,22 @@ +// token.cpp +#include "token.h" + +#include <assert.h> + +namespace Slang { +namespace Compiler { + +char const* TokenTypeToString(TokenType type) +{ + switch( type ) + { + default: + assert(!"unexpected"); + return "<uknown>"; + +#define TOKEN(NAME, DESC) case TokenType::NAME: return DESC; +#include "token-defs.h" + } +} + +}} diff --git a/source/slang/token.h b/source/slang/token.h new file mode 100644 index 000000000..00a55feb1 --- /dev/null +++ b/source/slang/token.h @@ -0,0 +1,50 @@ +// token.h +#ifndef SLANG_TOKEN_H_INCLUDED +#define SLANG_TOKEN_H_INCLUDED + +#include "../core/basic.h" + +#include "source-loc.h" + +namespace Slang { +namespace Compiler { + +using namespace CoreLib::Basic; + +enum class TokenType +{ +#define TOKEN(NAME, DESC) NAME, +#include "token-defs.h" +}; + +char const* TokenTypeToString(TokenType type); + +enum TokenFlag : unsigned int +{ + AtStartOfLine = 1 << 0, + AfterWhitespace = 1 << 1, +}; +typedef unsigned int TokenFlags; + +class Token +{ +public: + TokenType Type = TokenType::Unknown; + String Content; + CodePosition Position; + TokenFlags flags = 0; + Token() = default; + Token(TokenType type, const String & content, int line, int col, int pos, String fileName, TokenFlags flags = 0) + : flags(flags) + { + Type = type; + Content = content; + Position = CodePosition(line, col, pos, fileName); + } +}; + + + +}} + +#endif diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp new file mode 100644 index 000000000..4e5c98ed5 --- /dev/null +++ b/source/slang/type-layout.cpp @@ -0,0 +1,1125 @@ +// TypeLayout.cpp +#include "type-layout.h" + +#include "syntax.h" + +#include <assert.h> + +namespace Slang { +namespace Compiler { + +size_t RoundToAlignment(size_t offset, size_t alignment) +{ + size_t remainder = offset % alignment; + if (remainder == 0) + return offset; + else + return offset + (alignment - remainder); +} + +static size_t RoundUpToPowerOfTwo( size_t value ) +{ + // TODO(tfoley): I know this isn't a fast approach + size_t result = 1; + while (result < value) + result *= 2; + return result; +} + +struct DefaultLayoutRulesImpl : SimpleLayoutRulesImpl +{ + // Get size and alignment for a single value of base type. + SimpleLayoutInfo GetScalarLayout(BaseType baseType) override + { + switch (baseType) + { + case BaseType::Int: + case BaseType::UInt: + case BaseType::Float: + case BaseType::Bool: + return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4, 4 ); + + default: + assert(!"unimplemented"); + return SimpleLayoutInfo( LayoutResourceKind::Uniform, 0, 1 ); + } + } + + virtual SimpleLayoutInfo GetScalarLayout(slang::TypeReflection::ScalarType scalarType) + { + switch( scalarType ) + { + case slang::TypeReflection::ScalarType::Void: return SimpleLayoutInfo(); + case slang::TypeReflection::ScalarType::None: return SimpleLayoutInfo(); + + // TODO(tfoley): At some point we don't want to lay out `bool` as 4 bytes by default... + case slang::TypeReflection::ScalarType::Bool: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4,4); + case slang::TypeReflection::ScalarType::Int32: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4,4); + case slang::TypeReflection::ScalarType::UInt32: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4,4); + case slang::TypeReflection::ScalarType::Int64: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 8,8); + case slang::TypeReflection::ScalarType::UInt64: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 8,8); + + // TODO(tfoley): What actually happens if you use `half` in a constant buffer? + case slang::TypeReflection::ScalarType::Float16: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 2,2); + case slang::TypeReflection::ScalarType::Float32: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4,4); + case slang::TypeReflection::ScalarType::Float64: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 8,8); + + default: + assert(!"unimplemented"); + return SimpleLayoutInfo(); + } + } + + SimpleArrayLayoutInfo GetArrayLayout( SimpleLayoutInfo elementInfo, size_t elementCount) override + { + size_t stride = elementInfo.size; + + SimpleArrayLayoutInfo arrayInfo; + arrayInfo.kind = elementInfo.kind; + arrayInfo.size = stride * elementCount; + arrayInfo.alignment = elementInfo.alignment; + arrayInfo.elementStride = stride; + return arrayInfo; + } + + SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) override + { + SimpleLayoutInfo vectorInfo; + vectorInfo.kind = elementInfo.kind; + vectorInfo.size = elementInfo.size * elementCount; + vectorInfo.alignment = elementInfo.alignment; + return vectorInfo; + } + + SimpleLayoutInfo GetMatrixLayout(SimpleLayoutInfo elementInfo, size_t rowCount, size_t columnCount) override + { + return GetArrayLayout( + GetVectorLayout(elementInfo, columnCount), + rowCount); + } + + UniformLayoutInfo BeginStructLayout() override + { + UniformLayoutInfo structInfo(0, 1); + return structInfo; + } + + size_t AddStructField(UniformLayoutInfo* ioStructInfo, UniformLayoutInfo fieldInfo) override + { + // Skip zero-size fields + if(fieldInfo.size == 0) + return ioStructInfo->size; + + ioStructInfo->alignment = std::max(ioStructInfo->alignment, fieldInfo.alignment); + ioStructInfo->size = RoundToAlignment(ioStructInfo->size, fieldInfo.alignment); + size_t fieldOffset = ioStructInfo->size; + ioStructInfo->size += fieldInfo.size; + return fieldOffset; + } + + + void EndStructLayout(UniformLayoutInfo* ioStructInfo) override + { + ioStructInfo->size = RoundToAlignment(ioStructInfo->size, ioStructInfo->alignment); + } +}; + +// Capture common behavior betwen HLSL and GLSL (`std140`) constnat buffer rules +struct DefaultConstantBufferLayoutRulesImpl : DefaultLayoutRulesImpl +{ + // The `std140` rules require that all array elements + // be a multiple of 16 bytes. + // + // HLSL agrees. + SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, size_t elementCount) override + { + if(elementInfo.kind == LayoutResourceKind::Uniform) + { + if (elementInfo.alignment < 16) + elementInfo.alignment = 16; + elementInfo.size = RoundToAlignment(elementInfo.size, elementInfo.alignment); + } + return DefaultLayoutRulesImpl::GetArrayLayout(elementInfo, elementCount); + } + + // The `std140` rules require that a `struct` type be + // aligned to at least 16. + // + // HLSL agrees. + UniformLayoutInfo BeginStructLayout() override + { + return UniformLayoutInfo(0, 16); + } +}; + +struct GLSLConstantBufferLayoutRulesImpl : DefaultConstantBufferLayoutRulesImpl +{ +}; + +struct Std140LayoutRulesImpl : GLSLConstantBufferLayoutRulesImpl +{ + // The `std140` rules require vectors to be aligned to the next power of two + // up from their size (so a `float2` is 8-byte aligned, and a `float3` is + // 16-byte aligned). + SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) override + { + assert(elementInfo.kind == LayoutResourceKind::Uniform); + SimpleLayoutInfo vectorInfo( + LayoutResourceKind::Uniform, + elementInfo.size * elementCount, + RoundUpToPowerOfTwo(elementInfo.size * elementInfo.alignment)); + return vectorInfo; + } +}; + +struct HLSLConstantBufferLayoutRulesImpl : DefaultConstantBufferLayoutRulesImpl +{ + // Can't let a `struct` field straddle a register (16-byte) boundary + size_t AddStructField(UniformLayoutInfo* ioStructInfo, UniformLayoutInfo fieldInfo) override + { + // Skip zero-size fields + if(fieldInfo.size == 0) + return ioStructInfo->size; + + ioStructInfo->alignment = std::max(ioStructInfo->alignment, fieldInfo.alignment); + ioStructInfo->size = RoundToAlignment(ioStructInfo->size, fieldInfo.alignment); + + size_t fieldOffset = ioStructInfo->size; + size_t fieldSize = fieldInfo.size; + + // Would this field cross a 16-byte boundary? + auto registerSize = 16; + auto startRegister = fieldOffset / registerSize; + auto endRegister = (fieldOffset + fieldSize - 1) / registerSize; + if (startRegister != endRegister) + { + ioStructInfo->size = RoundToAlignment(ioStructInfo->size, size_t(registerSize)); + fieldOffset = ioStructInfo->size; + } + + ioStructInfo->size += fieldInfo.size; + return fieldOffset; + } +}; + +struct HLSLStructuredBufferLayoutRulesImpl : DefaultLayoutRulesImpl +{ + // TODO: customize these to be correct... +}; + +struct Std430LayoutRulesImpl : GLSLConstantBufferLayoutRulesImpl +{ +}; + +struct DefaultVaryingLayoutRulesImpl : DefaultLayoutRulesImpl +{ + LayoutResourceKind kind; + + DefaultVaryingLayoutRulesImpl(LayoutResourceKind kind) + : kind(kind) + {} + + + // hook to allow differentiating for input/output + virtual LayoutResourceKind getKind() + { + return kind; + } + + SimpleLayoutInfo GetScalarLayout(BaseType baseType) override + { + // Assume that all scalars take up one "slot" + return SimpleLayoutInfo( + getKind(), + 1); + } + + virtual SimpleLayoutInfo GetScalarLayout(slang::TypeReflection::ScalarType scalarType) + { + // Assume that all scalars take up one "slot" + return SimpleLayoutInfo( + getKind(), + 1); + } + + SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) override + { + // Vectors take up one slot by default + // + // TODO: some platforms may decide that vectors of `double` need + // special handling + return SimpleLayoutInfo( + getKind(), + 1); + } +}; + +struct GLSLVaryingLayoutRulesImpl : DefaultVaryingLayoutRulesImpl +{ + GLSLVaryingLayoutRulesImpl(LayoutResourceKind kind) + : DefaultVaryingLayoutRulesImpl(kind) + {} +}; + +struct HLSLVaryingLayoutRulesImpl : DefaultVaryingLayoutRulesImpl +{ + HLSLVaryingLayoutRulesImpl(LayoutResourceKind kind) + : DefaultVaryingLayoutRulesImpl(kind) + {} +}; + +// + +struct GLSLSpecializationConstantLayoutRulesImpl : DefaultLayoutRulesImpl +{ + LayoutResourceKind getKind() + { + return LayoutResourceKind::SpecializationConstant; + } + + SimpleLayoutInfo GetScalarLayout(BaseType baseType) override + { + // Assume that all scalars take up one "slot" + return SimpleLayoutInfo( + getKind(), + 1); + } + + virtual SimpleLayoutInfo GetScalarLayout(slang::TypeReflection::ScalarType scalarType) + { + // Assume that all scalars take up one "slot" + return SimpleLayoutInfo( + getKind(), + 1); + } + + SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) override + { + // GLSL doesn't support vectors of specialization constants, + // but we will assume that, if supported, they would use one slot per element. + return SimpleLayoutInfo( + getKind(), + elementCount); + } +}; + +GLSLSpecializationConstantLayoutRulesImpl kGLSLSpecializationConstantLayoutRulesImpl; + +// + +struct GLSLObjectLayoutRulesImpl : ObjectLayoutRulesImpl +{ + virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind kind) override + { + // In Vulkan GLSL, pretty much every object is just a descriptor-table slot. + // We can refine this method once we support a case where this isn't true. + return SimpleLayoutInfo(LayoutResourceKind::DescriptorTableSlot, 1); + } +}; +GLSLObjectLayoutRulesImpl kGLSLObjectLayoutRulesImpl; + +struct HLSLObjectLayoutRulesImpl : ObjectLayoutRulesImpl +{ + virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind kind) override + { + switch( kind ) + { + case ShaderParameterKind::ConstantBuffer: + return SimpleLayoutInfo(LayoutResourceKind::ConstantBuffer, 1); + + case ShaderParameterKind::TextureUniformBuffer: + case ShaderParameterKind::StructuredBuffer: + case ShaderParameterKind::SampledBuffer: + case ShaderParameterKind::RawBuffer: + case ShaderParameterKind::Buffer: + case ShaderParameterKind::Texture: + return SimpleLayoutInfo(LayoutResourceKind::ShaderResource, 1); + + case ShaderParameterKind::MutableStructuredBuffer: + case ShaderParameterKind::MutableSampledBuffer: + case ShaderParameterKind::MutableRawBuffer: + case ShaderParameterKind::MutableBuffer: + case ShaderParameterKind::MutableTexture: + return SimpleLayoutInfo(LayoutResourceKind::UnorderedAccess, 1); + + case ShaderParameterKind::SamplerState: + return SimpleLayoutInfo(LayoutResourceKind::SamplerState, 1); + + case ShaderParameterKind::TextureSampler: + case ShaderParameterKind::MutableTextureSampler: + case ShaderParameterKind::InputRenderTarget: + // TODO: how to handle these? + default: + assert(!"unimplemented"); + return SimpleLayoutInfo(); + } + } +}; +HLSLObjectLayoutRulesImpl kHLSLObjectLayoutRulesImpl; + +Std140LayoutRulesImpl kStd140LayoutRulesImpl; +Std430LayoutRulesImpl kStd430LayoutRulesImpl; +HLSLConstantBufferLayoutRulesImpl kHLSLConstantBufferLayoutRulesImpl; +HLSLStructuredBufferLayoutRulesImpl kHLSLStructuredBufferLayoutRulesImpl; + +GLSLVaryingLayoutRulesImpl kGLSLVaryingInputLayoutRulesImpl(LayoutResourceKind::VertexInput); +GLSLVaryingLayoutRulesImpl kGLSLVaryingOutputLayoutRulesImpl(LayoutResourceKind::FragmentOutput); + +HLSLVaryingLayoutRulesImpl kHLSLVaryingInputLayoutRulesImpl(LayoutResourceKind::VertexInput); +HLSLVaryingLayoutRulesImpl kHLSLVaryingOutputLayoutRulesImpl(LayoutResourceKind::FragmentOutput); + +// + +struct GLSLLayoutRulesFamilyImpl : LayoutRulesFamilyImpl +{ + virtual LayoutRulesImpl* getConstantBufferRules() override; + virtual LayoutRulesImpl* getTextureBufferRules() override; + virtual LayoutRulesImpl* getVaryingInputRules() override; + virtual LayoutRulesImpl* getVaryingOutputRules() override; + virtual LayoutRulesImpl* getSpecializationConstantRules() override; + virtual LayoutRulesImpl* getShaderStorageBufferRules() override; +}; + +struct HLSLLayoutRulesFamilyImpl : LayoutRulesFamilyImpl +{ + virtual LayoutRulesImpl* getConstantBufferRules() override; + virtual LayoutRulesImpl* getTextureBufferRules() override; + virtual LayoutRulesImpl* getVaryingInputRules() override; + virtual LayoutRulesImpl* getVaryingOutputRules() override; + virtual LayoutRulesImpl* getSpecializationConstantRules() override; + virtual LayoutRulesImpl* getShaderStorageBufferRules() override; +}; + +GLSLLayoutRulesFamilyImpl kGLSLLayoutRulesFamilyImpl; +HLSLLayoutRulesFamilyImpl kHLSLLayoutRulesFamilyImpl; + + +// GLSL cases + +LayoutRulesImpl kStd140LayoutRulesImpl_ = { + &kGLSLLayoutRulesFamilyImpl, &kStd140LayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kStd430LayoutRulesImpl_ = { + &kGLSLLayoutRulesFamilyImpl, &kStd430LayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kGLSLVaryingInputLayoutRulesImpl_ = { + &kGLSLLayoutRulesFamilyImpl, &kGLSLVaryingInputLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kGLSLVaryingOutputLayoutRulesImpl_ = { + &kGLSLLayoutRulesFamilyImpl, &kGLSLVaryingOutputLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kGLSLSpecializationConstantLayoutRulesImpl_ = { + &kGLSLLayoutRulesFamilyImpl, &kGLSLSpecializationConstantLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl, +}; + +// HLSL cases + +LayoutRulesImpl kHLSLConstantBufferLayoutRulesImpl_ = { + &kHLSLLayoutRulesFamilyImpl, &kHLSLConstantBufferLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kHLSLStructuredBufferLayoutRulesImpl_ = { + &kHLSLLayoutRulesFamilyImpl, &kHLSLStructuredBufferLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kHLSLVaryingInputLayoutRulesImpl_ = { + &kHLSLLayoutRulesFamilyImpl, &kHLSLVaryingInputLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, +}; + +LayoutRulesImpl kHLSLVaryingOutputLayoutRulesImpl_ = { + &kHLSLLayoutRulesFamilyImpl, &kHLSLVaryingOutputLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl, +}; + +// + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getConstantBufferRules() +{ + return &kStd140LayoutRulesImpl_; +} + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getTextureBufferRules() +{ + return nullptr; +} + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getVaryingInputRules() +{ + return &kGLSLVaryingInputLayoutRulesImpl_; +} + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getVaryingOutputRules() +{ + return &kGLSLVaryingOutputLayoutRulesImpl_; +} + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getSpecializationConstantRules() +{ + return &kGLSLSpecializationConstantLayoutRulesImpl_; +} + +LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getShaderStorageBufferRules() +{ + return &kStd430LayoutRulesImpl_; +} + +// + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getConstantBufferRules() +{ + return &kHLSLConstantBufferLayoutRulesImpl_; +} + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getTextureBufferRules() +{ + return nullptr; +} + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getVaryingInputRules() +{ + return &kHLSLVaryingInputLayoutRulesImpl_; +} + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getVaryingOutputRules() +{ + return &kHLSLVaryingOutputLayoutRulesImpl_; +} + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getSpecializationConstantRules() +{ + return nullptr; +} + +LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getShaderStorageBufferRules() +{ + return nullptr; +} + +// + +LayoutRulesImpl* GetLayoutRulesImpl(LayoutRule rule) +{ + switch (rule) + { + case LayoutRule::Std140: return &kStd140LayoutRulesImpl_; + case LayoutRule::Std430: return &kStd430LayoutRulesImpl_; + case LayoutRule::HLSLConstantBuffer: return &kHLSLConstantBufferLayoutRulesImpl_; + case LayoutRule::HLSLStructuredBuffer: return &kHLSLStructuredBufferLayoutRulesImpl_; + default: + return nullptr; + } +} + +LayoutRulesFamilyImpl* GetLayoutRulesFamilyImpl(LayoutRulesFamily rule) +{ + switch (rule) + { + case LayoutRulesFamily::HLSL: return &kHLSLLayoutRulesFamilyImpl; + case LayoutRulesFamily::GLSL: return &kGLSLLayoutRulesFamilyImpl; + default: + return nullptr; + } +} + +LayoutRulesFamilyImpl* GetLayoutRulesFamilyImpl(SourceLanguage language) +{ + switch (language) + { + case SourceLanguage::Slang: + case SourceLanguage::HLSL: + return &kHLSLLayoutRulesFamilyImpl; + + case SourceLanguage::GLSL: + return &kGLSLLayoutRulesFamilyImpl; + + default: + return nullptr; + } +} + + +static int GetElementCount(RefPtr<IntVal> val) +{ + if (auto constantVal = val.As<ConstantIntVal>()) + { + return constantVal->value; + } + else if( auto varRefVal = val.As<GenericParamIntVal>() ) + { + // TODO(tfoley): do something sensible in this case + return 0; + } + assert(!"unexpected"); + return 0; +} + +bool IsResourceKind(LayoutResourceKind kind) +{ + switch (kind) + { + case LayoutResourceKind::None: + case LayoutResourceKind::Uniform: + return false; + + default: + return true; + } + +} + +SimpleLayoutInfo GetSimpleLayoutImpl( + SimpleLayoutInfo info, + RefPtr<ExpressionType> type, + LayoutRulesImpl* rules, + RefPtr<TypeLayout>* outTypeLayout) +{ + if (outTypeLayout) + { + RefPtr<TypeLayout> typeLayout = new TypeLayout(); + *outTypeLayout = typeLayout; + + typeLayout->type = type; + typeLayout->rules = rules; + + typeLayout->uniformAlignment = info.alignment; + + typeLayout->addResourceUsage(info.kind, info.size); + } + + return info; +} + +static SimpleLayoutInfo getParameterBlockLayoutInfo( + RefPtr<ParameterBlockType> type, + LayoutRulesImpl* rules) +{ + if( type->As<ConstantBufferType>() ) + { + return rules->GetObjectLayout(ShaderParameterKind::ConstantBuffer); + } + else if( type->As<TextureBufferType>() ) + { + return rules->GetObjectLayout(ShaderParameterKind::TextureUniformBuffer); + } + else if( type->As<GLSLShaderStorageBufferType>() ) + { + return rules->GetObjectLayout(ShaderParameterKind::ShaderStorageBuffer); + } + // TODO: the vertex-input and fragment-output cases should + // only actually apply when we are at the appropriate stage in + // the pipeline... + else if( type->As<GLSLInputParameterBlockType>() ) + { + return SimpleLayoutInfo(LayoutResourceKind::VertexInput, 0); + } + else if( type->As<GLSLOutputParameterBlockType>() ) + { + return SimpleLayoutInfo(LayoutResourceKind::FragmentOutput, 0); + } + else + { + assert(!"unexpected"); + return SimpleLayoutInfo(); + } +} + + +RefPtr<ParameterBlockTypeLayout> +createParameterBlockTypeLayout( + RefPtr<ParameterBlockType> parameterBlockType, + RefPtr<TypeLayout> elementTypeLayout, + LayoutRulesImpl* rules) +{ + auto info = getParameterBlockLayoutInfo( + parameterBlockType, + rules); + + auto typeLayout = new ParameterBlockTypeLayout(); + + typeLayout->type = parameterBlockType; + typeLayout->rules = rules; + + typeLayout->elementTypeLayout = elementTypeLayout; + + // The layout of the constant buffer if it gets stored + // in another constant buffer is just what we computed + // originally (which should be a single binding "slot" + // and hence no uniform data). + // + typeLayout->uniformAlignment = info.alignment; + assert(!typeLayout->FindResourceInfo(LayoutResourceKind::Uniform)); + assert(typeLayout->uniformAlignment == 1); + + // TODO(tfoley): There is a subtle question here of whether + // a constant buffer declaration that then contains zero + // bytes of uniform data should actually allocate a CB + // binding slot. For now I'm going to try to ignore it, + // but handling this robustly could let other code + // simply handle the "global scope" as a giant outer + // CB declaration... + + // Make sure that we allocate resource usage for the + // parameter block itself. + if( info.size ) + { + typeLayout->addResourceUsage( + info.kind, + info.size); + } + + // Now, if the element type itself had any resources, then + // we need to make these part of the layout for our block + // + // TODO: re-consider this decision, since it creates + // complications... + for( auto elementResourceInfo : elementTypeLayout->resourceInfos ) + { + // Skip uniform data, since that is encapsualted behind the constant buffer + if(elementResourceInfo.kind == LayoutResourceKind::Uniform) + break; + + typeLayout->addResourceUsage(elementResourceInfo); + } + + return typeLayout; +} + +LayoutRulesImpl* getParameterBufferElementTypeLayoutRules( + RefPtr<ParameterBlockType> parameterBlockType, + LayoutRulesImpl* rules) +{ + if( parameterBlockType->As<ConstantBufferType>() ) + { + return rules->getLayoutRulesFamily()->getConstantBufferRules(); + } + else if( parameterBlockType->As<TextureBufferType>() ) + { + return rules->getLayoutRulesFamily()->getTextureBufferRules(); + } + else if( parameterBlockType->As<GLSLInputParameterBlockType>() ) + { + return rules->getLayoutRulesFamily()->getVaryingInputRules(); + } + else if( parameterBlockType->As<GLSLOutputParameterBlockType>() ) + { + return rules->getLayoutRulesFamily()->getVaryingOutputRules(); + } + else if( parameterBlockType->As<GLSLShaderStorageBufferType>() ) + { + return rules->getLayoutRulesFamily()->getShaderStorageBufferRules(); + } + else + { + assert(!"unexpected"); + return nullptr; + } +} + +RefPtr<ParameterBlockTypeLayout> +createParameterBlockTypeLayout( + RefPtr<ParameterBlockType> parameterBlockType, + LayoutRulesImpl* rules) +{ + // Determine the layout rules to use for the contents of the block + auto parameterBlockLayoutRules = getParameterBufferElementTypeLayoutRules( + parameterBlockType, + rules); + + // Create and save type layout for the buffer contents. + auto elementTypeLayout = CreateTypeLayout( + parameterBlockType->elementType.Ptr(), + parameterBlockLayoutRules); + + return createParameterBlockTypeLayout( + parameterBlockType, + elementTypeLayout, + rules); +} + +// Create a type layout for a structured buffer type. +RefPtr<StructuredBufferTypeLayout> +createStructuredBufferTypeLayout( + ShaderParameterKind kind, + RefPtr<ExpressionType> structuredBufferType, + RefPtr<TypeLayout> elementTypeLayout, + LayoutRulesImpl* rules) +{ + auto info = rules->GetObjectLayout(kind); + + auto typeLayout = new StructuredBufferTypeLayout(); + + typeLayout->type = structuredBufferType; + typeLayout->rules = rules; + + typeLayout->elementTypeLayout = elementTypeLayout; + + typeLayout->uniformAlignment = info.alignment; + assert(!typeLayout->FindResourceInfo(LayoutResourceKind::Uniform)); + assert(typeLayout->uniformAlignment == 1); + + if( info.size != 0 ) + { + typeLayout->addResourceUsage(info.kind, info.size); + } + + // Note: for now we don't deal with the case of a structured + // buffer that might contain anything other than "uniform" data, + // because there really isn't a way to implement that. + + return typeLayout; +} + +// Create a type layout for a structured buffer type. +RefPtr<StructuredBufferTypeLayout> +createStructuredBufferTypeLayout( + ShaderParameterKind kind, + RefPtr<ExpressionType> structuredBufferType, + RefPtr<ExpressionType> elementType, + LayoutRulesImpl* rules) +{ + // TODO(tfoley): need to compute the layout for the constant + // buffer's contents... + auto structuredBufferLayoutRules = GetLayoutRulesImpl( + LayoutRule::HLSLStructuredBuffer); + + // Create and save type layout for the buffer contents. + auto elementTypeLayout = CreateTypeLayout( + elementType.Ptr(), + structuredBufferLayoutRules); + + return createStructuredBufferTypeLayout( + kind, + structuredBufferType, + elementTypeLayout, + rules); + +} + +SimpleLayoutInfo GetLayoutImpl( + ExpressionType* type, + LayoutRulesImpl* rules, + RefPtr<TypeLayout>* outTypeLayout) +{ + if (auto parameterBlockType = type->As<ParameterBlockType>()) + { + // If the user is just interested in uniform layout info, + // then this is easy: a `ConstantBuffer<T>` is really no + // different from a `Texture2D<U>` in terms of how it + // should be handled as a member of a container. + // + auto info = getParameterBlockLayoutInfo(parameterBlockType, rules); + + // The more interesting case, though, is when the user + // is requesting us to actually create a `TypeLayout`, + // since in that case we need to: + // + // 1. Compute a layout for the data inside the constant + // buffer, including offsets, etc. + // + // 2. Compute information about any object types inside + // the constant buffer, which need to be surfaces out + // to the top level. + // + if (outTypeLayout) + { + *outTypeLayout = createParameterBlockTypeLayout( + parameterBlockType, + rules); + } + + return info; + } + else if (auto samplerStateType = type->As<SamplerStateType>()) + { + return GetSimpleLayoutImpl( + rules->GetObjectLayout(ShaderParameterKind::SamplerState), + type, + rules, + outTypeLayout); + } + else if (auto textureType = type->As<TextureType>()) + { + // TODO: the logic here should really be defined by the rules, + // and not at this top level... + ShaderParameterKind kind; + switch( textureType->getAccess() ) + { + default: + kind = ShaderParameterKind::MutableTexture; + break; + + case SLANG_RESOURCE_ACCESS_READ: + kind = ShaderParameterKind::Texture; + break; + } + + return GetSimpleLayoutImpl( + rules->GetObjectLayout(kind), + type, + rules, + outTypeLayout); + } + else if (auto textureSamplerType = type->As<TextureSamplerType>()) + { + // TODO: the logic here should really be defined by the rules, + // and not at this top level... + ShaderParameterKind kind; + switch( textureSamplerType->getAccess() ) + { + default: + kind = ShaderParameterKind::MutableTextureSampler; + break; + + case SLANG_RESOURCE_ACCESS_READ: + kind = ShaderParameterKind::TextureSampler; + break; + } + + return GetSimpleLayoutImpl( + rules->GetObjectLayout(kind), + type, + rules, + outTypeLayout); + } + + // TODO: need a better way to handle this stuff... +#define CASE(TYPE, KIND) \ + else if(auto type_##TYPE = type->As<TYPE>()) do { \ + auto info = rules->GetObjectLayout(ShaderParameterKind::KIND); \ + if (outTypeLayout) \ + { \ + *outTypeLayout = createStructuredBufferTypeLayout( \ + ShaderParameterKind::KIND, \ + type_##TYPE, \ + type_##TYPE->elementType.Ptr(), \ + rules); \ + } \ + return info; \ + } while(0) + + CASE(HLSLStructuredBufferType, StructuredBuffer); + CASE(HLSLRWStructuredBufferType, MutableStructuredBuffer); + CASE(HLSLAppendStructuredBufferType, MutableStructuredBuffer); + CASE(HLSLConsumeStructuredBufferType, MutableStructuredBuffer); + +#undef CASE + + + // TODO: need a better way to handle this stuff... +#define CASE(TYPE, KIND) \ + else if(type->As<TYPE>()) do { \ + return GetSimpleLayoutImpl( \ + rules->GetObjectLayout(ShaderParameterKind::KIND), \ + type, rules, outTypeLayout); \ + } while(0) + + CASE(HLSLBufferType, SampledBuffer); + CASE(HLSLRWBufferType, MutableSampledBuffer); + CASE(HLSLByteAddressBufferType, RawBuffer); + CASE(HLSLRWByteAddressBufferType, MutableRawBuffer); + + CASE(GLSLInputAttachmentType, InputRenderTarget); + + // This case is mostly to allow users to add new resource types... + CASE(UntypedBufferResourceType, RawBuffer); + +#undef CASE + + // + // TODO(tfoley): Need to recognize any UAV types here + // + else if(auto basicType = type->As<BasicExpressionType>()) + { + return GetSimpleLayoutImpl( + rules->GetScalarLayout(basicType->BaseType), + type, + rules, + outTypeLayout); + } + else if(auto vecType = type->As<VectorExpressionType>()) + { + return GetSimpleLayoutImpl( + rules->GetVectorLayout( + GetLayout(vecType->elementType.Ptr(), rules), + GetIntVal(vecType->elementCount)), + type, + rules, + outTypeLayout); + } + else if(auto matType = type->As<MatrixExpressionType>()) + { + return GetSimpleLayoutImpl( + rules->GetMatrixLayout( + GetLayout(matType->getElementType(), rules), + GetIntVal(matType->getRowCount()), + GetIntVal(matType->getColumnCount())), + type, + rules, + outTypeLayout); + } + else if (auto arrayType = type->As<ArrayExpressionType>()) + { + RefPtr<TypeLayout> elementTypeLayout; + auto elementInfo = GetLayoutImpl( + arrayType->BaseType.Ptr(), + rules, + outTypeLayout ? &elementTypeLayout : nullptr); + + // For layout purposes, we treat an unsized array as an array of zero elements. + // + // TODO: Longer term we are going to need to be careful to include some indication + // that a type has logically "infinite" size in some resource kind. In particular + // this affects how we would allocate space for parameter binding purposes. + auto elementCount = arrayType->ArrayLength ? GetElementCount(arrayType->ArrayLength) : 0; + auto arrayUniformInfo = rules->GetArrayLayout( + elementInfo, + elementCount).getUniformLayout(); + + if (outTypeLayout) + { + RefPtr<ArrayTypeLayout> typeLayout = new ArrayTypeLayout(); + *outTypeLayout = typeLayout; + + typeLayout->type = type; + typeLayout->elementTypeLayout = elementTypeLayout; + typeLayout->rules = rules; + + typeLayout->uniformAlignment = arrayUniformInfo.alignment; + typeLayout->uniformStride = arrayUniformInfo.elementStride; + + typeLayout->addResourceUsage(LayoutResourceKind::Uniform, arrayUniformInfo.size); + + // translate element-type resources into array-type resources + for( auto elementResourceInfo : elementTypeLayout->resourceInfos ) + { + // The uniform case was already handled above + if( elementResourceInfo.kind == LayoutResourceKind::Uniform ) + continue; + + typeLayout->addResourceUsage( + elementResourceInfo.kind, + elementResourceInfo.count * elementCount); + } + } + return arrayUniformInfo; + } + else if (auto declRefType = type->As<DeclRefType>()) + { + auto declRef = declRefType->declRef; + + if (auto structDeclRef = declRef.As<StructDeclRef>()) + { + RefPtr<StructTypeLayout> typeLayout; + if (outTypeLayout) + { + typeLayout = new StructTypeLayout(); + typeLayout->type = type; + typeLayout->rules = rules; + *outTypeLayout = typeLayout; + } + + UniformLayoutInfo info = rules->BeginStructLayout(); + + for (auto field : structDeclRef.GetFields()) + { + RefPtr<TypeLayout> fieldTypeLayout; + UniformLayoutInfo fieldInfo = GetLayoutImpl( + field.GetType().Ptr(), + rules, + outTypeLayout ? &fieldTypeLayout : nullptr).getUniformLayout(); + + // Note: we don't add any zero-size fields + // when computing structure layout, just + // to avoid having a resource type impact + // the final layout. + // + // This means that the code to generate final + // declarations needs to *also* eliminate zero-size + // fields to be safe... + size_t uniformOffset = info.size; + if(fieldInfo.size != 0) + { + uniformOffset = rules->AddStructField(&info, fieldInfo); + } + + if (outTypeLayout) + { + // If we are computing a complete layout, + // then we need to create variable layouts + // for each field of the structure. + RefPtr<VarLayout> fieldLayout = new VarLayout(); + fieldLayout->varDecl = field; + fieldLayout->typeLayout = fieldTypeLayout; + typeLayout->fields.Add(fieldLayout); + typeLayout->mapVarToLayout.Add(field.GetDecl(), fieldLayout); + + // Set up uniform offset information, if there is any uniform data in the field + if( fieldTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) + { + fieldLayout->AddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset; + } + + // Add offset information for any other resource kinds + for( auto fieldTypeResourceInfo : fieldTypeLayout->resourceInfos ) + { + // Uniforms were dealt with above + if(fieldTypeResourceInfo.kind == LayoutResourceKind::Uniform) + continue; + + // We should not have already processed this resource type + assert(!fieldLayout->FindResourceInfo(fieldTypeResourceInfo.kind)); + + // The field will need offset information for this kind + auto fieldResourceInfo = fieldLayout->AddResourceInfo(fieldTypeResourceInfo.kind); + + // Check how many slots of the given kind have already been added to the type + auto structTypeResourceInfo = typeLayout->findOrAddResourceInfo(fieldTypeResourceInfo.kind); + fieldResourceInfo->index = structTypeResourceInfo->count; + structTypeResourceInfo->count += fieldTypeResourceInfo.count; + } + } + } + + rules->EndStructLayout(&info); + if (outTypeLayout) + { + typeLayout->uniformAlignment = info.alignment; + typeLayout->addResourceUsage(LayoutResourceKind::Uniform, info.size); + } + + return info; + } + } + + // catch-all case in case nothing matched + assert(!"unimplemented"); + SimpleLayoutInfo info; + return GetSimpleLayoutImpl( + info, + type, + rules, + outTypeLayout); +} + +SimpleLayoutInfo GetLayout(ExpressionType* inType, LayoutRulesImpl* rules) +{ + return GetLayoutImpl(inType, rules, nullptr); +} + +RefPtr<TypeLayout> CreateTypeLayout(ExpressionType* type, LayoutRulesImpl* rules) +{ + RefPtr<TypeLayout> typeLayout; + GetLayoutImpl(type, rules, &typeLayout); + return typeLayout; +} + +SimpleLayoutInfo GetLayout(ExpressionType* type, LayoutRule rule) +{ + LayoutRulesImpl* rulesImpl = GetLayoutRulesImpl(rule); + return GetLayout(type, rulesImpl); +} + +}} diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h new file mode 100644 index 000000000..be54bbf53 --- /dev/null +++ b/source/slang/type-layout.h @@ -0,0 +1,550 @@ +#ifndef SLANG_TYPE_LAYOUT_H +#define SLANG_TYPE_LAYOUT_H + +#include "../core/basic.h" +#include "profile.h" +#include "syntax.h" + +#include "../../slang.h" + +namespace Slang { + +typedef intptr_t Int; +typedef uintptr_t UInt; + +namespace Compiler { + +// Forward declarations + +enum class BaseType; +class ExpressionType; + +// + +enum class LayoutRule +{ + Std140, + Std430, + HLSLConstantBuffer, + HLSLStructuredBuffer, +}; + +enum class LayoutRulesFamily +{ + HLSL, + GLSL, +}; + +// Layout appropriate to "just memory" scenarios, +// such as laying out the members of a constant buffer. +struct UniformLayoutInfo +{ + size_t size; + size_t alignment; + + UniformLayoutInfo() + : size(0) + , alignment(1) + {} + + UniformLayoutInfo( + size_t size, + size_t alignment) + : size(size) + , alignment(alignment) + {} +}; + +// Extended information required for an array of uniform data, +// including the "stride" of the array (the space between +// consecutive elements). +struct UniformArrayLayoutInfo : UniformLayoutInfo +{ + size_t elementStride; + + UniformArrayLayoutInfo() + : elementStride(0) + {} + + UniformArrayLayoutInfo( + size_t size, + size_t alignment, + size_t elementStride) + : UniformLayoutInfo(size, alignment) + , elementStride(elementStride) + {} +}; + +typedef slang::ParameterCategory LayoutResourceKind; + +// Layout information for a value that only consumes +// a single reosurce kind. +struct SimpleLayoutInfo +{ + // What kind of resource should we consume? + LayoutResourceKind kind; + + // How many resources of that kind? + size_t size; + + // only useful in the uniform case + size_t alignment; + + SimpleLayoutInfo() + : kind(LayoutResourceKind::None) + , size(0) + , alignment(1) + {} + + SimpleLayoutInfo( + UniformLayoutInfo uniformInfo) + : kind(LayoutResourceKind::Uniform) + , size(uniformInfo.size) + , alignment(uniformInfo.alignment) + {} + + SimpleLayoutInfo(LayoutResourceKind kind, size_t size, size_t alignment=1) + : kind(kind) + , size(size) + , alignment(alignment) + {} + + // Convert to layout for uniform data + UniformLayoutInfo getUniformLayout() + { + if(kind == LayoutResourceKind::Uniform) + { + return UniformLayoutInfo(size, alignment); + } + else + { + return UniformLayoutInfo(0, 1); + } + } +}; + +// Only useful in the case of a homogeneous array +struct SimpleArrayLayoutInfo : SimpleLayoutInfo +{ + // This field is only useful in the uniform case + size_t elementStride; + + // Convert to layout for uniform data + UniformArrayLayoutInfo getUniformLayout() + { + if(kind == LayoutResourceKind::Uniform) + { + return UniformArrayLayoutInfo(size, alignment, elementStride); + } + else + { + return UniformArrayLayoutInfo(0, 1, 0); + } + } +}; + +struct LayoutRulesImpl; + +// A reified reprsentation of a particular laid-out type +class TypeLayout : public RefObject +{ +public: + // The type that was laid out + RefPtr<ExpressionType> type; + ExpressionType* getType() { return type.Ptr(); } + + // The layout rules that were used to produce this type + LayoutRulesImpl* rules; + + struct ResourceInfo + { + // What kind of register was it? + LayoutResourceKind kind = LayoutResourceKind::None; + + // How many registers of the above kind did we use? + UInt count; + }; + + List<ResourceInfo> resourceInfos; + + // For uniform data, alignment matters, but not for + // any other resource category, so we don't waste + // the space storing it in the above array + UInt uniformAlignment = 1; + + ResourceInfo* FindResourceInfo(LayoutResourceKind kind) + { + for(auto& rr : resourceInfos) + { + if(rr.kind == kind) + return &rr; + } + return nullptr; + } + + ResourceInfo* findOrAddResourceInfo(LayoutResourceKind kind) + { + auto existing = FindResourceInfo(kind); + if(existing) return existing; + + ResourceInfo info; + info.kind = kind; + info.count = 0; + resourceInfos.Add(info); + return &resourceInfos.Last(); + } + + void addResourceUsage(ResourceInfo info) + { + if(info.count == 0) return; + + findOrAddResourceInfo(info.kind)->count += info.count; + } + + void addResourceUsage(LayoutResourceKind kind, UInt count) + { + ResourceInfo info; + info.kind = kind; + info.count = count; + addResourceUsage(info); + } +}; + +typedef unsigned int VarLayoutFlags; +enum VarLayoutFlag : VarLayoutFlags +{ + IsRedeclaration = 1 << 0, ///< This is a redeclaration of some shader parameter +}; + +// A reified layout for a particular variable, field, etc. +class VarLayout : public RefObject +{ +public: + // The variable we are laying out + VarDeclBaseRef varDecl; + VarDeclBase* getVariable() { return varDecl.GetDecl(); } + + String const& getName() { return getVariable()->getName(); } + + // The result of laying out the variable's type + RefPtr<TypeLayout> typeLayout; + TypeLayout* getTypeLayout() { return typeLayout.Ptr(); } + + // Additional flags + VarLayoutFlags flags = 0; + + // The start register(s) for any resources + struct ResourceInfo + { + // What kind of register was it? + LayoutResourceKind kind = LayoutResourceKind::None; + + // What binding space (HLSL) or set (Vulkan) are we placed in? + UInt space; + + // What is our starting register in that space? + // + // (In the case of uniform data, this is a byte offset) + UInt index; + }; + List<ResourceInfo> resourceInfos; + + ResourceInfo* FindResourceInfo(LayoutResourceKind kind) + { + for(auto& rr : resourceInfos) + { + if(rr.kind == kind) + return &rr; + } + return nullptr; + } + + ResourceInfo* AddResourceInfo(LayoutResourceKind kind) + { + ResourceInfo info; + info.kind = kind; + info.space = 0; + info.index = 0; + + resourceInfos.Add(info); + return &resourceInfos.Last(); + } + + ResourceInfo* findOrAddResourceInfo(LayoutResourceKind kind) + { + auto existing = FindResourceInfo(kind); + if(existing) return existing; + + return AddResourceInfo(kind); + } +}; + +// Type layout for a variable that has a constant-buffer type +class ParameterBlockTypeLayout : public TypeLayout +{ +public: + RefPtr<TypeLayout> elementTypeLayout; +}; + +// Type layout for a variable that has a constant-buffer type +class StructuredBufferTypeLayout : public TypeLayout +{ +public: + RefPtr<TypeLayout> elementTypeLayout; +}; + +// Specific case of type layout for an array +class ArrayTypeLayout : public TypeLayout +{ +public: + // The layout used for the element type + RefPtr<TypeLayout> elementTypeLayout; + + // the stride between elements when used in + // a uniform buffer + size_t uniformStride; +}; + +// Specific case of type layout for a struct +class StructTypeLayout : public TypeLayout +{ +public: + // An ordered list of layouts for the known fields + List<RefPtr<VarLayout>> fields; + + // Map a variable to its layout directly. + // + // Note that in the general case, there may be entries + // in the `fields` array that came from multiple + // translation units, and in cases where there are + // multiple declarations of the same parameter, only + // one will appear in `fields`, while all of + // them will be reflected in `mapVarToLayout`. + // + Dictionary<Decl*, RefPtr<VarLayout>> mapVarToLayout; +}; + +// Layout information for a single shader entry point +// within a program +class EntryPointLayout : public RefObject +{ +public: + // The corresponding function declaration + RefPtr<FunctionSyntaxNode> entryPoint; + + // The shader profile that was used to compile the entry point + Profile profile; +}; + +// Layout information for the global scope of a program +class ProgramLayout : public RefObject +{ +public: + // We store a layout for the declarations at the global + // scope. Note that this will *either* be a single + // `StructTypeLayout` with the fields stored directly, + // or it will be a single `ParameterBlockTypeLayout`, + // where the global-scope fields are the members of + // that constant buffer. + // + // The `struct` case will be used if there are no + // "naked" global-scope uniform variables, and the + // constant-buffer case will be used if there are + // (since a constant buffer will have to be allocated + // to store them). + // + RefPtr<TypeLayout> globalScopeLayout; + + // We catalog the requested entry points here, + // and any entry-point-specific parameter data + // will (eventually) belong there... + List<RefPtr<EntryPointLayout>> entryPoints; +}; + +// A modifier to be attached to syntax after we've computed layout +class ComputedLayoutModifier : public Modifier +{ +public: + RefPtr<TypeLayout> typeLayout; +}; + + +struct LayoutRulesFamilyImpl; + +// A delineation of shader parameter types into fine-grained +// categories that can then be mapped down to actual resources +// by a given set of rules. +// +// TODO(tfoley): `SlangParameterCategory` and `slang::ParameterCategory` +// are badly named, and need to be revised so they can't be confused +// with this concept. +enum class ShaderParameterKind +{ + ConstantBuffer, + TextureUniformBuffer, + ShaderStorageBuffer, + + StructuredBuffer, + MutableStructuredBuffer, + + SampledBuffer, + MutableSampledBuffer, + + RawBuffer, + MutableRawBuffer, + + Buffer, + MutableBuffer, + + Texture, + MutableTexture, + + TextureSampler, + MutableTextureSampler, + + InputRenderTarget, + + SamplerState, +}; + +struct SimpleLayoutRulesImpl +{ + // Get size and alignment for a single value of base type. + virtual SimpleLayoutInfo GetScalarLayout(BaseType baseType) = 0; + virtual SimpleLayoutInfo GetScalarLayout(slang::TypeReflection::ScalarType scalarType) = 0; + + // Get size and alignment for an array of elements + virtual SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, size_t elementCount) = 0; + + // Get layout for a vector or matrix type + virtual SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) = 0; + virtual SimpleLayoutInfo GetMatrixLayout(SimpleLayoutInfo elementInfo, size_t rowCount, size_t columnCount) = 0; + + // Begin doing layout on a `struct` type + virtual UniformLayoutInfo BeginStructLayout() = 0; + + // Add a field to a `struct` type, and return the offset for the field + virtual size_t AddStructField(UniformLayoutInfo* ioStructInfo, UniformLayoutInfo fieldInfo) = 0; + + // End layout for a struct, and finalize its size/alignment. + virtual void EndStructLayout(UniformLayoutInfo* ioStructInfo) = 0; +}; + +struct ObjectLayoutRulesImpl +{ + // Compute layout info for an object type + virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind kind) = 0; +}; + +struct LayoutRulesImpl +{ + LayoutRulesFamilyImpl* family; + SimpleLayoutRulesImpl* simpleRules; + ObjectLayoutRulesImpl* objectRules; + + // Forward `SimpleLayoutRulesImpl` interface + + SimpleLayoutInfo GetScalarLayout(BaseType baseType) + { + return simpleRules->GetScalarLayout(baseType); + } + + SimpleLayoutInfo GetScalarLayout(slang::TypeReflection::ScalarType scalarType) + { + return simpleRules->GetScalarLayout(scalarType); + } + + SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, size_t elementCount) + { + return simpleRules->GetArrayLayout(elementInfo, elementCount); + } + + SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) + { + return simpleRules->GetVectorLayout(elementInfo, elementCount); + } + + SimpleLayoutInfo GetMatrixLayout(SimpleLayoutInfo elementInfo, size_t rowCount, size_t columnCount) + { + return simpleRules->GetMatrixLayout(elementInfo, rowCount, columnCount); + } + + UniformLayoutInfo BeginStructLayout() + { + return simpleRules->BeginStructLayout(); + } + + size_t AddStructField(UniformLayoutInfo* ioStructInfo, UniformLayoutInfo fieldInfo) + { + return simpleRules->AddStructField(ioStructInfo, fieldInfo); + } + + void EndStructLayout(UniformLayoutInfo* ioStructInfo) + { + return simpleRules->EndStructLayout(ioStructInfo); + } + + // Forward `ObjectLayoutRulesImpl` interface + + SimpleLayoutInfo GetObjectLayout(ShaderParameterKind kind) + { + return objectRules->GetObjectLayout(kind); + } + + // + + LayoutRulesFamilyImpl* getLayoutRulesFamily() { return family; } +}; + +struct LayoutRulesFamilyImpl +{ + virtual LayoutRulesImpl* getConstantBufferRules() = 0; + virtual LayoutRulesImpl* getTextureBufferRules() = 0; + virtual LayoutRulesImpl* getVaryingInputRules() = 0; + virtual LayoutRulesImpl* getVaryingOutputRules() = 0; + virtual LayoutRulesImpl* getSpecializationConstantRules() = 0; + virtual LayoutRulesImpl* getShaderStorageBufferRules() = 0; +}; + +LayoutRulesImpl* GetLayoutRulesImpl(LayoutRule rule); +LayoutRulesFamilyImpl* GetLayoutRulesFamilyImpl(LayoutRulesFamily rule); +LayoutRulesFamilyImpl* GetLayoutRulesFamilyImpl(SourceLanguage language); + +SimpleLayoutInfo GetLayout(ExpressionType* type, LayoutRulesImpl* rules); + +SimpleLayoutInfo GetLayout(ExpressionType* type, LayoutRule rule = LayoutRule::Std430); + +RefPtr<TypeLayout> CreateTypeLayout(ExpressionType* type, LayoutRulesImpl* rules); + +// + +// Create a type layout for a parameter block type. +RefPtr<ParameterBlockTypeLayout> +createParameterBlockTypeLayout( + RefPtr<ParameterBlockType> parameterBlockType, + LayoutRulesImpl* rules); + +// Create a type layout for a constant buffer type, +// in the case where we already know the layout +// for the element type. +RefPtr<ParameterBlockTypeLayout> +createParameterBlockTypeLayout( + RefPtr<ParameterBlockType> parameterBlockType, + RefPtr<TypeLayout> elementTypeLayout, + LayoutRulesImpl* rules); + + +// Create a type layout for a structured buffer type. +RefPtr<StructuredBufferTypeLayout> +createStructuredBufferTypeLayout( + ShaderParameterKind kind, + RefPtr<ExpressionType> structuredBufferType, + RefPtr<ExpressionType> elementType, + LayoutRulesImpl* rules); + + +// + +}} + +#endif
\ No newline at end of file |
