summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorTim Foley <tfoley@nvidia.com>2017-06-09 11:34:21 -0700
committerTim Foley <tfoley@nvidia.com>2017-06-09 13:44:59 -0700
commitfcf83dbf9effab3bd98bad2b83b2468b7eb05cfd (patch)
tree41047c94883b86ec085a81597391ce3ef557cd43 /source/slang
parent52e8d4b9a27ab0060f874c3a63ab531847be35c0 (diff)
Initial import of code.
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/check.cpp4973
-rw-r--r--source/slang/compiled-program.h96
-rw-r--r--source/slang/compiler.cpp659
-rw-r--r--source/slang/compiler.h156
-rw-r--r--source/slang/diagnostic-defs.h338
-rw-r--r--source/slang/diagnostics.cpp204
-rw-r--r--source/slang/diagnostics.h218
-rw-r--r--source/slang/emit.cpp2537
-rw-r--r--source/slang/emit.h24
-rw-r--r--source/slang/intrinsic-defs.h94
-rw-r--r--source/slang/lexer.cpp1012
-rw-r--r--source/slang/lexer.h101
-rw-r--r--source/slang/lookup.cpp311
-rw-r--r--source/slang/lookup.h41
-rw-r--r--source/slang/parameter-binding.cpp1252
-rw-r--r--source/slang/parameter-binding.h32
-rw-r--r--source/slang/parser.cpp3106
-rw-r--r--source/slang/parser.h23
-rw-r--r--source/slang/preprocessor.cpp2032
-rw-r--r--source/slang/preprocessor.h35
-rw-r--r--source/slang/profile-defs.h123
-rw-r--r--source/slang/profile.cpp20
-rw-r--r--source/slang/profile.h84
-rw-r--r--source/slang/reflection.cpp1404
-rw-r--r--source/slang/reflection.h39
-rw-r--r--source/slang/slang-stdlib.cpp1855
-rw-r--r--source/slang/slang-stdlib.h23
-rw-r--r--source/slang/slang.cpp699
-rw-r--r--source/slang/slang.natvis14
-rw-r--r--source/slang/slang.vcxproj427
-rw-r--r--source/slang/slang.vcxproj.filters48
-rw-r--r--source/slang/source-loc.h47
-rw-r--r--source/slang/syntax-visitors.h21
-rw-r--r--source/slang/syntax.cpp1484
-rw-r--r--source/slang/syntax.h2771
-rw-r--r--source/slang/token-defs.h93
-rw-r--r--source/slang/token.cpp22
-rw-r--r--source/slang/token.h50
-rw-r--r--source/slang/type-layout.cpp1125
-rw-r--r--source/slang/type-layout.h550
40 files changed, 28143 insertions, 0 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
new file mode 100644
index 000000000..7d50c5978
--- /dev/null
+++ b/source/slang/check.cpp
@@ -0,0 +1,4973 @@
+#include "syntax-visitors.h"
+
+#include "lookup.h"
+#include "compiler.h"
+
+#include <assert.h>
+
+namespace Slang
+{
+ namespace Compiler
+ {
+ bool IsNumeric(BaseType t)
+ {
+ return t == BaseType::Int || t == BaseType::Float || t == BaseType::UInt;
+ }
+
+ String TranslateHLSLTypeNames(String name)
+ {
+ if (name == "float2" || name == "half2")
+ return "vec2";
+ else if (name == "float3" || name == "half3")
+ return "vec3";
+ else if (name == "float4" || name == "half4")
+ return "vec4";
+ else if (name == "half")
+ return "float";
+ else if (name == "int2")
+ return "ivec2";
+ else if (name == "int3")
+ return "ivec3";
+ else if (name == "int4")
+ return "ivec4";
+ else if (name == "uint2")
+ return "uvec2";
+ else if (name == "uint3")
+ return "uvec3";
+ else if (name == "uint4")
+ return "uvec4";
+ else if (name == "float3x3" || name == "half3x3")
+ return "mat3";
+ else if (name == "float4x4" || name == "half4x4")
+ return "mat4";
+ else
+ return name;
+ }
+
+ class SemanticsVisitor : public SyntaxVisitor
+ {
+ ProgramSyntaxNode * program = nullptr;
+ FunctionSyntaxNode * function = nullptr;
+ CompileOptions const* options = nullptr;
+
+ // lexical outer statements
+ List<StatementSyntaxNode*> outerStmts;
+ public:
+ SemanticsVisitor(
+ DiagnosticSink * pErr,
+ CompileOptions const& options)
+ : SyntaxVisitor(pErr)
+ , options(&options)
+ {
+ }
+
+ CompileOptions const& getOptions() { return *options; }
+
+ public:
+ // Translate Types
+ RefPtr<ExpressionType> typeResult;
+ RefPtr<ExpressionSyntaxNode> TranslateTypeNodeImpl(const RefPtr<ExpressionSyntaxNode> & node)
+ {
+ if (!node) return nullptr;
+ auto expr = node->Accept(this).As<ExpressionSyntaxNode>();
+ expr = ExpectATypeRepr(expr);
+ return expr;
+ }
+ RefPtr<ExpressionType> ExtractTypeFromTypeRepr(const RefPtr<ExpressionSyntaxNode>& typeRepr)
+ {
+ if (!typeRepr) return nullptr;
+ if (auto typeType = typeRepr->Type->As<TypeType>())
+ {
+ return typeType->type;
+ }
+ return ExpressionType::Error;
+ }
+ RefPtr<ExpressionType> TranslateTypeNode(const RefPtr<ExpressionSyntaxNode> & node)
+ {
+ if (!node) return nullptr;
+ auto typeRepr = TranslateTypeNodeImpl(node);
+ return ExtractTypeFromTypeRepr(typeRepr);
+ }
+ TypeExp TranslateTypeNode(TypeExp const& typeExp)
+ {
+ // HACK(tfoley): It seems that in some cases we end up re-checking
+ // syntax that we've already checked. We need to root-cause that
+ // issue, but for now a quick fix in this case is to early
+ // exist if we've already got a type associated here:
+ if (typeExp.type)
+ {
+ return typeExp;
+ }
+
+
+ auto typeRepr = TranslateTypeNodeImpl(typeExp.exp);
+
+ TypeExp result;
+ result.exp = typeRepr;
+ result.type = ExtractTypeFromTypeRepr(typeRepr);
+ return result;
+ }
+
+ RefPtr<ExpressionSyntaxNode> ConstructDeclRefExpr(
+ DeclRef declRef,
+ RefPtr<ExpressionSyntaxNode> baseExpr,
+ RefPtr<ExpressionSyntaxNode> originalExpr)
+ {
+ if (baseExpr)
+ {
+ auto expr = new MemberExpressionSyntaxNode();
+ expr->Position = originalExpr->Position;
+ expr->BaseExpression = baseExpr;
+ expr->MemberName = declRef.GetName();
+ expr->Type = GetTypeForDeclRef(declRef);
+ expr->declRef = declRef;
+ return expr;
+ }
+ else
+ {
+ auto expr = new VarExpressionSyntaxNode();
+ expr->Position = originalExpr->Position;
+ expr->Variable = declRef.GetName();
+ expr->Type = GetTypeForDeclRef(declRef);
+ expr->declRef = declRef;
+ return expr;
+ }
+ }
+
+ RefPtr<ExpressionSyntaxNode> ConstructDerefExpr(
+ RefPtr<ExpressionSyntaxNode> base,
+ RefPtr<ExpressionSyntaxNode> originalExpr)
+ {
+ auto ptrLikeType = base->Type->As<PointerLikeType>();
+ assert(ptrLikeType);
+
+ auto derefExpr = new DerefExpr();
+ derefExpr->Position = originalExpr->Position;
+ derefExpr->base = base;
+ derefExpr->Type = ptrLikeType->elementType;
+
+ // TODO(tfoley): handle l-value status here
+
+ return derefExpr;
+ }
+
+ RefPtr<ExpressionSyntaxNode> ConstructLookupResultExpr(
+ LookupResultItem const& item,
+ RefPtr<ExpressionSyntaxNode> baseExpr,
+ RefPtr<ExpressionSyntaxNode> originalExpr)
+ {
+ // If we collected any breadcrumbs, then these represent
+ // additional segments of the lookup path that we need
+ // to expand here.
+ auto bb = baseExpr;
+ for (auto breadcrumb = item.breadcrumbs; breadcrumb; breadcrumb = breadcrumb->next)
+ {
+ switch (breadcrumb->kind)
+ {
+ case LookupResultItem::Breadcrumb::Kind::Member:
+ bb = ConstructDeclRefExpr(breadcrumb->declRef, bb, originalExpr);
+ break;
+ case LookupResultItem::Breadcrumb::Kind::Deref:
+ bb = ConstructDerefExpr(bb, originalExpr);
+ break;
+ default:
+ SLANG_UNREACHABLE("all cases handle");
+ }
+ }
+
+ return ConstructDeclRefExpr(item.declRef, bb, originalExpr);
+ }
+
+ RefPtr<ExpressionSyntaxNode> createLookupResultExpr(
+ LookupResult const& lookupResult,
+ RefPtr<ExpressionSyntaxNode> baseExpr,
+ RefPtr<ExpressionSyntaxNode> originalExpr)
+ {
+ if (lookupResult.isOverloaded())
+ {
+ auto overloadedExpr = new OverloadedExpr();
+ overloadedExpr->Position = originalExpr->Position;
+ overloadedExpr->Type = ExpressionType::Overloaded;
+ overloadedExpr->base = baseExpr;
+ overloadedExpr->lookupResult2 = lookupResult;
+ return overloadedExpr;
+ }
+ else
+ {
+ return ConstructLookupResultExpr(lookupResult.item, baseExpr, originalExpr);
+ }
+ }
+
+ RefPtr<ExpressionSyntaxNode> ResolveOverloadedExpr(RefPtr<OverloadedExpr> overloadedExpr, LookupMask mask)
+ {
+ auto lookupResult = overloadedExpr->lookupResult2;
+ assert(lookupResult.isValid() && lookupResult.isOverloaded());
+
+ // Take the lookup result we had, and refine it based on what is expected in context.
+ lookupResult = refineLookup(lookupResult, mask);
+
+ if (!lookupResult.isValid())
+ {
+ // If we didn't find any symbols after filtering, then just
+ // use the original and report errors that way
+ return overloadedExpr;
+ }
+
+ if (lookupResult.isOverloaded())
+ {
+ // We had an ambiguity anyway, so report it.
+ getSink()->diagnose(overloadedExpr, Diagnostics::ambiguousReference, lookupResult.items[0].declRef.GetName());
+
+ for(auto item : lookupResult.items)
+ {
+ String declString = getDeclSignatureString(item);
+ getSink()->diagnose(item.declRef, Diagnostics::overloadCandidate, declString);
+ }
+
+ // TODO(tfoley): should we construct a new ErrorExpr here?
+ overloadedExpr->Type = ExpressionType::Error;
+ return overloadedExpr;
+ }
+
+ // otherwise, we had a single decl and it was valid, hooray!
+ return ConstructLookupResultExpr(lookupResult.item, overloadedExpr->base, overloadedExpr);
+ }
+
+ RefPtr<ExpressionSyntaxNode> ExpectATypeRepr(RefPtr<ExpressionSyntaxNode> expr)
+ {
+ if (auto overloadedExpr = expr.As<OverloadedExpr>())
+ {
+ expr = ResolveOverloadedExpr(overloadedExpr, LookupMask::Type);
+ }
+
+ if (auto typeType = expr->Type.type->As<TypeType>())
+ {
+ return expr;
+ }
+ else if (expr->Type.type->Equals(ExpressionType::Error))
+ {
+ return expr;
+ }
+
+ getSink()->diagnose(expr, Diagnostics::unimplemented, "expected a type");
+ // TODO: construct some kind of `ErrorExpr`?
+ return expr;
+ }
+
+ RefPtr<ExpressionType> ExpectAType(RefPtr<ExpressionSyntaxNode> expr)
+ {
+ auto typeRepr = ExpectATypeRepr(expr);
+ if (auto typeType = typeRepr->Type->As<TypeType>())
+ {
+ return typeType->type;
+ }
+ return ExpressionType::Error;
+ }
+
+ RefPtr<ExpressionType> ExtractGenericArgType(RefPtr<ExpressionSyntaxNode> exp)
+ {
+ return ExpectAType(exp);
+ }
+
+ RefPtr<IntVal> ExtractGenericArgInteger(RefPtr<ExpressionSyntaxNode> exp)
+ {
+ return CheckIntegerConstantExpression(exp.Ptr());
+ }
+
+ RefPtr<Val> ExtractGenericArgVal(RefPtr<ExpressionSyntaxNode> exp)
+ {
+ if (auto overloadedExpr = exp.As<OverloadedExpr>())
+ {
+ // assume that if it is overloaded, we want a type
+ exp = ResolveOverloadedExpr(overloadedExpr, LookupMask::Type);
+ }
+
+ if (auto typeType = exp->Type->As<TypeType>())
+ {
+ return typeType->type;
+ }
+ else if (exp->Type->Equals(ExpressionType::Error))
+ {
+ return exp->Type.type;
+ }
+ else
+ {
+ return ExtractGenericArgInteger(exp);
+ }
+ }
+
+ // Construct a type reprsenting the instantiation of
+ // the given generic declaration for the given arguments.
+ // The arguments should already be checked against
+ // the declaration.
+ RefPtr<ExpressionType> InstantiateGenericType(
+ GenericDeclRef genericDeclRef,
+ List<RefPtr<ExpressionSyntaxNode>> const& args)
+ {
+ RefPtr<Substitutions> subst = new Substitutions();
+ subst->genericDecl = genericDeclRef.GetDecl();
+ subst->outer = genericDeclRef.substitutions;
+
+ for (auto argExpr : args)
+ {
+ subst->args.Add(ExtractGenericArgVal(argExpr));
+ }
+
+ DeclRef innerDeclRef;
+ innerDeclRef.decl = genericDeclRef.GetInner();
+ innerDeclRef.substitutions = subst;
+
+ return DeclRefType::Create(innerDeclRef);
+ }
+
+ // Make sure a declaration has been checked, so we can refer to it.
+ // Note that this may lead to us recursively invoking checking,
+ // so this may not be the best way to handle things.
+ void EnsureDecl(RefPtr<Decl> decl, DeclCheckState state = DeclCheckState::CheckedHeader)
+ {
+ if (decl->IsChecked(state)) return;
+ if (decl->checkState == DeclCheckState::CheckingHeader)
+ {
+ // We tried to reference the same declaration while checking it!
+ throw "circularity";
+ }
+
+ if (DeclCheckState::CheckingHeader > decl->checkState)
+ {
+ decl->SetCheckState(DeclCheckState::CheckingHeader);
+ }
+
+ // TODO: not all of the `Visit` cases are ready to
+ // handle this being called on-the-fly
+ decl->Accept(this);
+
+ decl->SetCheckState(DeclCheckState::Checked);
+ }
+
+ void EnusreAllDeclsRec(RefPtr<Decl> decl)
+ {
+ EnsureDecl(decl, DeclCheckState::Checked);
+ if (auto containerDecl = decl.As<ContainerDecl>())
+ {
+ for (auto m : containerDecl->Members)
+ {
+ EnusreAllDeclsRec(m);
+ }
+ }
+ }
+
+ // A "proper" type is one that can be used as the type of an expression.
+ // Put simply, it can be a concrete type like `int`, or a generic
+ // type that is applied to arguments, like `Texture2D<float4>`.
+ // The type `void` is also a proper type, since we can have expressions
+ // that return a `void` result (e.g., many function calls).
+ //
+ // A "non-proper" type is any type that can't actually have values.
+ // A simple example of this in C++ is `std::vector` - you can't have
+ // a value of this type.
+ //
+ // Part of what this function does is give errors if somebody tries
+ // to use a non-proper type as the type of a variable (or anything
+ // else that needs a proper type).
+ //
+ // The other thing it handles is the fact that HLSL lets you use
+ // the name of a non-proper type, and then have the compiler fill
+ // in the default values for its type arguments (e.g., a variable
+ // given type `Texture2D` will actually have type `Texture2D<float4>`).
+ bool CoerceToProperTypeImpl(TypeExp const& typeExp, RefPtr<ExpressionType>* outProperType)
+ {
+ ExpressionType* type = typeExp.type.Ptr();
+ if (auto genericDeclRefType = type->As<GenericDeclRefType>())
+ {
+ // We are using a reference to a generic declaration as a concrete
+ // type. This means we should substitute in any default parameter values
+ // if they are available.
+ //
+ // TODO(tfoley): A more expressive type system would substitute in
+ // "fresh" variables and then solve for their values...
+ //
+
+ auto genericDeclRef = genericDeclRefType->GetDeclRef();
+ EnsureDecl(genericDeclRef.decl);
+ List<RefPtr<ExpressionSyntaxNode>> args;
+ for (RefPtr<Decl> member : genericDeclRef.GetDecl()->Members)
+ {
+ if (auto typeParam = member.As<GenericTypeParamDecl>())
+ {
+ if (!typeParam->initType.exp)
+ {
+ if (outProperType)
+ {
+ getSink()->diagnose(typeExp.exp.Ptr(), Diagnostics::unimplemented, "can't fill in default for generic type parameter");
+ *outProperType = ExpressionType::Error;
+ }
+ return false;
+ }
+
+ // TODO: this is one place where syntax should get cloned!
+ if(outProperType)
+ args.Add(typeParam->initType.exp);
+ }
+ else if (auto valParam = member.As<GenericValueParamDecl>())
+ {
+ if (!valParam->Expr)
+ {
+ if (outProperType)
+ {
+ getSink()->diagnose(typeExp.exp.Ptr(), Diagnostics::unimplemented, "can't fill in default for generic type parameter");
+ *outProperType = ExpressionType::Error;
+ }
+ return false;
+ }
+
+ // TODO: this is one place where syntax should get cloned!
+ if(outProperType)
+ args.Add(valParam->Expr);
+ }
+ else
+ {
+ // ignore non-parameter members
+ }
+ }
+
+ if (outProperType)
+ {
+ *outProperType = InstantiateGenericType(genericDeclRef, args);
+ }
+ return true;
+ }
+ else
+ {
+ // default case: we expect this to already be a proper type
+ if (outProperType)
+ {
+ *outProperType = type;
+ }
+ return true;
+ }
+ }
+
+
+
+ TypeExp CoerceToProperType(TypeExp const& typeExp)
+ {
+ TypeExp result = typeExp;
+ CoerceToProperTypeImpl(typeExp, &result.type);
+ return result;
+ }
+
+ bool CanCoerceToProperType(TypeExp const& typeExp)
+ {
+ return CoerceToProperTypeImpl(typeExp, nullptr);
+ }
+
+ // Check a type, and coerce it to be proper
+ TypeExp CheckProperType(TypeExp typeExp)
+ {
+ return CoerceToProperType(TranslateTypeNode(typeExp));
+ }
+
+ // For our purposes, a "usable" type is one that can be
+ // used to declare a function parameter, variable, etc.
+ // These turn out to be all the proper types except
+ // `void`.
+ //
+ // TODO(tfoley): consider just allowing `void` as a
+ // simple example of a "unit" type, and get rid of
+ // this check.
+ TypeExp CoerceToUsableType(TypeExp const& typeExp)
+ {
+ TypeExp result = CoerceToProperType(typeExp);
+ ExpressionType* type = result.type.Ptr();
+ if (auto basicType = type->As<BasicExpressionType>())
+ {
+ // TODO: `void` shouldn't be a basic type, to make this easier to avoid
+ if (basicType->BaseType == BaseType::Void)
+ {
+ // TODO(tfoley): pick the right diagnostic message
+ getSink()->diagnose(result.exp.Ptr(), Diagnostics::invalidTypeVoid);
+ result.type = ExpressionType::Error;
+ return result;
+ }
+ }
+ return result;
+ }
+
+ // Check a type, and coerce it to be usable
+ TypeExp CheckUsableType(TypeExp typeExp)
+ {
+ return CoerceToUsableType(TranslateTypeNode(typeExp));
+ }
+
+ RefPtr<ExpressionSyntaxNode> CheckTerm(RefPtr<ExpressionSyntaxNode> term)
+ {
+ if (!term) return nullptr;
+ return term->Accept(this).As<ExpressionSyntaxNode>();
+ }
+
+ RefPtr<ExpressionSyntaxNode> CreateErrorExpr(ExpressionSyntaxNode* expr)
+ {
+ expr->Type = ExpressionType::Error;
+ return expr;
+ }
+
+ bool IsErrorExpr(RefPtr<ExpressionSyntaxNode> expr)
+ {
+ // TODO: we may want other cases here...
+
+ if (expr->Type->Equals(ExpressionType::Error))
+ return true;
+
+ return false;
+ }
+
+ // Capture the "base" expression in case this is a member reference
+ RefPtr<ExpressionSyntaxNode> GetBaseExpr(RefPtr<ExpressionSyntaxNode> expr)
+ {
+ if (auto memberExpr = expr.As<MemberExpressionSyntaxNode>())
+ {
+ return memberExpr->BaseExpression;
+ }
+ else if(auto overloadedExpr = expr.As<OverloadedExpr>())
+ {
+ return overloadedExpr->base;
+ }
+ return nullptr;
+ }
+
+ public:
+
+ typedef unsigned int ConversionCost;
+ enum : ConversionCost
+ {
+ // No conversion at all
+ kConversionCost_None = 0,
+
+ // Conversion that is lossless and keeps the "kind" of the value the same
+ kConversionCost_RankPromotion = 100,
+
+ // Conversions that are lossless, but change "kind"
+ kConversionCost_UnsignedToSignedPromotion = 200,
+
+ // Conversion from signed->unsigned integer of same or greater size
+ kConversionCost_SignedToUnsignedConversion = 300,
+
+ // Cost of converting an integer to a floating-point type
+ kConversionCost_IntegerToFloatConversion = 400,
+
+ // Catch-all for conversions that should be discouraged
+ // (i.e., that really shouldn't be made implicitly)
+ //
+ // TODO: make these conversions not be allowed implicitly in "Slang mode"
+ kConversionCost_GeneralConversion = 900,
+
+ // Additional conversion cost to add when promoting from a scalar to
+ // a vector (this will be added to the cost, if any, of converting
+ // the element type of the vector)
+ kConversionCost_ScalarToVector = 1,
+ };
+
+ enum BaseTypeConversionKind : uint8_t
+ {
+ kBaseTypeConversionKind_Signed,
+ kBaseTypeConversionKind_Unsigned,
+ kBaseTypeConversionKind_Float,
+ kBaseTypeConversionKind_Error,
+ };
+
+ enum BaseTypeConversionRank : uint8_t
+ {
+ kBaseTypeConversionRank_Bool,
+ kBaseTypeConversionRank_Int8,
+ kBaseTypeConversionRank_Int16,
+ kBaseTypeConversionRank_Int32,
+ kBaseTypeConversionRank_IntPtr,
+ kBaseTypeConversionRank_Int64,
+ kBaseTypeConversionRank_Error,
+ };
+
+ struct BaseTypeConversionInfo
+ {
+ BaseTypeConversionKind kind;
+ BaseTypeConversionRank rank;
+ };
+ static BaseTypeConversionInfo GetBaseTypeConversionInfo(BaseType baseType)
+ {
+ switch (baseType)
+ {
+ #define CASE(TAG, KIND, RANK) \
+ case BaseType::TAG: { BaseTypeConversionInfo info = {kBaseTypeConversionKind_##KIND, kBaseTypeConversionRank_##RANK}; return info; } break
+
+ CASE(Bool, Unsigned, Bool);
+ CASE(Int, Signed, Int32);
+ CASE(UInt, Unsigned, Int32);
+ CASE(UInt64, Unsigned, Int64);
+ CASE(Float, Float, Int32);
+ CASE(Void, Error, Error);
+
+ #undef CASE
+
+ default:
+ break;
+ }
+ SLANG_UNREACHABLE("all cases handled");
+ }
+
+ bool ValuesAreEqual(
+ RefPtr<IntVal> left,
+ RefPtr<IntVal> right)
+ {
+ if(left == right) return true;
+
+ if(auto leftConst = left.As<ConstantIntVal>())
+ {
+ if(auto rightConst = right.As<ConstantIntVal>())
+ {
+ return leftConst->value == rightConst->value;
+ }
+ }
+
+ if(auto leftVar = left.As<GenericParamIntVal>())
+ {
+ if(auto rightVar = right.As<GenericParamIntVal>())
+ {
+ return leftVar->declRef.Equals(rightVar->declRef);
+ }
+ }
+
+ return false;
+ }
+
+ // Central engine for implementing implicit coercion logic
+ bool TryCoerceImpl(
+ RefPtr<ExpressionType> toType, // the target type for conversion
+ RefPtr<ExpressionSyntaxNode>* outToExpr, // (optional) a place to stuff the target expression
+ RefPtr<ExpressionType> fromType, // the source type for the conversion
+ RefPtr<ExpressionSyntaxNode> fromExpr, // the source expression
+ ConversionCost* outCost) // (optional) a place to stuff the conversion cost
+ {
+ // Easy case: the types are equal
+ if (toType->Equals(fromType))
+ {
+ if (outToExpr)
+ *outToExpr = fromExpr;
+ if (outCost)
+ *outCost = kConversionCost_None;
+ return true;
+ }
+
+ // If either type is an error, then let things pass.
+ if (toType->As<ErrorType>() || fromType->As<ErrorType>())
+ {
+ if (outToExpr)
+ *outToExpr = CreateImplicitCastExpr(toType, fromExpr);
+ if (outCost)
+ *outCost = kConversionCost_None;
+ return true;
+ }
+
+ // Coercion from an initializer list is allowed for many types
+ if( auto fromInitializerListExpr = fromExpr.As<InitializerListExpr>())
+ {
+ auto argCount = fromInitializerListExpr->args.Count();
+
+ // In the case where we need to build a reuslt expression,
+ // we will collect the new arguments here
+ List<RefPtr<ExpressionSyntaxNode>> coercedArgs;
+
+ if(auto toDeclRefType = toType->As<DeclRefType>())
+ {
+ auto toTypeDeclRef = toDeclRefType->declRef;
+ if(auto toStructDeclRef = toTypeDeclRef.As<StructDeclRef>())
+ {
+ // Trying to initialize a `struct` type given an initializer list.
+ // We will go through the fields in order and try to match them
+ // up with initializer arguments.
+
+
+ int argIndex = 0;
+ for(auto& fieldDeclRef : toStructDeclRef.GetMembersOfType<FieldDeclRef>())
+ {
+ if(argIndex >= argCount)
+ {
+ // We've consumed all the arguments, so we should stop
+ break;
+ }
+
+ auto arg = fromInitializerListExpr->args[argIndex++];
+
+ //
+ RefPtr<ExpressionSyntaxNode> coercedArg;
+ ConversionCost argCost;
+
+ bool argResult = TryCoerceImpl(
+ fieldDeclRef.GetType(),
+ outToExpr ? &coercedArg : nullptr,
+ arg->Type,
+ arg,
+ outCost ? &argCost : nullptr);
+
+ // No point in trying further if any argument fails
+ if(!argResult)
+ return false;
+
+ // TODO(tfoley): what to do with cost?
+ // This only matters if/when we allow an initializer list as an argument to
+ // an overloaded call.
+
+ if( outToExpr )
+ {
+ coercedArgs.Add(coercedArg);
+ }
+ }
+ }
+ }
+ else if(auto toArrayType = toType->As<ArrayExpressionType>())
+ {
+ // TODO(tfoley): If we can compute the size of the array statically,
+ // then we want to check that there aren't too many initializers present
+
+ auto toElementType = toArrayType->BaseType;
+
+ for(auto& arg : fromInitializerListExpr->args)
+ {
+ RefPtr<ExpressionSyntaxNode> coercedArg;
+ ConversionCost argCost;
+
+ bool argResult = TryCoerceImpl(
+ toElementType,
+ outToExpr ? &coercedArg : nullptr,
+ arg->Type,
+ arg,
+ outCost ? &argCost : nullptr);
+
+ // No point in trying further if any argument fails
+ if(!argResult)
+ return false;
+
+ if( outToExpr )
+ {
+ coercedArgs.Add(coercedArg);
+ }
+ }
+ }
+ else
+ {
+ // By default, we don't allow a type to be initialized using
+ // an initializer list.
+ return false;
+ }
+
+ // For now, coercion from an initializer list has no cost
+ if(outCost)
+ {
+ *outCost = kConversionCost_None;
+ }
+
+ // We were able to coerce all the arguments given, and so
+ // we need to construct a suitable expression to remember the result
+ if(outToExpr)
+ {
+ auto toInitializerListExpr = new InitializerListExpr();
+ toInitializerListExpr->Position = fromInitializerListExpr->Position;
+ toInitializerListExpr->Type = toType;
+ toInitializerListExpr->args = coercedArgs;
+
+
+ *outToExpr = toInitializerListExpr;
+ }
+
+ return true;
+ }
+
+ //
+
+ if (auto toBasicType = toType->AsBasicType())
+ {
+ if (auto fromBasicType = fromType->AsBasicType())
+ {
+ // Conversions between base types are always allowed,
+ // and the only question is what the cost will be.
+
+ auto toInfo = GetBaseTypeConversionInfo(toBasicType->BaseType);
+ auto fromInfo = GetBaseTypeConversionInfo(fromBasicType->BaseType);
+
+ // We expect identical types to have been dealt with already.
+ assert(toInfo.kind != fromInfo.kind || toInfo.rank != fromInfo.rank);
+
+ if (outToExpr)
+ *outToExpr = CreateImplicitCastExpr(toType, fromExpr);
+
+
+ if (outCost)
+ {
+ // Conversions within the same kind are easist to handle
+ if (toInfo.kind == fromInfo.kind)
+ {
+ // If we are converting to a "larger" type, then
+ // we are doing a lossless promotion, and otherwise
+ // we are doing a demotion.
+ if( toInfo.rank > fromInfo.rank)
+ *outCost = kConversionCost_RankPromotion;
+ else
+ *outCost = kConversionCost_GeneralConversion;
+ }
+ // If we are converting from an unsigned integer type to
+ // a signed integer type that is guaranteed to be larger,
+ // then that is also a lossless promotion.
+ else if(toInfo.kind == kBaseTypeConversionKind_Signed
+ && fromInfo.kind == kBaseTypeConversionKind_Unsigned
+ && toInfo.rank > fromInfo.rank)
+ {
+ // TODO: probably need to weed out cases involving
+ // "pointer-sized" integers if these are treated
+ // as distinct from 32- and 64-bit types.
+ // E.g., there is no guarantee that conversion
+ // from 32-bit unsigned to pointer-sized signed
+ // is lossless, because pointers could be 32-bit,
+ // and the same applies for conversion from
+ // `uintptr` to `uint64`.
+ *outCost = kConversionCost_UnsignedToSignedPromotion;
+ }
+ // Conversion from signed to unsigned is always lossy,
+ // but it is preferred over conversions from unsigned
+ // to signed, for same-size types.
+ else if(toInfo.kind == kBaseTypeConversionKind_Unsigned
+ && fromInfo.kind == kBaseTypeConversionKind_Signed
+ && toInfo.rank >= fromInfo.rank)
+ {
+ *outCost = kConversionCost_SignedToUnsignedConversion;
+ }
+ // Conversion from an integer to a floating-point type
+ // is never considered a promotion (even when the value
+ // would fit in the available bits).
+ // If the destination type is at least 32 bits we consider
+ // this a reasonably good conversion, though.
+ else if (toInfo.kind == kBaseTypeConversionKind_Float
+ && toInfo.rank >= kBaseTypeConversionRank_Int32)
+ {
+ *outCost = kConversionCost_IntegerToFloatConversion;
+ }
+ // All other cases are considered as "general" conversions,
+ // where we don't consider any one conversion better than
+ // any others.
+ else
+ {
+ *outCost = kConversionCost_GeneralConversion;
+ }
+ }
+
+ return true;
+ }
+ }
+
+ if (auto toVectorType = toType->AsVectorType())
+ {
+ if (auto fromVectorType = fromType->AsVectorType())
+ {
+ // Conversion between vector types.
+
+ // If element counts don't match, then bail:
+ if (!ValuesAreEqual(toVectorType->elementCount, fromVectorType->elementCount))
+ return false;
+
+ // Otherwise, if we can convert the element types, we are golden
+ ConversionCost elementCost;
+ if (CanCoerce(toVectorType->elementType, fromVectorType->elementType, &elementCost))
+ {
+ if (outToExpr)
+ *outToExpr = CreateImplicitCastExpr(toType, fromExpr);
+ if (outCost)
+ *outCost = elementCost;
+ return true;
+ }
+ }
+ else if (auto fromScalarType = fromType->AsBasicType())
+ {
+ // Conversion from scalar to vector.
+ // Should allow as long as we can coerce the scalar to our element type.
+ ConversionCost elementCost;
+ if (CanCoerce(toVectorType->elementType, fromScalarType, &elementCost))
+ {
+ if (outToExpr)
+ *outToExpr = CreateImplicitCastExpr(toType, fromExpr);
+ if (outCost)
+ *outCost = elementCost + kConversionCost_ScalarToVector;
+ return true;
+ }
+ }
+ }
+
+ // TODO: more cases!
+
+ return false;
+ }
+
+ // Check whether a type coercion is possible
+ bool CanCoerce(
+ RefPtr<ExpressionType> toType, // the target type for conversion
+ RefPtr<ExpressionType> fromType, // the source type for the conversion
+ ConversionCost* outCost = 0) // (optional) a place to stuff the conversion cost
+ {
+ return TryCoerceImpl(
+ toType,
+ nullptr,
+ fromType,
+ nullptr,
+ outCost);
+ }
+
+ RefPtr<ExpressionSyntaxNode> CreateImplicitCastExpr(
+ RefPtr<ExpressionType> toType,
+ RefPtr<ExpressionSyntaxNode> fromExpr)
+ {
+ auto castExpr = new TypeCastExpressionSyntaxNode();
+ castExpr->Position = fromExpr->Position;
+ castExpr->TargetType.type = toType;
+ castExpr->Type = toType;
+ castExpr->Expression = fromExpr;
+ return castExpr;
+ }
+
+
+ // Perform type coercion, and emit errors if it isn't possible
+ RefPtr<ExpressionSyntaxNode> Coerce(
+ RefPtr<ExpressionType> toType,
+ RefPtr<ExpressionSyntaxNode> fromExpr)
+ {
+ // If semantic checking is being suppressed, then we might see
+ // expressions without a type, and we need to ignore them.
+ if( !fromExpr->Type.type )
+ {
+ if(getOptions().flags & SLANG_COMPILE_FLAG_NO_CHECKING )
+ return fromExpr;
+ }
+
+ RefPtr<ExpressionSyntaxNode> expr;
+ if (!TryCoerceImpl(
+ toType,
+ &expr,
+ fromExpr->Type.Ptr(),
+ fromExpr.Ptr(),
+ nullptr))
+ {
+ if(!(getOptions().flags & SLANG_COMPILE_FLAG_NO_CHECKING))
+ {
+ getSink()->diagnose(fromExpr->Position, Diagnostics::typeMismatch, toType, fromExpr->Type);
+ }
+
+ // Note(tfoley): We don't call `CreateErrorExpr` here, because that would
+ // clobber the type on `fromExpr`, and an invariant here is that coercion
+ // really shouldn't *change* the expression that is passed in, but should
+ // introduce new AST nodes to coerce its value to a different type...
+ return CreateImplicitCastExpr(ExpressionType::Error, fromExpr);
+ }
+ return expr;
+ }
+
+ void CheckVarDeclCommon(RefPtr<VarDeclBase> varDecl)
+ {
+ // Check the type, if one was given
+ TypeExp type = CheckUsableType(varDecl->Type);
+
+ // TODO: Additional validation rules on types should go here,
+ // but we need to deal with the fact that some cases might be
+ // allowed in one context (e.g., an unsized array parameter)
+ // but not in othters (e.g., an unsized array field in a struct).
+
+ // Check the initializers, if one was given
+ RefPtr<ExpressionSyntaxNode> initExpr = CheckTerm(varDecl->Expr);
+
+ // If a type was given, ...
+ if (type.Ptr())
+ {
+ // then coerce any initializer to the type
+ if (initExpr)
+ {
+ initExpr = Coerce(type, initExpr);
+ }
+ }
+ else
+ {
+ // TODO: infer a type from the initializers
+ if (!initExpr)
+ {
+ getSink()->diagnose(varDecl, Diagnostics::unimplemented, "variable declaration with no type must have initializer");
+ }
+ else
+ {
+ getSink()->diagnose(varDecl, Diagnostics::unimplemented, "type inference for variable declaration");
+ }
+ }
+
+ varDecl->Type = type;
+ varDecl->Expr = initExpr;
+ }
+
+ void CheckGenericConstraintDecl(GenericTypeConstraintDecl* decl)
+ {
+ // TODO: are there any other validations we can do at this point?
+ //
+ // There probably needs to be a kind of "occurs check" to make
+ // sure that the constraint actually applies to at least one
+ // of the parameters of the generic.
+
+ decl->sub = TranslateTypeNode(decl->sub);
+ decl->sup = TranslateTypeNode(decl->sup);
+ }
+
+ virtual RefPtr<GenericDecl> VisitGenericDecl(GenericDecl* genericDecl) override
+ {
+ // check the parameters
+ for (auto m : genericDecl->Members)
+ {
+ if (auto typeParam = m.As<GenericTypeParamDecl>())
+ {
+ typeParam->initType = CheckProperType(typeParam->initType);
+ }
+ else if (auto valParam = m.As<GenericValueParamDecl>())
+ {
+ // TODO: some real checking here...
+ CheckVarDeclCommon(valParam);
+ }
+ else if(auto constraint = m.As<GenericTypeConstraintDecl>())
+ {
+ CheckGenericConstraintDecl(constraint.Ptr());
+ }
+ }
+
+ // check the nested declaration
+ // TODO: this needs to be done in an appropriate environment...
+ genericDecl->inner->Accept(this);
+ return genericDecl;
+ }
+
+ virtual void VisitTraitConformanceDecl(TraitConformanceDecl* conformanceDecl) override
+ {
+ // check the type being conformed to
+ auto base = conformanceDecl->base;
+ base = TranslateTypeNode(base);
+ conformanceDecl->base = base;
+
+ if(auto declRefType = base.type->As<DeclRefType>())
+ {
+ if(auto traitDeclRef = declRefType->declRef.As<TraitDeclRef>())
+ {
+ conformanceDecl->traitDeclRef = traitDeclRef;
+ return;
+ }
+ }
+
+ // We expected a trait here
+ getSink()->diagnose( conformanceDecl, Diagnostics::expectedATraitGot, base.type);
+ }
+
+ RefPtr<ConstantIntVal> checkConstantIntVal(
+ RefPtr<ExpressionSyntaxNode> expr)
+ {
+ // First type-check the expression as normal
+ expr = CheckExpr(expr);
+
+ auto intVal = CheckIntegerConstantExpression(expr.Ptr());
+ if(!intVal)
+ return nullptr;
+
+ auto constIntVal = intVal.As<ConstantIntVal>();
+ if(!constIntVal)
+ {
+ getSink()->diagnose(expr->Position, Diagnostics::expectedIntegerConstantNotLiteral);
+ return nullptr;
+ }
+ return constIntVal;
+ }
+
+ RefPtr<Modifier> checkModifier(
+ RefPtr<Modifier> m,
+ Decl* decl)
+ {
+ if(auto hlslUncheckedAttribute = m.As<HLSLUncheckedAttribute>())
+ {
+ // We have an HLSL `[name(arg,...)]` attribute, and we'd like
+ // to check that it is provides all the expected arguments
+ //
+ // For now we will do this in a completely ad hoc fashion,
+ // but it would be nice to have some generic routine to
+ // do the needed type checking/coercion.
+ if(hlslUncheckedAttribute->nameToken.Content == "numthreads")
+ {
+ if(hlslUncheckedAttribute->args.Count() != 3)
+ return m;
+
+ auto xVal = checkConstantIntVal(hlslUncheckedAttribute->args[0]);
+ auto yVal = checkConstantIntVal(hlslUncheckedAttribute->args[1]);
+ auto zVal = checkConstantIntVal(hlslUncheckedAttribute->args[2]);
+
+ if(!xVal) return m;
+ if(!yVal) return m;
+ if(!zVal) return m;
+
+ auto hlslNumThreadsAttribute = new HLSLNumThreadsAttribute();
+
+ hlslNumThreadsAttribute->Position = hlslUncheckedAttribute->Position;
+ hlslNumThreadsAttribute->nameToken = hlslUncheckedAttribute->nameToken;
+ hlslNumThreadsAttribute->args = hlslUncheckedAttribute->args;
+ hlslNumThreadsAttribute->x = xVal->value;
+ hlslNumThreadsAttribute->y = yVal->value;
+ hlslNumThreadsAttribute->z = zVal->value;
+
+ return hlslNumThreadsAttribute;
+ }
+ }
+
+ // Default behavior is to leave things as they are,
+ // and assume that modifiers are mostly already checked.
+ //
+ // TODO: This would be a good place to validate that
+ // a modifier is actually valid for the thing it is
+ // being applied to, and potentially to check that
+ // it isn't in conflict with any other modifiers
+ // on the same declaration.
+
+ return m;
+ }
+
+
+ void checkModifiers(Decl* decl)
+ {
+ // TODO(tfoley): need to make sure this only
+ // performs semantic checks on a `SharedModifier` once...
+
+ // The process of checking a modifier may produce a new modifier in its place,
+ // so we will build up a new linked list of modifiers that will replace
+ // the old list.
+ RefPtr<Modifier> resultModifiers;
+ RefPtr<Modifier>* resultModifierLink = &resultModifiers;
+
+ RefPtr<Modifier> modifier = decl->modifiers.first;
+ while(modifier)
+ {
+ // Because we are rewriting the list in place, we need to extract
+ // the next modifier here (not at the end of the loop).
+ auto next = modifier->next;
+
+ // We also go ahead and clobber the `next` field on the modifier
+ // itself, so that the default behavior of `checkModifier()` can
+ // be to return a single unlinked modifier.
+ modifier->next = nullptr;
+
+ auto checkedModifier = checkModifier(modifier, decl);
+ if(checkedModifier)
+ {
+ // If checking gave us a modifier to add, then we
+ // had better add it.
+
+ // Just in case `checkModifier` ever returns multiple
+ // modifiers, lets advance to the end of the list we
+ // are building.
+ while(*resultModifierLink)
+ resultModifierLink = &(*resultModifierLink)->next;
+
+ // attach the new modifier at the end of the list,
+ // and now set the "link" to point to its `next` field
+ *resultModifierLink = checkedModifier;
+ resultModifierLink = &checkedModifier->next;
+ }
+
+ // Move along to the next modifier
+ modifier = next;
+ }
+
+ // Whether we actually re-wrote anything or note, lets
+ // install the new list of modifiers on the declaration
+ decl->modifiers.first = resultModifiers;
+ }
+
+ virtual RefPtr<ProgramSyntaxNode> VisitProgram(ProgramSyntaxNode * programNode) override
+ {
+ // Try to register all the builtin decls
+ for (auto decl : programNode->Members)
+ {
+ auto inner = decl;
+ if (auto genericDecl = decl.As<GenericDecl>())
+ {
+ inner = genericDecl->inner;
+ }
+
+ if (auto builtinMod = inner->FindModifier<BuiltinTypeModifier>())
+ {
+ RegisterBuiltinDecl(decl, builtinMod);
+ }
+ if (auto magicMod = inner->FindModifier<MagicTypeModifier>())
+ {
+ RegisterMagicDecl(decl, magicMod);
+ }
+ }
+
+ //
+
+ HashSet<String> funcNames;
+ this->program = programNode;
+ this->function = nullptr;
+
+ for (auto & s : program->GetTypeDefs())
+ VisitTypeDefDecl(s.Ptr());
+ for (auto & s : program->GetStructs())
+ {
+ VisitStruct(s.Ptr());
+ }
+ for (auto & s : program->GetClasses())
+ {
+ VisitClass(s.Ptr());
+ }
+ // HACK(tfoley): Visiting all generic declarations here,
+ // because otherwise they won't get visited.
+ for (auto & g : program->GetMembersOfType<GenericDecl>())
+ {
+ VisitGenericDecl(g.Ptr());
+ }
+
+ for (auto & func : program->GetFunctions())
+ {
+ if (!func->IsChecked(DeclCheckState::Checked))
+ {
+ VisitFunctionDeclaration(func.Ptr());
+ }
+ }
+ for (auto & func : program->GetFunctions())
+ {
+ EnsureDecl(func);
+ }
+
+ if (sink->GetErrorCount() != 0)
+ return programNode;
+
+ // Force everything to be fully checked, just in case
+ // Note that we don't just call this on the program,
+ // because we'd end up recursing into this very code path...
+ for (auto d : programNode->Members)
+ {
+ EnusreAllDeclsRec(d);
+ }
+
+ // Do any semantic checking required on modifiers?
+ for (auto d : programNode->Members)
+ {
+ checkModifiers(d.Ptr());
+ }
+
+ return programNode;
+ }
+
+ virtual RefPtr<ClassSyntaxNode> VisitClass(ClassSyntaxNode * classNode) override
+ {
+ if (classNode->IsChecked(DeclCheckState::Checked))
+ return classNode;
+ classNode->SetCheckState(DeclCheckState::Checked);
+
+ for (auto field : classNode->GetFields())
+ {
+ field->Type = CheckUsableType(field->Type);
+ field->SetCheckState(DeclCheckState::Checked);
+ }
+ return classNode;
+ }
+
+ virtual RefPtr<StructSyntaxNode> VisitStruct(StructSyntaxNode * structNode) override
+ {
+ if (structNode->IsChecked(DeclCheckState::Checked))
+ return structNode;
+ structNode->SetCheckState(DeclCheckState::Checked);
+
+ for (auto field : structNode->GetFields())
+ {
+ field->Type = CheckUsableType(field->Type);
+ field->SetCheckState(DeclCheckState::Checked);
+ }
+ return structNode;
+ }
+
+ virtual RefPtr<TypeDefDecl> VisitTypeDefDecl(TypeDefDecl* decl) override
+ {
+ if (decl->IsChecked(DeclCheckState::Checked)) return decl;
+
+ decl->SetCheckState(DeclCheckState::CheckingHeader);
+ decl->Type = CheckProperType(decl->Type);
+ decl->SetCheckState(DeclCheckState::Checked);
+ return decl;
+ }
+
+ virtual RefPtr<FunctionSyntaxNode> VisitFunction(FunctionSyntaxNode *functionNode) override
+ {
+ if (functionNode->IsChecked(DeclCheckState::Checked))
+ return functionNode;
+
+ VisitFunctionDeclaration(functionNode);
+ functionNode->SetCheckState(DeclCheckState::Checked);
+
+ if (!functionNode->IsExtern())
+ {
+ this->function = functionNode;
+ if (functionNode->Body)
+ {
+ functionNode->Body->Accept(this);
+ }
+ this->function = nullptr;
+ }
+ return functionNode;
+ }
+
+ // Check if two functions have the same signature for the purposes
+ // of overload resolution.
+ bool DoFunctionSignaturesMatch(
+ FunctionSyntaxNode* fst,
+ FunctionSyntaxNode* snd)
+ {
+ // TODO(tfoley): This function won't do anything sensible for generics,
+ // so we need to figure out a plan for that...
+
+ // TODO(tfoley): This copies the parameter array, which is bad for performance.
+ auto fstParams = fst->GetParameters().ToArray();
+ auto sndParams = snd->GetParameters().ToArray();
+
+ // If the functions have different numbers of parameters, then
+ // their signatures trivially don't match.
+ auto fstParamCount = fstParams.Count();
+ auto sndParamCount = sndParams.Count();
+ if (fstParamCount != sndParamCount)
+ return false;
+
+ for (int ii = 0; ii < fstParamCount; ++ii)
+ {
+ auto fstParam = fstParams[ii];
+ auto sndParam = sndParams[ii];
+
+ // If a given parameter type doesn't match, then signatures don't match
+ if (!fstParam->Type.Equals(sndParam->Type))
+ return false;
+
+ // If one parameter is `out` and the other isn't, then they don't match
+ //
+ // Note(tfoley): we don't consider `out` and `inout` as distinct here,
+ // because there is no way for overload resolution to pick between them.
+ if (fstParam->HasModifier<OutModifier>() != sndParam->HasModifier<OutModifier>())
+ return false;
+ }
+
+ // Note(tfoley): return type doesn't enter into it, because we can't take
+ // calling context into account during overload resolution.
+
+ return true;
+ }
+
+ void ValidateFunctionRedeclaration(FunctionSyntaxNode* funcDecl)
+ {
+ auto parentDecl = funcDecl->ParentDecl;
+ assert(parentDecl);
+ if (!parentDecl) return;
+
+ // Look at previously-declared functions with the same name,
+ // in the same container
+ buildMemberDictionary(parentDecl);
+
+ for (auto prevDecl = funcDecl->nextInContainerWithSameName; prevDecl; prevDecl = prevDecl->nextInContainerWithSameName)
+ {
+ // Look through generics to the declaration underneath
+ auto prevGenericDecl = dynamic_cast<GenericDecl*>(prevDecl);
+ if (prevGenericDecl)
+ prevDecl = prevGenericDecl->inner.Ptr();
+
+ // We only care about previously-declared functions
+ // Note(tfoley): although we should really error out if the
+ // name is already in use for something else, like a variable...
+ auto prevFuncDecl = dynamic_cast<FunctionSyntaxNode*>(prevDecl);
+ if (!prevFuncDecl)
+ continue;
+
+ // If the parameter signatures don't match, then don't worry
+ if (!DoFunctionSignaturesMatch(funcDecl, prevFuncDecl))
+ continue;
+
+ // If we get this far, then we've got two declarations in the same
+ // scope, with the same name and signature.
+ //
+ // They might just be redeclarations, which we would want to allow.
+
+ // First, check if the return types match.
+ // TODO(tfolye): this code won't work for generics
+ if (!funcDecl->ReturnType.Equals(prevFuncDecl->ReturnType))
+ {
+ // Bad dedeclaration
+ getSink()->diagnose(funcDecl, Diagnostics::unimplemented, "redeclaration has a different return type");
+
+ // Don't bother emitting other errors at this point
+ break;
+ }
+
+ // TODO(tfoley): track the fact that there is redeclaration going on,
+ // so that we can detect it and react accordingly during overload resolution
+ // (e.g., by only considering one declaration as the canonical one...)
+
+ // If both have a body, then there is trouble
+ if (funcDecl->Body && prevFuncDecl->Body)
+ {
+ // Redefinition
+ getSink()->diagnose(funcDecl, Diagnostics::unimplemented, "function redefinition");
+
+ // Don't bother emitting other errors
+ break;
+ }
+
+ // TODO(tfoley): If both specific default argument expressions
+ // for the same value, then that is an error too...
+ }
+ }
+
+ void VisitFunctionDeclaration(FunctionSyntaxNode *functionNode)
+ {
+ if (functionNode->IsChecked(DeclCheckState::CheckedHeader)) return;
+ functionNode->SetCheckState(DeclCheckState::CheckingHeader);
+
+ this->function = functionNode;
+ auto returnType = CheckProperType(functionNode->ReturnType);
+ functionNode->ReturnType = returnType;
+ HashSet<String> paraNames;
+ for (auto & para : functionNode->GetParameters())
+ {
+ if (paraNames.Contains(para->Name.Content))
+ getSink()->diagnose(para, Diagnostics::parameterAlreadyDefined, para->Name);
+ else
+ paraNames.Add(para->Name.Content);
+ para->Type = CheckUsableType(para->Type);
+ if (para->Type.Equals(ExpressionType::GetVoid()))
+ getSink()->diagnose(para, Diagnostics::parameterCannotBeVoid);
+ }
+ this->function = NULL;
+ functionNode->SetCheckState(DeclCheckState::CheckedHeader);
+
+ // One last bit of validation: check if we are redeclaring an existing function
+ ValidateFunctionRedeclaration(functionNode);
+ }
+
+ virtual RefPtr<StatementSyntaxNode> VisitBlockStatement(BlockStatementSyntaxNode *stmt) override
+ {
+ for (auto & node : stmt->Statements)
+ {
+ node->Accept(this);
+ }
+ return stmt;
+ }
+
+ template<typename T>
+ T* FindOuterStmt()
+ {
+ int outerStmtCount = outerStmts.Count();
+ for (int ii = outerStmtCount - 1; ii >= 0; --ii)
+ {
+ auto outerStmt = outerStmts[ii];
+ auto found = dynamic_cast<T*>(outerStmt);
+ if (found)
+ return found;
+ }
+ return nullptr;
+ }
+
+ virtual RefPtr<StatementSyntaxNode> VisitBreakStatement(BreakStatementSyntaxNode *stmt) override
+ {
+ auto outer = FindOuterStmt<BreakableStmt>();
+ if (!outer)
+ {
+ getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop);
+ }
+ stmt->parentStmt = outer;
+ return stmt;
+ }
+ virtual RefPtr<StatementSyntaxNode> VisitContinueStatement(ContinueStatementSyntaxNode *stmt) override
+ {
+ auto outer = FindOuterStmt<LoopStmt>();
+ if (!outer)
+ {
+ getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop);
+ }
+ stmt->parentStmt = outer;
+ return stmt;
+ }
+
+ void PushOuterStmt(StatementSyntaxNode* stmt)
+ {
+ outerStmts.Add(stmt);
+ }
+
+ void PopOuterStmt(StatementSyntaxNode* /*stmt*/)
+ {
+ outerStmts.RemoveAt(outerStmts.Count() - 1);
+ }
+
+ virtual RefPtr<StatementSyntaxNode> VisitDoWhileStatement(DoWhileStatementSyntaxNode *stmt) override
+ {
+ PushOuterStmt(stmt);
+ if (stmt->Predicate != NULL)
+ stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>();
+ if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) &&
+ !stmt->Predicate->Type->Equals(ExpressionType::GetInt()) &&
+ !stmt->Predicate->Type->Equals(ExpressionType::GetBool()))
+ {
+ getSink()->diagnose(stmt, Diagnostics::whilePredicateTypeError);
+ }
+ stmt->Statement->Accept(this);
+
+ PopOuterStmt(stmt);
+ return stmt;
+ }
+ virtual RefPtr<StatementSyntaxNode> VisitForStatement(ForStatementSyntaxNode *stmt) override
+ {
+ PushOuterStmt(stmt);
+ if (stmt->InitialStatement)
+ {
+ stmt->InitialStatement = stmt->InitialStatement->Accept(this).As<StatementSyntaxNode>();
+ }
+ if (stmt->PredicateExpression)
+ {
+ stmt->PredicateExpression = stmt->PredicateExpression->Accept(this).As<ExpressionSyntaxNode>();
+ if (!stmt->PredicateExpression->Type->Equals(ExpressionType::GetBool()) &&
+ !stmt->PredicateExpression->Type->Equals(ExpressionType::GetInt()) &&
+ !stmt->PredicateExpression->Type->Equals(ExpressionType::GetUInt()))
+ {
+ getSink()->diagnose(stmt->PredicateExpression.Ptr(), Diagnostics::forPredicateTypeError);
+ }
+ }
+ if (stmt->SideEffectExpression)
+ {
+ stmt->SideEffectExpression = stmt->SideEffectExpression->Accept(this).As<ExpressionSyntaxNode>();
+ }
+ stmt->Statement->Accept(this);
+
+ PopOuterStmt(stmt);
+ return stmt;
+ }
+ virtual RefPtr<SwitchStmt> VisitSwitchStmt(SwitchStmt* stmt) override
+ {
+ PushOuterStmt(stmt);
+ // TODO(tfoley): need to coerce condition to an integral type...
+ stmt->condition = CheckExpr(stmt->condition);
+ stmt->body->Accept(this);
+ PopOuterStmt(stmt);
+ return stmt;
+ }
+ virtual RefPtr<CaseStmt> VisitCaseStmt(CaseStmt* stmt) override
+ {
+ auto expr = CheckExpr(stmt->expr);
+ auto switchStmt = FindOuterStmt<SwitchStmt>();
+
+ if (!switchStmt)
+ {
+ getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch);
+ }
+ else
+ {
+ // TODO: need to do some basic matching to ensure the type
+ // for the `case` is consistent with the type for the `switch`...
+ }
+
+ stmt->expr = expr;
+ stmt->parentStmt = switchStmt;
+
+ return stmt;
+ }
+ virtual RefPtr<DefaultStmt> VisitDefaultStmt(DefaultStmt* stmt) override
+ {
+ auto switchStmt = FindOuterStmt<SwitchStmt>();
+ if (!switchStmt)
+ {
+ getSink()->diagnose(stmt, Diagnostics::defaultOutsideSwitch);
+ }
+ stmt->parentStmt = switchStmt;
+ return stmt;
+ }
+ virtual RefPtr<StatementSyntaxNode> VisitIfStatement(IfStatementSyntaxNode *stmt) override
+ {
+ auto condition = stmt->Predicate;
+ condition = CheckTerm(condition);
+ condition = Coerce(ExpressionType::GetBool(), condition);
+
+ stmt->Predicate = condition;
+
+#if 0
+ if (stmt->Predicate != NULL)
+ stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>();
+ if (!stmt->Predicate->Type->Equals(ExpressionType::GetError())
+ && (!stmt->Predicate->Type->Equals(ExpressionType::GetInt()) &&
+ !stmt->Predicate->Type->Equals(ExpressionType::GetBool())))
+ getSink()->diagnose(stmt, Diagnostics::ifPredicateTypeError);
+#endif
+
+ if (stmt->PositiveStatement != NULL)
+ stmt->PositiveStatement->Accept(this);
+
+ if (stmt->NegativeStatement != NULL)
+ stmt->NegativeStatement->Accept(this);
+ return stmt;
+ }
+ virtual RefPtr<StatementSyntaxNode> VisitReturnStatement(ReturnStatementSyntaxNode *stmt) override
+ {
+ if (!stmt->Expression)
+ {
+ if (function && !function->ReturnType.Equals(ExpressionType::GetVoid()))
+ getSink()->diagnose(stmt, Diagnostics::returnNeedsExpression);
+ }
+ else
+ {
+ stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>();
+ if (!stmt->Expression->Type->Equals(ExpressionType::Error.Ptr()))
+ {
+ if (function)
+ {
+ stmt->Expression = Coerce(function->ReturnType, stmt->Expression);
+ }
+ else
+ {
+ // TODO(tfoley): this case currently gets triggered for member functions,
+ // which aren't being checked consistently (because of the whole symbol
+ // table idea getting in the way).
+
+// getSink()->diagnose(stmt, Diagnostics::unimplemented, "case for return stmt");
+ }
+ }
+ }
+ return stmt;
+ }
+
+ int GetMinBound(RefPtr<IntVal> val)
+ {
+ if (auto constantVal = val.As<ConstantIntVal>())
+ return constantVal->value;
+
+ // TODO(tfoley): Need to track intervals so that this isn't just a lie...
+ return 1;
+ }
+
+ void maybeInferArraySizeForVariable(Variable* varDecl)
+ {
+ // Not an array?
+ auto arrayType = varDecl->Type->AsArrayType();
+ if (!arrayType) return;
+
+ // Explicit element count given?
+ auto elementCount = arrayType->ArrayLength;
+ if (elementCount) return;
+
+ // No initializer?
+ auto initExpr = varDecl->Expr;
+ if(!initExpr) return;
+
+ // Is the initializer an initializer list?
+ if(auto initializerListExpr = initExpr.As<InitializerListExpr>())
+ {
+ auto argCount = initializerListExpr->args.Count();
+ elementCount = new ConstantIntVal(argCount);
+ }
+ // Is the type of the initializer an array type?
+ else if(auto arrayInitType = initExpr->Type->As<ArrayExpressionType>())
+ {
+ elementCount = arrayInitType->ArrayLength;
+ }
+ else
+ {
+ // Nothing to do: we couldn't infer a size
+ return;
+ }
+
+ // Create a new array type based on the size we found,
+ // and install it into our type.
+ auto newArrayType = new ArrayExpressionType();
+ newArrayType->BaseType = arrayType->BaseType;
+ newArrayType->ArrayLength = elementCount;
+
+ // Okay we are good to go!
+ varDecl->Type.type = newArrayType;
+ }
+
+ void ValidateArraySizeForVariable(Variable* varDecl)
+ {
+ auto arrayType = varDecl->Type->AsArrayType();
+ if (!arrayType) return;
+
+ auto elementCount = arrayType->ArrayLength;
+ if (!elementCount)
+ {
+ // Note(tfoley): For now we allow arrays of unspecified size
+ // everywhere, because some source languages (e.g., GLSL)
+ // allow them in specific cases.
+#if 0
+ getSink()->diagnose(varDecl, Diagnostics::invalidArraySize);
+#endif
+ return;
+ }
+
+ // TODO(tfoley): How to handle the case where bound isn't known?
+ if (GetMinBound(elementCount) <= 0)
+ {
+ getSink()->diagnose(varDecl, Diagnostics::invalidArraySize);
+ return;
+ }
+ }
+
+ virtual RefPtr<Variable> VisitDeclrVariable(Variable* varDecl)
+ {
+ TypeExp typeExp = CheckUsableType(varDecl->Type);
+#if 0
+ if (typeExp.type->GetBindableResourceType() != BindableResourceType::NonBindable)
+ {
+ // We don't want to allow bindable resource types as local variables (at least for now).
+ auto parentDecl = varDecl->ParentDecl;
+ if (auto parentScopeDecl = dynamic_cast<ScopeDecl*>(parentDecl))
+ {
+ getSink()->diagnose(varDecl->Type, Diagnostics::invalidTypeForLocalVariable);
+ }
+ }
+#endif
+ varDecl->Type = typeExp;
+ if (varDecl->Type.Equals(ExpressionType::GetVoid()))
+ getSink()->diagnose(varDecl, Diagnostics::invalidTypeVoid);
+
+ if(auto initExpr = varDecl->Expr)
+ {
+ initExpr = CheckTerm(initExpr);
+ varDecl->Expr = initExpr;
+ }
+
+ // If this is an array variable, then we first want to give
+ // it a chance to infer an array size from its initializer
+ //
+ // TODO(tfoley): May need to extend this to handle the
+ // multi-dimensional case...
+ maybeInferArraySizeForVariable(varDecl);
+ //
+ // Next we want to make sure that the declared (or inferred)
+ // size for the array meets whatever language-specific
+ // constraints we want to enforce (e.g., disallow empty
+ // arrays in specific cases)
+ ValidateArraySizeForVariable(varDecl);
+
+
+ if(auto initExpr = varDecl->Expr)
+ {
+ // TODO(tfoley): should coercion of initializer lists be special-cased
+ // here, or handled as a general case for coercion?
+
+ initExpr = Coerce(varDecl->Type, initExpr);
+ varDecl->Expr = initExpr;
+ }
+
+ varDecl->SetCheckState(DeclCheckState::Checked);
+
+ return varDecl;
+ }
+
+ virtual RefPtr<StatementSyntaxNode> VisitWhileStatement(WhileStatementSyntaxNode *stmt) override
+ {
+ PushOuterStmt(stmt);
+ stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>();
+ if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) &&
+ !stmt->Predicate->Type->Equals(ExpressionType::GetInt()) &&
+ !stmt->Predicate->Type->Equals(ExpressionType::GetBool()))
+ getSink()->diagnose(stmt, Diagnostics::whilePredicateTypeError2);
+
+ stmt->Statement->Accept(this);
+ PopOuterStmt(stmt);
+ return stmt;
+ }
+ virtual RefPtr<StatementSyntaxNode> VisitExpressionStatement(ExpressionStatementSyntaxNode *stmt) override
+ {
+ stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>();
+ return stmt;
+ }
+ virtual RefPtr<ExpressionSyntaxNode> VisitOperatorExpression(OperatorExpressionSyntaxNode *expr) override
+ {
+#if 0
+
+ for (int i = 0; i < expr->Arguments.Count(); i++)
+ expr->Arguments[i] = expr->Arguments[i]->Accept(this).As<ExpressionSyntaxNode>();
+ auto & leftType = expr->Arguments[0]->Type;
+ QualType rightType;
+ if (expr->Arguments.Count() == 2)
+ rightType = expr->Arguments[1]->Type;
+ RefPtr<ExpressionType> matchedType;
+ auto checkAssign = [&]()
+ {
+ if (!leftType.IsLeftValue &&
+ !leftType->Equals(ExpressionType::Error.Ptr()))
+ getSink()->diagnose(expr->Arguments[0].Ptr(), Diagnostics::assignNonLValue);
+ if (expr->Operator == Operator::AndAssign ||
+ expr->Operator == Operator::OrAssign ||
+ expr->Operator == Operator::XorAssign ||
+ expr->Operator == Operator::LshAssign ||
+ expr->Operator == Operator::RshAssign)
+ {
+#if 0
+ if (!(leftType->IsIntegral() && rightType->IsIntegral()))
+ {
+ // TODO(tfoley): This diagnostic shouldn't be handled here
+// getSink()->diagnose(expr, Diagnostics::bitOperationNonIntegral);
+ }
+#endif
+ }
+
+ // TODO(tfoley): Need to actual insert coercion here...
+ if(CanCoerce(leftType, expr->Type))
+ expr->Type = leftType;
+ else
+ expr->Type = ExpressionType::Error;
+ };
+#if 0
+ if (expr->Operator == Operator::Assign)
+ {
+ expr->Type = rightType;
+ checkAssign();
+ }
+ else
+#endif
+ {
+ expr->FunctionExpr = CheckExpr(expr->FunctionExpr);
+ CheckInvokeExprWithCheckedOperands(expr);
+ if (expr->Operator > Operator::Assign)
+ checkAssign();
+ }
+ return expr;
+#endif
+
+ // Treat operator application just like a function call
+ return VisitInvokeExpression(expr);
+ }
+ virtual RefPtr<ExpressionSyntaxNode> VisitConstantExpression(ConstantExpressionSyntaxNode *expr) override
+ {
+ switch (expr->ConstType)
+ {
+ case ConstantExpressionSyntaxNode::ConstantType::Int:
+ expr->Type = ExpressionType::GetInt();
+ break;
+ case ConstantExpressionSyntaxNode::ConstantType::Bool:
+ expr->Type = ExpressionType::GetBool();
+ break;
+ case ConstantExpressionSyntaxNode::ConstantType::Float:
+ expr->Type = ExpressionType::GetFloat();
+ break;
+ default:
+ expr->Type = ExpressionType::Error;
+ throw "Invalid constant type.";
+ break;
+ }
+ return expr;
+ }
+
+ IntVal* GetIntVal(ConstantExpressionSyntaxNode* expr)
+ {
+ // TODO(tfoley): don't keep allocating here!
+ return new ConstantIntVal(expr->IntValue);
+ }
+
+ RefPtr<IntVal> TryConstantFoldExpr(
+ InvokeExpressionSyntaxNode* invokeExpr)
+ {
+ // We need all the operands to the expression
+
+ // Check if the callee is an operation that is amenable to constant-folding.
+ //
+ // For right now we will look for calls to intrinsic functions, and then inspect
+ // their names (this is bad and slow).
+ auto funcDeclRefExpr = invokeExpr->FunctionExpr.As<DeclRefExpr>();
+ if (!funcDeclRefExpr) return nullptr;
+
+ auto funcDeclRef = funcDeclRefExpr->declRef;
+ auto intrinsicMod = funcDeclRef.GetDecl()->FindModifier<IntrinsicModifier>();
+ if (!intrinsicMod) return nullptr;
+
+ // Let's not constant-fold operations with more than a certain number of arguments, for simplicity
+ static const int kMaxArgs = 8;
+ if (invokeExpr->Arguments.Count() > kMaxArgs)
+ return nullptr;
+
+ // Before checking the operation name, let's look at the arguments
+ RefPtr<IntVal> argVals[kMaxArgs];
+ int constArgVals[kMaxArgs];
+ int argCount = 0;
+ bool allConst = true;
+ for (auto argExpr : invokeExpr->Arguments)
+ {
+ auto argVal = TryCheckIntegerConstantExpression(argExpr.Ptr());
+ if (!argVal)
+ return nullptr;
+
+ argVals[argCount] = argVal;
+
+ if (auto constArgVal = argVal.As<ConstantIntVal>())
+ {
+ constArgVals[argCount] = constArgVal->value;
+ }
+ else
+ {
+ allConst = false;
+ }
+ argCount++;
+ }
+
+ if (!allConst)
+ {
+ // TODO(tfoley): We probably want to support a very limited number of operations
+ // on "constants" that aren't actually known, to be able to handle a generic
+ // that takes an integer `N` but then constructs a vector of size `N+1`.
+ //
+ // The hard part there is implementing the rules for value unification in the
+ // presence of more complicated `IntVal` subclasses, like `SumIntVal`. You'd
+ // need inference to be smart enough to know that `2 + N` and `N + 2` are the
+ // same value, as are `N + M + 1 + 1` and `M + 2 + N`.
+ //
+ // For now we can just bail in this case.
+ return nullptr;
+ }
+
+ // At this point, all the operands had simple integer values, so we are golden.
+ int resultValue = 0;
+ auto opName = funcDeclRef.GetName();
+
+ // handle binary operators
+ if (opName == "-")
+ {
+ if (argCount == 1)
+ {
+ resultValue = -constArgVals[0];
+ }
+ else if (argCount == 2)
+ {
+ resultValue = constArgVals[0] - constArgVals[1];
+ }
+ }
+
+ // simple binary operators
+#define CASE(OP) \
+ else if(opName == #OP) do { \
+ if(argCount != 2) return nullptr; \
+ resultValue = constArgVals[0] OP constArgVals[1]; \
+ } while(0)
+
+ CASE(+); // TODO: this can also be unary...
+ CASE(*);
+#undef CASE
+
+ // binary operators with chance of divide-by-zero
+ // TODO: issue a suitable error in that case
+#define CASE(OP) \
+ else if(opName == #OP) do { \
+ if(argCount != 2) return nullptr; \
+ if(!constArgVals[1]) return nullptr; \
+ resultValue = constArgVals[0] OP constArgVals[1]; \
+ } while(0)
+
+ CASE(/);
+ CASE(%);
+#undef CASE
+
+ // TODO(tfoley): more cases
+ else
+ {
+ return nullptr;
+ }
+
+ RefPtr<IntVal> result = new ConstantIntVal(resultValue);
+ return result;
+ }
+
+ RefPtr<IntVal> TryConstantFoldExpr(
+ ExpressionSyntaxNode* expr)
+ {
+ // TODO(tfoley): more serious constant folding here
+ if (auto constExp = dynamic_cast<ConstantExpressionSyntaxNode*>(expr))
+ {
+ return GetIntVal(constExp);
+ }
+
+ // it is possible that we are referring to a generic value param
+ if (auto declRefExpr = dynamic_cast<DeclRefExpr*>(expr))
+ {
+ auto declRef = declRefExpr->declRef;
+
+ if (auto genericValParamRef = declRef.As<GenericValueParamDeclRef>())
+ {
+ // TODO(tfoley): handle the case of non-`int` value parameters...
+ return new GenericParamIntVal(genericValParamRef);
+ }
+
+ // We may also need to check for references to variables that
+ // are defined in a way that can be used as a constant expression:
+ if(auto varRef = declRef.As<VarDeclBaseRef>())
+ {
+ auto varDecl = varRef.GetDecl();
+
+ switch(sourceLanguage)
+ {
+ case SourceLanguage::Slang:
+ case SourceLanguage::HLSL:
+ // HLSL: `static const` is used to mark compile-time constant expressions
+ if(auto staticAttr = varDecl->FindModifier<HLSLStaticModifier>())
+ {
+ if(auto constAttr = varDecl->FindModifier<ConstModifier>())
+ {
+ // HLSL `static const` can be used as a constant expression
+ if(auto initExpr = varRef.getInitExpr())
+ {
+ return TryConstantFoldExpr(initExpr.Ptr());
+ }
+ }
+ }
+ break;
+
+ case SourceLanguage::GLSL:
+ // GLSL: `const` indicates compile-time constant expression
+ //
+ // TODO(tfoley): The current logic here isn't robust against
+ // GLSL "specialization constants" - we will extract the
+ // initializer for a `const` variable and use it to extract
+ // a value, when we really should be using an opaque
+ // reference to the variable.
+ if(auto constAttr = varDecl->FindModifier<ConstModifier>())
+ {
+ // We need to handle a "specialization constant" (with a `constant_id` layout modifier)
+ // differently from an ordinary compile-time constant. The latter can/should be reduced
+ // to a value, while the former should be kept as a symbolic reference
+
+ if(auto constantIDModifier = varDecl->FindModifier<GLSLConstantIDLayoutModifier>())
+ {
+ // Retain the specialization constant as a symbolic reference
+ //
+ // TODO(tfoley): handle the case of non-`int` value parameters...
+ //
+ // TODO(tfoley): this is cloned from the case above that handles generic value parameters
+ return new GenericParamIntVal(varRef);
+ }
+ else if(auto initExpr = varRef.getInitExpr())
+ {
+ // This is an ordinary constant, and not a specialization constant, so we
+ // can try to fold its value right now.
+ return TryConstantFoldExpr(initExpr.Ptr());
+ }
+ }
+ break;
+ }
+
+ }
+ }
+
+ if (auto invokeExpr = dynamic_cast<InvokeExpressionSyntaxNode*>(expr))
+ {
+ auto val = TryConstantFoldExpr(invokeExpr);
+ if (val)
+ return val;
+ }
+ else if(auto castExpr = dynamic_cast<TypeCastExpressionSyntaxNode*>(expr))
+ {
+ auto val = TryConstantFoldExpr(castExpr->Expression.Ptr());
+ if(val)
+ return val;
+ }
+
+ return nullptr;
+ }
+
+ // Try to check an integer constant expression, either returning the value,
+ // or NULL if the expression isn't recognized as a constant.
+ RefPtr<IntVal> TryCheckIntegerConstantExpression(ExpressionSyntaxNode* exp)
+ {
+ if (!exp->Type.type->Equals(ExpressionType::GetInt()))
+ {
+ return nullptr;
+ }
+
+
+
+ // Otherwise, we need to consider operations that we might be able to constant-fold...
+ return TryConstantFoldExpr(exp);
+ }
+
+ // Enforce that an expression resolves to an integer constant, and get its value
+ RefPtr<IntVal> CheckIntegerConstantExpression(ExpressionSyntaxNode* inExpr)
+ {
+ // First coerce the expression to the expected type
+ auto expr = Coerce(ExpressionType::GetInt(),inExpr);
+ auto result = TryCheckIntegerConstantExpression(expr.Ptr());
+ if (!result)
+ {
+ getSink()->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant);
+ }
+ return result;
+ }
+
+
+
+ RefPtr<ExpressionSyntaxNode> CheckSimpleSubscriptExpr(
+ RefPtr<IndexExpressionSyntaxNode> subscriptExpr,
+ RefPtr<ExpressionType> elementType)
+ {
+ auto baseExpr = subscriptExpr->BaseExpression;
+ auto indexExpr = subscriptExpr->IndexExpression;
+
+ if (!indexExpr->Type->Equals(ExpressionType::GetInt()) &&
+ !indexExpr->Type->Equals(ExpressionType::GetUInt()))
+ {
+ getSink()->diagnose(indexExpr, Diagnostics::subscriptIndexNonInteger);
+ return CreateErrorExpr(subscriptExpr.Ptr());
+ }
+
+ subscriptExpr->Type = elementType;
+
+ // TODO(tfoley): need to be more careful about this stuff
+ subscriptExpr->Type.IsLeftValue = baseExpr->Type.IsLeftValue;
+
+ return subscriptExpr;
+ }
+
+ // The way that we have designed out type system, pretyt much *every*
+ // type is a reference to some declaration in the standard library.
+ // That means that when we construct a new type on the fly, we need
+ // to make sure that it is wired up to reference the appropriate
+ // declaration, or else it won't compare as equal to other types
+ // that *do* reference the declaration.
+ //
+ // This function is used to construct a `vector<T,N>` type
+ // programmatically, so that it will work just like a type of
+ // that form constructed by the user.
+ RefPtr<VectorExpressionType> createVectorType(
+ RefPtr<ExpressionType> elementType,
+ RefPtr<IntVal> elementCount)
+ {
+ auto vectorGenericDecl = findMagicDecl("Vector").As<GenericDecl>();
+ auto vectorTypeDecl = vectorGenericDecl->inner;
+
+ auto substitutions = new Substitutions();
+ substitutions->genericDecl = vectorGenericDecl.Ptr();
+ substitutions->args.Add(elementType);
+ substitutions->args.Add(elementCount);
+
+ auto declRef = DeclRef(vectorTypeDecl.Ptr(), substitutions);
+
+ return DeclRefType::Create(declRef)->As<VectorExpressionType>();
+ }
+
+ virtual RefPtr<ExpressionSyntaxNode> VisitIndexExpression(IndexExpressionSyntaxNode* subscriptExpr) override
+ {
+ auto baseExpr = subscriptExpr->BaseExpression;
+ baseExpr = CheckExpr(baseExpr);
+
+ RefPtr<ExpressionSyntaxNode> indexExpr = subscriptExpr->IndexExpression;
+ if (indexExpr)
+ {
+ indexExpr = CheckExpr(indexExpr);
+ }
+
+ subscriptExpr->BaseExpression = baseExpr;
+ subscriptExpr->IndexExpression = indexExpr;
+
+ // If anything went wrong in the base expression,
+ // then just move along...
+ if (IsErrorExpr(baseExpr))
+ return CreateErrorExpr(subscriptExpr);
+
+ // Otherwise, we need to look at the type of the base expression,
+ // to figure out how subscripting should work.
+ auto baseType = baseExpr->Type.Ptr();
+ if (auto baseTypeType = baseType->As<TypeType>())
+ {
+ // We are trying to "index" into a type, so we have an expression like `float[2]`
+ // which should be interpreted as resolving to an array type.
+
+ RefPtr<IntVal> elementCount = nullptr;
+ if (indexExpr)
+ {
+ elementCount = CheckIntegerConstantExpression(indexExpr.Ptr());
+ }
+
+ auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->type));
+ auto arrayType = new ArrayExpressionType();
+ arrayType->BaseType = elementType;
+ arrayType->ArrayLength = elementCount;
+
+ typeResult = arrayType;
+ subscriptExpr->Type = new TypeType(arrayType);
+ return subscriptExpr;
+ }
+ else if (auto baseArrayType = baseType->As<ArrayExpressionType>())
+ {
+ return CheckSimpleSubscriptExpr(
+ subscriptExpr,
+ baseArrayType->BaseType);
+ }
+ else if (auto vecType = baseType->As<VectorExpressionType>())
+ {
+ return CheckSimpleSubscriptExpr(
+ subscriptExpr,
+ vecType->elementType);
+ }
+ else if (auto matType = baseType->As<MatrixExpressionType>())
+ {
+ // TODO(tfoley): We shouldn't go and recompute
+ // row types over and over like this... :(
+ auto rowType = createVectorType(
+ matType->getElementType(),
+ matType->getColumnCount());
+
+ return CheckSimpleSubscriptExpr(
+ subscriptExpr,
+ rowType);
+ }
+
+ // Default behavior is to look at all available `__subscript`
+ // declarations on the type and try to call one of them.
+
+ if (auto declRefType = baseType->AsDeclRefType())
+ {
+ if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>())
+ {
+ // Checking of the type must be complete before we can reference its members safely
+ EnsureDecl(aggTypeDeclRef.GetDecl(), DeclCheckState::Checked);
+
+ // Note(tfoley): The name used for lookup here is a bit magical, since
+ // it must match what the parser installed in subscript declarations.
+ LookupResult lookupResult = LookUpLocal("operator[]", aggTypeDeclRef);
+ if (!lookupResult.isValid())
+ {
+ goto fail;
+ }
+
+ RefPtr<ExpressionSyntaxNode> subscriptFuncExpr = createLookupResultExpr(
+ lookupResult, subscriptExpr->BaseExpression, subscriptExpr);
+
+ // Now that we know there is at least one subscript member,
+ // we will construct a reference to it and try to call it
+
+ RefPtr<InvokeExpressionSyntaxNode> subscriptCallExpr = new InvokeExpressionSyntaxNode();
+ subscriptCallExpr->Position = subscriptExpr->Position;
+ subscriptCallExpr->FunctionExpr = subscriptFuncExpr;
+
+ // TODO(tfoley): This path can support multiple arguments easily
+ subscriptCallExpr->Arguments.Add(subscriptExpr->IndexExpression);
+
+ return CheckInvokeExprWithCheckedOperands(subscriptCallExpr.Ptr());
+ }
+ }
+
+ fail:
+ {
+ getSink()->diagnose(subscriptExpr, Diagnostics::subscriptNonArray, baseType);
+ return CreateErrorExpr(subscriptExpr);
+ }
+ }
+
+ bool MatchArguments(FunctionSyntaxNode * functionNode, List <RefPtr<ExpressionSyntaxNode>> &args)
+ {
+ if (functionNode->GetParameters().Count() != args.Count())
+ return false;
+ int i = 0;
+ for (auto param : functionNode->GetParameters())
+ {
+ if (!param->Type.Equals(args[i]->Type.Ptr()))
+ return false;
+ i++;
+ }
+ return true;
+ }
+
+ // Coerce an expression to a specific type that it is expected to have in context
+ RefPtr<ExpressionSyntaxNode> CoerceExprToType(
+ RefPtr<ExpressionSyntaxNode> expr,
+ RefPtr<ExpressionType> type)
+ {
+ // TODO(tfoley): clean this up so there is only one version...
+ return Coerce(type, expr);
+ }
+
+ // Resolve a call to a function, represented here
+ // by a symbol with a `FuncType` type.
+ RefPtr<ExpressionSyntaxNode> ResolveFunctionApp(
+ RefPtr<FuncType> funcType,
+ InvokeExpressionSyntaxNode* /*appExpr*/)
+ {
+ // TODO(tfoley): Actual checking logic needs to go here...
+#if 0
+ auto& args = appExpr->Arguments;
+ List<RefPtr<ParameterSyntaxNode>> params;
+ RefPtr<ExpressionType> resultType;
+ if (auto funcDeclRef = funcType->declRef)
+ {
+ EnsureDecl(funcDeclRef.GetDecl());
+
+ params = funcDeclRef->GetParameters().ToArray();
+ resultType = funcDecl->ReturnType;
+ }
+ else if (auto funcSym = funcType->Func)
+ {
+ auto funcDecl = funcSym->SyntaxNode;
+ EnsureDecl(funcDecl);
+
+ params = funcDecl->GetParameters().ToArray();
+ resultType = funcDecl->ReturnType;
+ }
+ else if (auto componentFuncSym = funcType->Component)
+ {
+ auto componentFuncDecl = componentFuncSym->Implementations.First()->SyntaxNode;
+ params = componentFuncDecl->GetParameters().ToArray();
+ resultType = componentFuncDecl->Type;
+ }
+
+ auto argCount = args.Count();
+ auto paramCount = params.Count();
+ if (argCount != paramCount)
+ {
+ getSink()->diagnose(appExpr, Diagnostics::unimplemented, "wrong number of arguments for call");
+ appExpr->Type = ExpressionType::Error;
+ return appExpr;
+ }
+
+ for (int ii = 0; ii < argCount; ++ii)
+ {
+ auto arg = args[ii];
+ auto param = params[ii];
+
+ arg = CoerceExprToType(arg, param->Type);
+
+ args[ii] = arg;
+ }
+
+ assert(resultType);
+ appExpr->Type = resultType;
+ return appExpr;
+#else
+ throw "unimplemented";
+#endif
+ }
+
+ // Resolve a constructor call, formed by apply a type to arguments
+ RefPtr<ExpressionSyntaxNode> ResolveConstructorApp(
+ RefPtr<ExpressionType> type,
+ InvokeExpressionSyntaxNode* appExpr)
+ {
+ // TODO(tfoley): Actual checking logic needs to go here...
+
+ appExpr->Type = type;
+ return appExpr;
+ }
+
+
+ //
+
+ virtual void VisitExtensionDecl(ExtensionDecl* decl) override
+ {
+ if (decl->IsChecked(DeclCheckState::Checked)) return;
+
+ decl->SetCheckState(DeclCheckState::CheckingHeader);
+ decl->targetType = CheckProperType(decl->targetType);
+
+ // TODO: need to check that the target type names a declaration...
+
+ if (auto targetDeclRefType = decl->targetType->As<DeclRefType>())
+ {
+ // Attach our extension to that type as a candidate...
+ if (auto aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDeclRef>())
+ {
+ auto aggTypeDecl = aggTypeDeclRef.GetDecl();
+ decl->nextCandidateExtension = aggTypeDecl->candidateExtensions;
+ aggTypeDecl->candidateExtensions = decl;
+ }
+ else
+ {
+ getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here");
+ }
+ }
+ else if (decl->targetType->Equals(ExpressionType::Error))
+ {
+ // there was an error, so ignore
+ }
+ else
+ {
+ getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here");
+ }
+
+ decl->SetCheckState(DeclCheckState::CheckedHeader);
+
+ // now check the members of the extension
+ for (auto m : decl->Members)
+ {
+ EnsureDecl(m);
+ }
+
+ decl->SetCheckState(DeclCheckState::Checked);
+ }
+
+ virtual void VisitConstructorDecl(ConstructorDecl* decl) override
+ {
+ if (decl->IsChecked(DeclCheckState::Checked)) return;
+ decl->SetCheckState(DeclCheckState::CheckingHeader);
+
+ for (auto& paramDecl : decl->GetParameters())
+ {
+ paramDecl->Type = CheckUsableType(paramDecl->Type);
+ }
+ decl->SetCheckState(DeclCheckState::CheckedHeader);
+
+ // TODO(tfoley): check body
+ decl->SetCheckState(DeclCheckState::Checked);
+ }
+
+
+ virtual void visitSubscriptDecl(SubscriptDecl* decl) override
+ {
+ if (decl->IsChecked(DeclCheckState::Checked)) return;
+ decl->SetCheckState(DeclCheckState::CheckingHeader);
+
+ for (auto& paramDecl : decl->GetParameters())
+ {
+ paramDecl->Type = CheckUsableType(paramDecl->Type);
+ }
+
+ decl->ReturnType = CheckUsableType(decl->ReturnType);
+
+ decl->SetCheckState(DeclCheckState::CheckedHeader);
+
+ decl->SetCheckState(DeclCheckState::Checked);
+ }
+
+ virtual void visitAccessorDecl(AccessorDecl* decl) override
+ {
+ // TODO: check the body!
+
+ decl->SetCheckState(DeclCheckState::Checked);
+ }
+
+
+ //
+
+ struct Constraint
+ {
+ Decl* decl; // the declaration of the thing being constraints
+ RefPtr<Val> val; // the value to which we are constraining it
+ bool satisfied = false; // Has this constraint been met?
+ };
+
+ // A collection of constraints that will need to be satisified (solved)
+ // in order for checking to suceed.
+ struct ConstraintSystem
+ {
+ List<Constraint> constraints;
+ };
+
+ RefPtr<ExpressionType> TryJoinVectorAndScalarType(
+ RefPtr<VectorExpressionType> vectorType,
+ RefPtr<BasicExpressionType> scalarType)
+ {
+ // Join( vector<T,N>, S ) -> vetor<Join(T,S), N>
+ //
+ // That is, the join of a vector and a scalar type is
+ // a vector type with a joined element type.
+ auto joinElementType = TryJoinTypes(
+ vectorType->elementType,
+ scalarType);
+ if(!joinElementType)
+ return nullptr;
+
+ return createVectorType(
+ joinElementType,
+ vectorType->elementCount);
+ }
+
+ bool DoesTypeConformToTrait(
+ RefPtr<ExpressionType> type,
+ TraitDeclRef traitDeclRef)
+ {
+ // for now look up a conformance member...
+ if(auto declRefType = type->As<DeclRefType>())
+ {
+ if( auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>() )
+ {
+ for( auto conformanceRef : aggTypeDeclRef.GetMembersOfType<TraitConformanceDeclRef>())
+ {
+ EnsureDecl(conformanceRef.GetDecl());
+
+ if(traitDeclRef.Equals(conformanceRef.GetTraitDeclRef()))
+ return true;
+ }
+ }
+ }
+
+ // default is failure
+ return false;
+ }
+
+ RefPtr<ExpressionType> TryJoinTypeWithTrait(
+ RefPtr<ExpressionType> type,
+ TraitDeclRef traitDeclRef)
+ {
+ // The most basic test here should be: does the type declare conformance to the trait.
+ if(DoesTypeConformToTrait(type, traitDeclRef))
+ return type;
+
+ // There is a more nuanced case if `type` is a builtin type, and we need to make it
+ // conform to a trait that some but not all builtin types support (the main problem
+ // here is when an operation wants an integer type, but one of our operands is a `float`.
+ // The HLSL rules will allow that, with implicit conversion, but our default join rules
+ // will end up picking `float` and we don't want that...).
+
+ // For now we don't handle the hard case and just bail
+ return nullptr;
+ }
+
+ // Try to compute the "join" between two types
+ RefPtr<ExpressionType> TryJoinTypes(
+ RefPtr<ExpressionType> left,
+ RefPtr<ExpressionType> right)
+ {
+ // Easy case: they are the same type!
+ if (left->Equals(right))
+ return left;
+
+ // We can join two basic types by picking the "better" of the two
+ if (auto leftBasic = left->As<BasicExpressionType>())
+ {
+ if (auto rightBasic = right->As<BasicExpressionType>())
+ {
+ auto leftFlavor = leftBasic->BaseType;
+ auto rightFlavor = rightBasic->BaseType;
+
+ // TODO(tfoley): Need a special-case rule here that if
+ // either operand is of type `half`, then we promote
+ // to at least `float`
+
+ // Return the one that had higher rank...
+ if (leftFlavor > rightFlavor)
+ return left;
+ else
+ {
+ assert(rightFlavor > leftFlavor);
+ return right;
+ }
+ }
+
+ // We can also join a vector and a scalar
+ if(auto rightVector = right->As<VectorExpressionType>())
+ {
+ return TryJoinVectorAndScalarType(rightVector, leftBasic);
+ }
+ }
+
+ // We can join two vector types by joining their element types
+ // (and also their sizes...)
+ if( auto leftVector = left->As<VectorExpressionType>())
+ {
+ if(auto rightVector = right->As<VectorExpressionType>())
+ {
+ // Check if the vector sizes match
+ if(!leftVector->elementCount->EqualsVal(rightVector->elementCount.Ptr()))
+ return nullptr;
+
+ // Try to join the element types
+ auto joinElementType = TryJoinTypes(
+ leftVector->elementType,
+ rightVector->elementType);
+ if(!joinElementType)
+ return nullptr;
+
+ return createVectorType(
+ joinElementType,
+ leftVector->elementCount);
+ }
+
+ // We can also join a vector and a scalar
+ if(auto rightBasic = right->As<BasicExpressionType>())
+ {
+ return TryJoinVectorAndScalarType(leftVector, rightBasic);
+ }
+ }
+
+ // HACK: trying to work trait types in here...
+ if(auto leftDeclRefType = left->As<DeclRefType>())
+ {
+ if( auto leftTraitRef = leftDeclRefType->declRef.As<TraitDeclRef>() )
+ {
+ //
+ return TryJoinTypeWithTrait(right, leftTraitRef);
+ }
+ }
+ if(auto rightDeclRefType = right->As<DeclRefType>())
+ {
+ if( auto rightTraitRef = rightDeclRefType->declRef.As<TraitDeclRef>() )
+ {
+ //
+ return TryJoinTypeWithTrait(left, rightTraitRef);
+ }
+ }
+
+ // TODO: all the cases for vectors apply to matrices too!
+
+ // Default case is that we just fail.
+ return nullptr;
+ }
+
+ // Try to solve a system of generic constraints.
+ // The `system` argument provides the constraints.
+ // The `varSubst` argument provides the list of constraint
+ // variables that were created for the system.
+ //
+ // Returns a new substitution representing the values that
+ // we solved for along the way.
+ RefPtr<Substitutions> TrySolveConstraintSystem(
+ ConstraintSystem* system,
+ GenericDeclRef genericDeclRef)
+ {
+ // For now the "solver" is going to be ridiculously simplistic.
+
+ // The generic itself will have some constraints, so we need to try and solve those too
+ for( auto constraintDeclRef : genericDeclRef.GetMembersOfType<GenericTypeConstraintDeclRef>() )
+ {
+ if(!TryUnifyTypes(*system, constraintDeclRef.GetSub(), constraintDeclRef.GetSup()))
+ return nullptr;
+ }
+
+ // We will loop over the generic parameters, and for
+ // each we will try to find a way to satisfy all
+ // the constraints for that parameter
+ List<RefPtr<Val>> args;
+ for (auto m : genericDeclRef.GetMembers())
+ {
+ if (auto typeParam = m.As<GenericTypeParamDeclRef>())
+ {
+ RefPtr<ExpressionType> type = nullptr;
+ for (auto& c : system->constraints)
+ {
+ if (c.decl != typeParam.GetDecl())
+ continue;
+
+ auto cType = c.val.As<ExpressionType>();
+ assert(cType.Ptr());
+
+ if (!type)
+ {
+ type = cType;
+ }
+ else
+ {
+ auto joinType = TryJoinTypes(type, cType);
+ if (!joinType)
+ {
+ // failure!
+ return nullptr;
+ }
+ type = joinType;
+ }
+
+ c.satisfied = true;
+ }
+
+ if (!type)
+ {
+ // failure!
+ return nullptr;
+ }
+ args.Add(type);
+ }
+ else if (auto valParam = m.As<GenericValueParamDeclRef>())
+ {
+ // TODO(tfoley): maybe support more than integers some day?
+ // TODO(tfoley): figure out how this needs to interact with
+ // compile-time integers that aren't just constants...
+ RefPtr<IntVal> val = nullptr;
+ for (auto& c : system->constraints)
+ {
+ if (c.decl != valParam.GetDecl())
+ continue;
+
+ auto cVal = c.val.As<IntVal>();
+ assert(cVal.Ptr());
+
+ if (!val)
+ {
+ val = cVal;
+ }
+ else
+ {
+ if(!val->EqualsVal(cVal.Ptr()))
+ {
+ // failure!
+ return nullptr;
+ }
+ }
+
+ c.satisfied = true;
+ }
+
+ if (!val)
+ {
+ // failure!
+ return nullptr;
+ }
+ args.Add(val);
+ }
+ else
+ {
+ // ignore anything that isn't a generic parameter
+ }
+ }
+
+ // Make sure we haven't constructed any spurious constraints
+ // that we aren't able to satisfy:
+ for (auto c : system->constraints)
+ {
+ if (!c.satisfied)
+ {
+ return nullptr;
+ }
+ }
+
+ // Consruct a reference to the extension with our constraint variables
+ // as the
+ RefPtr<Substitutions> solvedSubst = new Substitutions();
+ solvedSubst->genericDecl = genericDeclRef.GetDecl();
+ solvedSubst->outer = genericDeclRef.substitutions;
+ solvedSubst->args = args;
+
+ return solvedSubst;
+
+
+#if 0
+ List<RefPtr<Val>> solvedArgs;
+ for (auto varArg : varSubst->args)
+ {
+ if (auto typeVar = dynamic_cast<ConstraintVarType*>(varArg.Ptr()))
+ {
+ RefPtr<ExpressionType> type = nullptr;
+ for (auto& c : system->constraints)
+ {
+ if (c.decl != typeVar->declRef.GetDecl())
+ continue;
+
+ auto cType = c.val.As<ExpressionType>();
+ assert(cType.Ptr());
+
+ if (!type)
+ {
+ type = cType;
+ }
+ else
+ {
+ if (!type->Equals(cType))
+ {
+ // failure!
+ return nullptr;
+ }
+ }
+
+ c.satisfied = true;
+ }
+
+ if (!type)
+ {
+ // failure!
+ return nullptr;
+ }
+ solvedArgs.Add(type);
+ }
+ else if (auto valueVar = dynamic_cast<ConstraintVarInt*>(varArg.Ptr()))
+ {
+ // TODO(tfoley): maybe support more than integers some day?
+ RefPtr<IntVal> val = nullptr;
+ for (auto& c : system->constraints)
+ {
+ if (c.decl != valueVar->declRef.GetDecl())
+ continue;
+
+ auto cVal = c.val.As<IntVal>();
+ assert(cVal.Ptr());
+
+ if (!val)
+ {
+ val = cVal;
+ }
+ else
+ {
+ if (val->value != cVal->value)
+ {
+ // failure!
+ return nullptr;
+ }
+ }
+
+ c.satisfied = true;
+ }
+
+ if (!val)
+ {
+ // failure!
+ return nullptr;
+ }
+ solvedArgs.Add(val);
+ }
+ else
+ {
+ // ignore anything that isn't a generic parameter
+ }
+ }
+
+ // Make sure we haven't constructed any spurious constraints
+ // that we aren't able to satisfy:
+ for (auto c : system->constraints)
+ {
+ if (!c.satisfied)
+ {
+ return nullptr;
+ }
+ }
+
+ RefPtr<Substitutions> newSubst = new Substitutions();
+ newSubst->genericDecl = varSubst->genericDecl;
+ newSubst->outer = varSubst->outer;
+ newSubst->args = solvedArgs;
+ return newSubst;
+
+#endif
+ }
+
+ //
+
+ struct OverloadCandidate
+ {
+ enum class Flavor
+ {
+ Func,
+ Generic,
+ UnspecializedGeneric,
+ };
+ Flavor flavor;
+
+ enum class Status
+ {
+ GenericArgumentInferenceFailed,
+ Unchecked,
+ ArityChecked,
+ FixityChecked,
+ TypeChecked,
+ Appicable,
+ };
+ Status status = Status::Unchecked;
+
+ // Reference to the declaration being applied
+ LookupResultItem item;
+
+ // The type of the result expression if this candidate is selected
+ RefPtr<ExpressionType> resultType;
+
+ // A system for tracking constraints introduced on generic parameters
+ ConstraintSystem constraintSystem;
+
+ // How much conversion cost should be considered for this overload,
+ // when ranking candidates.
+ ConversionCost conversionCostSum = kConversionCost_None;
+ };
+
+
+
+ // State related to overload resolution for a call
+ // to an overloaded symbol
+ struct OverloadResolveContext
+ {
+ enum class Mode
+ {
+ // We are just checking if a candidate works or not
+ JustTrying,
+
+ // We want to actually update the AST for a chosen candidate
+ ForReal,
+ };
+
+ RefPtr<AppExprBase> appExpr;
+ RefPtr<ExpressionSyntaxNode> baseExpr;
+
+ // Are we still trying out candidates, or are we
+ // checking the chosen one for real?
+ Mode mode = Mode::JustTrying;
+
+ // We store one candidate directly, so that we don't
+ // need to do dynamic allocation on the list every time
+ OverloadCandidate bestCandidateStorage;
+ OverloadCandidate* bestCandidate = nullptr;
+
+ // Full list of all candidates being considered, in the ambiguous case
+ List<OverloadCandidate> bestCandidates;
+ };
+
+ struct ParamCounts
+ {
+ int required;
+ int allowed;
+ };
+
+ // count the number of parameters required/allowed for a callable
+ ParamCounts CountParameters(FilteredMemberRefList<ParamDeclRef> params)
+ {
+ ParamCounts counts = { 0, 0 };
+ for (auto param : params)
+ {
+ counts.allowed++;
+
+ // No initializer means no default value
+ //
+ // TODO(tfoley): The logic here is currently broken in two ways:
+ //
+ // 1. We are assuming that once one parameter has a default, then all do.
+ // This can/should be validated earlier, so that we can assume it here.
+ //
+ // 2. We are not handling the possibility of multiple declarations for
+ // a single function, where we'd need to merge default parameters across
+ // all the declarations.
+ if (!param.GetDecl()->Expr)
+ {
+ counts.required++;
+ }
+ }
+ return counts;
+ }
+
+ // count the number of parameters required/allowed for a generic
+ ParamCounts CountParameters(GenericDeclRef genericRef)
+ {
+ ParamCounts counts = { 0, 0 };
+ for (auto m : genericRef.GetDecl()->Members)
+ {
+ if (auto typeParam = m.As<GenericTypeParamDecl>())
+ {
+ counts.allowed++;
+ if (!typeParam->initType.Ptr())
+ {
+ counts.required++;
+ }
+ }
+ else if (auto valParam = m.As<GenericValueParamDecl>())
+ {
+ counts.allowed++;
+ if (!valParam->Expr)
+ {
+ counts.required++;
+ }
+ }
+ }
+ return counts;
+ }
+
+ bool TryCheckOverloadCandidateArity(
+ OverloadResolveContext& context,
+ OverloadCandidate const& candidate)
+ {
+ int argCount = context.appExpr->Arguments.Count();
+ ParamCounts paramCounts = { 0, 0 };
+ switch (candidate.flavor)
+ {
+ case OverloadCandidate::Flavor::Func:
+ paramCounts = CountParameters(candidate.item.declRef.As<CallableDeclRef>().GetParameters());
+ break;
+
+ case OverloadCandidate::Flavor::Generic:
+ paramCounts = CountParameters(candidate.item.declRef.As<GenericDeclRef>());
+ break;
+
+ default:
+ assert(!"unexpected");
+ break;
+ }
+
+ if (argCount >= paramCounts.required && argCount <= paramCounts.allowed)
+ return true;
+
+ // Emit an error message if we are checking this call for real
+ if (context.mode != OverloadResolveContext::Mode::JustTrying)
+ {
+ if (argCount < paramCounts.required)
+ {
+ getSink()->diagnose(context.appExpr, Diagnostics::notEnoughArguments, argCount, paramCounts.required);
+ }
+ else
+ {
+ assert(argCount > paramCounts.allowed);
+ getSink()->diagnose(context.appExpr, Diagnostics::tooManyArguments, argCount, paramCounts.allowed);
+ }
+ }
+
+ return false;
+ }
+
+ bool TryCheckOverloadCandidateFixity(
+ OverloadResolveContext& context,
+ OverloadCandidate const& candidate)
+ {
+ auto expr = context.appExpr;
+
+ auto decl = candidate.item.declRef.decl;
+
+ if(auto prefixExpr = expr.As<PrefixExpr>())
+ {
+ if(decl->HasModifier<PrefixModifier>())
+ return true;
+
+ if (context.mode != OverloadResolveContext::Mode::JustTrying)
+ {
+ getSink()->diagnose(context.appExpr, Diagnostics::expectedPrefixOperator);
+ getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName());
+ }
+
+ return false;
+ }
+ else if(auto postfixExpr = expr.As<PostfixExpr>())
+ {
+ if(decl->HasModifier<PostfixModifier>())
+ return true;
+
+ if (context.mode != OverloadResolveContext::Mode::JustTrying)
+ {
+ getSink()->diagnose(context.appExpr, Diagnostics::expectedPostfixOperator);
+ getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName());
+ }
+
+ return false;
+ }
+ else
+ {
+ return true;
+ }
+
+ return false;
+ }
+
+ bool TryCheckGenericOverloadCandidateTypes(
+ OverloadResolveContext& context,
+ OverloadCandidate& candidate)
+ {
+ auto& args = context.appExpr->Arguments;
+
+ auto genericDeclRef = candidate.item.declRef.As<GenericDeclRef>();
+
+ int aa = 0;
+ for (auto memberRef : genericDeclRef.GetMembers())
+ {
+ if (auto typeParamRef = memberRef.As<GenericTypeParamDeclRef>())
+ {
+ auto arg = args[aa++];
+
+ if (context.mode == OverloadResolveContext::Mode::JustTrying)
+ {
+ if (!CanCoerceToProperType(TypeExp(arg)))
+ {
+ return false;
+ }
+ }
+ else
+ {
+ TypeExp typeExp = CoerceToProperType(TypeExp(arg));
+ }
+ }
+ else if (auto valParamRef = memberRef.As<GenericValueParamDeclRef>())
+ {
+ auto arg = args[aa++];
+
+ if (context.mode == OverloadResolveContext::Mode::JustTrying)
+ {
+ ConversionCost cost = kConversionCost_None;
+ if (!CanCoerce(valParamRef.GetType(), arg->Type, &cost))
+ {
+ return false;
+ }
+ candidate.conversionCostSum += cost;
+ }
+ else
+ {
+ arg = Coerce(valParamRef.GetType(), arg);
+ auto val = ExtractGenericArgInteger(arg);
+ }
+ }
+ else
+ {
+ continue;
+ }
+ }
+
+ return true;
+ }
+
+ bool TryCheckOverloadCandidateTypes(
+ OverloadResolveContext& context,
+ OverloadCandidate& candidate)
+ {
+ auto& args = context.appExpr->Arguments;
+ int argCount = args.Count();
+
+ List<ParamDeclRef> params;
+ switch (candidate.flavor)
+ {
+ case OverloadCandidate::Flavor::Func:
+ params = candidate.item.declRef.As<CallableDeclRef>().GetParameters().ToArray();
+ break;
+
+ case OverloadCandidate::Flavor::Generic:
+ return TryCheckGenericOverloadCandidateTypes(context, candidate);
+
+ default:
+ assert(!"unexpected");
+ break;
+ }
+
+ // Note(tfoley): We might have fewer arguments than parameters in the
+ // case where one or more parameters had defaults.
+ assert(argCount <= params.Count());
+
+ for (int ii = 0; ii < argCount; ++ii)
+ {
+ auto& arg = args[ii];
+ auto param = params[ii];
+
+ if (context.mode == OverloadResolveContext::Mode::JustTrying)
+ {
+ ConversionCost cost = kConversionCost_None;
+ if (!CanCoerce(param.GetType(), arg->Type, &cost))
+ {
+ return false;
+ }
+ candidate.conversionCostSum += cost;
+ }
+ else
+ {
+ arg = Coerce(param.GetType(), arg);
+ }
+ }
+ return true;
+ }
+
+ bool TryCheckOverloadCandidateDirections(
+ OverloadResolveContext& /*context*/,
+ OverloadCandidate const& /*candidate*/)
+ {
+ // TODO(tfoley): check `in` and `out` markers, as needed.
+ return true;
+ }
+
+ // Try to check an overload candidate, but bail out
+ // if any step fails
+ void TryCheckOverloadCandidate(
+ OverloadResolveContext& context,
+ OverloadCandidate& candidate)
+ {
+ if (!TryCheckOverloadCandidateArity(context, candidate))
+ return;
+
+ candidate.status = OverloadCandidate::Status::ArityChecked;
+ if (!TryCheckOverloadCandidateFixity(context, candidate))
+ return;
+
+ candidate.status = OverloadCandidate::Status::FixityChecked;
+ if (!TryCheckOverloadCandidateTypes(context, candidate))
+ return;
+
+ candidate.status = OverloadCandidate::Status::TypeChecked;
+ if (!TryCheckOverloadCandidateDirections(context, candidate))
+ return;
+
+ candidate.status = OverloadCandidate::Status::Appicable;
+ }
+
+ // Create the representation of a given generic applied to some arguments
+ RefPtr<ExpressionSyntaxNode> CreateGenericDeclRef(
+ RefPtr<ExpressionSyntaxNode> baseExpr,
+ RefPtr<AppExprBase> appExpr)
+ {
+ auto baseDeclRefExpr = baseExpr.As<DeclRefExpr>();
+ if (!baseDeclRefExpr)
+ {
+ assert(!"unexpected");
+ return CreateErrorExpr(appExpr.Ptr());
+ }
+ auto baseGenericRef = baseDeclRefExpr->declRef.As<GenericDeclRef>();
+ if (!baseGenericRef)
+ {
+ assert(!"unexpected");
+ return CreateErrorExpr(appExpr.Ptr());
+ }
+
+ RefPtr<Substitutions> subst = new Substitutions();
+ subst->genericDecl = baseGenericRef.GetDecl();
+ subst->outer = baseGenericRef.substitutions;
+
+ for (auto arg : appExpr->Arguments)
+ {
+ subst->args.Add(ExtractGenericArgVal(arg));
+ }
+
+ DeclRef innerDeclRef(baseGenericRef.GetInner(), subst);
+
+ return ConstructDeclRefExpr(
+ innerDeclRef,
+ nullptr,
+ appExpr);
+ }
+
+ // Take an overload candidate that previously got through
+ // `TryCheckOverloadCandidate` above, and try to finish
+ // up the work and turn it into a real expression.
+ //
+ // If the candidate isn't actually applicable, this is
+ // where we'd start reporting the issue(s).
+ RefPtr<ExpressionSyntaxNode> CompleteOverloadCandidate(
+ OverloadResolveContext& context,
+ OverloadCandidate& candidate)
+ {
+ // special case for generic argument inference failure
+ if (candidate.status == OverloadCandidate::Status::GenericArgumentInferenceFailed)
+ {
+ String callString = GetCallSignatureString(context.appExpr);
+ getSink()->diagnose(
+ context.appExpr,
+ Diagnostics::genericArgumentInferenceFailed,
+ callString);
+
+ String declString = getDeclSignatureString(candidate.item);
+ getSink()->diagnose(candidate.item.declRef, Diagnostics::genericSignatureTried, declString);
+ goto error;
+ }
+
+ context.mode = OverloadResolveContext::Mode::ForReal;
+ context.appExpr->Type = ExpressionType::Error;
+
+ if (!TryCheckOverloadCandidateArity(context, candidate))
+ goto error;
+
+ if (!TryCheckOverloadCandidateFixity(context, candidate))
+ goto error;
+
+ if (!TryCheckOverloadCandidateTypes(context, candidate))
+ goto error;
+
+ if (!TryCheckOverloadCandidateDirections(context, candidate))
+ goto error;
+
+ {
+ auto baseExpr = ConstructLookupResultExpr(
+ candidate.item, context.baseExpr, context.appExpr->FunctionExpr);
+
+ switch(candidate.flavor)
+ {
+ case OverloadCandidate::Flavor::Func:
+ context.appExpr->FunctionExpr = baseExpr;
+ context.appExpr->Type = candidate.resultType;
+
+ // A call may yield an l-value, and we should take a look at the candidate to be sure
+ if(auto subscriptDeclRef = candidate.item.declRef.As<SubscriptDeclRef>())
+ {
+ for(auto setter : subscriptDeclRef.GetDecl()->GetMembersOfType<SetterDecl>())
+ {
+ context.appExpr->Type.IsLeftValue = true;
+ }
+ }
+
+ // TODO: there may be other cases that confer l-value-ness
+
+ return context.appExpr;
+ break;
+
+ case OverloadCandidate::Flavor::Generic:
+ return CreateGenericDeclRef(baseExpr, context.appExpr);
+ break;
+
+ default:
+ assert(!"unexpected");
+ break;
+ }
+ }
+
+
+ error:
+ return CreateErrorExpr(context.appExpr.Ptr());
+ }
+
+ // Implement a comparison operation between overload candidates,
+ // so that the better candidate compares as less-than the other
+ int CompareOverloadCandidates(
+ OverloadCandidate* left,
+ OverloadCandidate* right)
+ {
+ // If one candidate got further along in validation, pick it
+ if (left->status != right->status)
+ return int(right->status) - int(left->status);
+
+ // If both candidates are applicable, then we need to compare
+ // the costs of their type conversion sequences
+ if(left->status == OverloadCandidate::Status::Appicable)
+ {
+ if (left->conversionCostSum != right->conversionCostSum)
+ return left->conversionCostSum - right->conversionCostSum;
+ }
+
+ return 0;
+ }
+
+ void AddOverloadCandidateInner(
+ OverloadResolveContext& context,
+ OverloadCandidate& candidate)
+ {
+ // Filter our existing candidates, to remove any that are worse than our new one
+
+ bool keepThisCandidate = true; // should this candidate be kept?
+
+ if (context.bestCandidates.Count() != 0)
+ {
+ // We have multiple candidates right now, so filter them.
+ bool anyFiltered = false;
+ // Note that we are querying the list length on every iteration,
+ // because we might remove things.
+ for (int cc = 0; cc < context.bestCandidates.Count(); ++cc)
+ {
+ int cmp = CompareOverloadCandidates(&candidate, &context.bestCandidates[cc]);
+ if (cmp < 0)
+ {
+ // our new candidate is better!
+
+ // remove it from the list (by swapping in a later one)
+ context.bestCandidates.FastRemoveAt(cc);
+ // and then reduce our index so that we re-visit the same index
+ --cc;
+
+ anyFiltered = true;
+ }
+ else if(cmp > 0)
+ {
+ // our candidate is worse!
+ keepThisCandidate = false;
+ }
+ }
+ // It should not be possible that we removed some existing candidate *and*
+ // chose not to keep this candidate (otherwise the better-ness relation
+ // isn't transitive). Therefore we confirm that we either chose to keep
+ // this candidate (in which case filtering is okay), or we didn't filter
+ // anything.
+ assert(keepThisCandidate || !anyFiltered);
+ }
+ else if(context.bestCandidate)
+ {
+ // There's only one candidate so far
+ int cmp = CompareOverloadCandidates(&candidate, context.bestCandidate);
+ if(cmp < 0)
+ {
+ // our new candidate is better!
+ context.bestCandidate = nullptr;
+ }
+ else if (cmp > 0)
+ {
+ // our candidate is worse!
+ keepThisCandidate = false;
+ }
+ }
+
+ // If our candidate isn't good enough, then drop it
+ if (!keepThisCandidate)
+ return;
+
+ // Otherwise we want to keep the candidate
+ if (context.bestCandidates.Count() > 0)
+ {
+ // There were already multiple candidates, and we are adding one more
+ context.bestCandidates.Add(candidate);
+ }
+ else if (context.bestCandidate)
+ {
+ // There was a unique best candidate, but now we are ambiguous
+ context.bestCandidates.Add(*context.bestCandidate);
+ context.bestCandidates.Add(candidate);
+ context.bestCandidate = nullptr;
+ }
+ else
+ {
+ // This is the only candidate worthe keeping track of right now
+ context.bestCandidateStorage = candidate;
+ context.bestCandidate = &context.bestCandidateStorage;
+ }
+ }
+
+ void AddOverloadCandidate(
+ OverloadResolveContext& context,
+ OverloadCandidate& candidate)
+ {
+ // Try the candidate out, to see if it is applicable at all.
+ TryCheckOverloadCandidate(context, candidate);
+
+ // Now (potentially) add it to the set of candidate overloads to consider.
+ AddOverloadCandidateInner(context, candidate);
+ }
+
+ void AddFuncOverloadCandidate(
+ LookupResultItem item,
+ CallableDeclRef funcDeclRef,
+ OverloadResolveContext& context)
+ {
+ EnsureDecl(funcDeclRef.GetDecl());
+
+ OverloadCandidate candidate;
+ candidate.flavor = OverloadCandidate::Flavor::Func;
+ candidate.item = item;
+ candidate.resultType = funcDeclRef.GetResultType();
+
+ AddOverloadCandidate(context, candidate);
+ }
+
+ void AddFuncOverloadCandidate(
+ RefPtr<FuncType> /*funcType*/,
+ OverloadResolveContext& /*context*/)
+ {
+#if 0
+ if (funcType->decl)
+ {
+ AddFuncOverloadCandidate(funcType->decl, context);
+ }
+ else if (funcType->Func)
+ {
+ AddFuncOverloadCandidate(funcType->Func->SyntaxNode, context);
+ }
+ else if (funcType->Component)
+ {
+ AddComponentFuncOverloadCandidate(funcType->Component, context);
+ }
+#else
+ throw "unimplemented";
+#endif
+ }
+
+ void AddCtorOverloadCandidate(
+ LookupResultItem typeItem,
+ RefPtr<ExpressionType> type,
+ ConstructorDeclRef ctorDeclRef,
+ OverloadResolveContext& context)
+ {
+ EnsureDecl(ctorDeclRef.GetDecl());
+
+ // `typeItem` refers to the type being constructed (the thing
+ // that was applied as a function) so we need to construct
+ // a `LookupResultItem` that refers to the constructor instead
+
+ LookupResultItem ctorItem;
+ ctorItem.declRef = ctorDeclRef;
+ ctorItem.breadcrumbs = new LookupResultItem::Breadcrumb(LookupResultItem::Breadcrumb::Kind::Member, typeItem.declRef, typeItem.breadcrumbs);
+
+ OverloadCandidate candidate;
+ candidate.flavor = OverloadCandidate::Flavor::Func;
+ candidate.item = ctorItem;
+ candidate.resultType = type;
+
+ AddOverloadCandidate(context, candidate);
+ }
+
+ // If the given declaration has generic parameters, then
+ // return the corresponding `GenericDecl` that holds the
+ // parameters, etc.
+ GenericDecl* GetOuterGeneric(Decl* decl)
+ {
+ auto parentDecl = decl->ParentDecl;
+ if (!parentDecl) return nullptr;
+ auto parentGeneric = dynamic_cast<GenericDecl*>(parentDecl);
+ return parentGeneric;
+ }
+
+ // Try to find a unification for two values
+ bool TryUnifyVals(
+ ConstraintSystem& constraints,
+ RefPtr<Val> fst,
+ RefPtr<Val> snd)
+ {
+ // if both values are types, then unify types
+ if (auto fstType = fst.As<ExpressionType>())
+ {
+ if (auto sndType = snd.As<ExpressionType>())
+ {
+ return TryUnifyTypes(constraints, fstType, sndType);
+ }
+ }
+
+ // if both values are constant integers, then compare them
+ if (auto fstIntVal = fst.As<ConstantIntVal>())
+ {
+ if (auto sndIntVal = snd.As<ConstantIntVal>())
+ {
+ return fstIntVal->value == sndIntVal->value;
+ }
+ }
+
+ // Check if both are integer values in general
+ if (auto fstInt = fst.As<IntVal>())
+ {
+ if (auto sndInt = snd.As<IntVal>())
+ {
+ auto fstParam = fstInt.As<GenericParamIntVal>();
+ auto sndParam = sndInt.As<GenericParamIntVal>();
+
+ if (fstParam)
+ TryUnifyIntParam(constraints, fstParam->declRef, sndInt);
+ if (sndParam)
+ TryUnifyIntParam(constraints, sndParam->declRef, fstInt);
+
+ if (fstParam || sndParam)
+ return true;
+ }
+ }
+
+ throw "unimplemented";
+
+ // default: fail
+ return false;
+ }
+
+ bool TryUnifySubstitutions(
+ ConstraintSystem& constraints,
+ RefPtr<Substitutions> fst,
+ RefPtr<Substitutions> snd)
+ {
+ // They must both be NULL or non-NULL
+ if (!fst || !snd)
+ return fst == snd;
+
+ // They must be specializing the same generic
+ if (fst->genericDecl != snd->genericDecl)
+ return false;
+
+ // Their arguments must unify
+ assert(fst->args.Count() == snd->args.Count());
+ int argCount = fst->args.Count();
+ for (int aa = 0; aa < argCount; ++aa)
+ {
+ if (!TryUnifyVals(constraints, fst->args[aa], snd->args[aa]))
+ return false;
+ }
+
+ // Their "base" specializations must unify
+ if (!TryUnifySubstitutions(constraints, fst->outer, snd->outer))
+ return false;
+
+ return true;
+ }
+
+ bool TryUnifyTypeParam(
+ ConstraintSystem& constraints,
+ RefPtr<GenericTypeParamDecl> typeParamDecl,
+ RefPtr<ExpressionType> type)
+ {
+ // We want to constrain the given type parameter
+ // to equal the given type.
+ Constraint constraint;
+ constraint.decl = typeParamDecl.Ptr();
+ constraint.val = type;
+
+ constraints.constraints.Add(constraint);
+
+ return true;
+ }
+
+ bool TryUnifyIntParam(
+ ConstraintSystem& constraints,
+ RefPtr<GenericValueParamDecl> paramDecl,
+ RefPtr<IntVal> val)
+ {
+ // We want to constrain the given parameter to equal the given value.
+ Constraint constraint;
+ constraint.decl = paramDecl.Ptr();
+ constraint.val = val;
+
+ constraints.constraints.Add(constraint);
+
+ return true;
+ }
+
+ bool TryUnifyIntParam(
+ ConstraintSystem& constraints,
+ VarDeclBaseRef const& varRef,
+ RefPtr<IntVal> val)
+ {
+ if(auto genericValueParamRef = varRef.As<GenericValueParamDeclRef>())
+ {
+ return TryUnifyIntParam(constraints, genericValueParamRef.GetDecl(), val);
+ }
+ else
+ {
+ return false;
+ }
+ }
+
+ bool TryUnifyTypesByStructuralMatch(
+ ConstraintSystem& constraints,
+ RefPtr<ExpressionType> fst,
+ RefPtr<ExpressionType> snd)
+ {
+ if (auto fstDeclRefType = fst->As<DeclRefType>())
+ {
+ auto fstDeclRef = fstDeclRefType->declRef;
+
+ if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(fstDeclRef.GetDecl()))
+ return TryUnifyTypeParam(constraints, typeParamDecl, snd);
+
+ if (auto sndDeclRefType = snd->As<DeclRefType>())
+ {
+ auto sndDeclRef = sndDeclRefType->declRef;
+
+ if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(sndDeclRef.GetDecl()))
+ return TryUnifyTypeParam(constraints, typeParamDecl, fst);
+
+ // can't be unified if they refer to differnt declarations.
+ if (fstDeclRef.GetDecl() != sndDeclRef.GetDecl()) return false;
+
+ // next we need to unify the substitutions applied
+ // to each decalration reference.
+ if (!TryUnifySubstitutions(
+ constraints,
+ fstDeclRef.substitutions,
+ sndDeclRef.substitutions))
+ {
+ return false;
+ }
+
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ bool TryUnifyTypes(
+ ConstraintSystem& constraints,
+ RefPtr<ExpressionType> fst,
+ RefPtr<ExpressionType> snd)
+ {
+ if (fst->Equals(snd)) return true;
+
+ // An error type can unify with anything, just so we avoid cascading errors.
+
+ if (auto fstErrorType = fst->As<ErrorType>())
+ return true;
+
+ if (auto sndErrorType = snd->As<ErrorType>())
+ return true;
+
+ // A generic parameter type can unify with anything.
+ // TODO: there actually needs to be some kind of "occurs check" sort
+ // of thing here...
+
+ if (auto fstDeclRefType = fst->As<DeclRefType>())
+ {
+ auto fstDeclRef = fstDeclRefType->declRef;
+
+ if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(fstDeclRef.GetDecl()))
+ return TryUnifyTypeParam(constraints, typeParamDecl, snd);
+ }
+
+ if (auto sndDeclRefType = snd->As<DeclRefType>())
+ {
+ auto sndDeclRef = sndDeclRefType->declRef;
+
+ if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(sndDeclRef.GetDecl()))
+ return TryUnifyTypeParam(constraints, typeParamDecl, fst);
+ }
+
+ // If we can unify the types structurally, then we are golden
+ if(TryUnifyTypesByStructuralMatch(constraints, fst, snd))
+ return true;
+
+ // Now we need to consider cases where coercion might
+ // need to be applied. For now we can try to do this
+ // in a completely ad hoc fashion, but eventually we'd
+ // want to do it more formally.
+
+ if(auto fstVectorType = fst->As<VectorExpressionType>())
+ {
+ if(auto sndScalarType = snd->As<BasicExpressionType>())
+ {
+ return TryUnifyTypes(
+ constraints,
+ fstVectorType->elementType,
+ sndScalarType);
+ }
+ }
+
+ if(auto fstScalarType = fst->As<BasicExpressionType>())
+ {
+ if(auto sndVectorType = snd->As<VectorExpressionType>())
+ {
+ return TryUnifyTypes(
+ constraints,
+ fstScalarType,
+ sndVectorType->elementType);
+ }
+ }
+
+ // TODO: the same thing for vectors...
+
+ return false;
+ }
+
+ // Is the candidate extension declaration actually applicable to the given type
+ ExtensionDeclRef ApplyExtensionToType(
+ ExtensionDecl* extDecl,
+ RefPtr<ExpressionType> type)
+ {
+ if (auto extGenericDecl = GetOuterGeneric(extDecl))
+ {
+ ConstraintSystem constraints;
+
+ if (!TryUnifyTypes(constraints, extDecl->targetType, type))
+ return DeclRef().As<ExtensionDeclRef>();
+
+ auto constraintSubst = TrySolveConstraintSystem(&constraints, DeclRef(extGenericDecl, nullptr).As<GenericDeclRef>());
+ if (!constraintSubst)
+ {
+ return DeclRef().As<ExtensionDeclRef>();
+ }
+
+ // Consruct a reference to the extension with our constraint variables
+ // set as they were found by solving the constraint system.
+ ExtensionDeclRef extDeclRef = DeclRef(extDecl, constraintSubst).As<ExtensionDeclRef>();
+
+ // We expect/require that the result of unification is such that
+ // the target types are now equal
+ assert(extDeclRef.GetTargetType()->Equals(type));
+
+ return extDeclRef;
+ }
+ else
+ {
+ // The easy case is when the extension isn't generic:
+ // either it applies to the type or not.
+ if (!type->Equals(extDecl->targetType))
+ return DeclRef().As<ExtensionDeclRef>();
+ return DeclRef(extDecl, nullptr).As<ExtensionDeclRef>();
+ }
+ }
+
+ bool TryUnifyArgAndParamTypes(
+ ConstraintSystem& system,
+ RefPtr<ExpressionSyntaxNode> argExpr,
+ ParamDeclRef paramDeclRef)
+ {
+ // TODO(tfoley): potentially need a bit more
+ // nuance in case where argument might be
+ // an overload group...
+ return TryUnifyTypes(system, argExpr->Type, paramDeclRef.GetType());
+ }
+
+ // Take a generic declaration and try to specialize its parameters
+ // so that the resulting inner declaration can be applicable in
+ // a particular context...
+ DeclRef SpecializeGenericForOverload(
+ GenericDeclRef genericDeclRef,
+ OverloadResolveContext& context)
+ {
+ ConstraintSystem constraints;
+
+ // Construct a reference to the inner declaration that has any generic
+ // parameter substitutions in place already, but *not* any substutions
+ // for the generic declaration we are currently trying to infer.
+ auto innerDecl = genericDeclRef.GetInner();
+ DeclRef unspecializedInnerRef = DeclRef(innerDecl, genericDeclRef.substitutions);
+
+ // Check what type of declaration we are dealing with, and then try
+ // to match it up with the arguments accordingly...
+ if (auto funcDeclRef = unspecializedInnerRef.As<CallableDeclRef>())
+ {
+ auto& args = context.appExpr->Arguments;
+ auto params = funcDeclRef.GetParameters().ToArray();
+
+ int argCount = args.Count();
+ int paramCount = params.Count();
+
+ // Bail out on mismatch.
+ // TODO(tfoley): need more nuance here
+ if (argCount != paramCount)
+ {
+ return DeclRef(nullptr, nullptr);
+ }
+
+ for (int aa = 0; aa < argCount; ++aa)
+ {
+#if 0
+ if (!TryUnifyArgAndParamTypes(constraints, args[aa], params[aa]))
+ return DeclRef(nullptr, nullptr);
+#else
+ // The question here is whether failure to "unify" an argument
+ // and parameter should lead to immediate failure.
+ //
+ // The case that is interesting is if we want to unify, say:
+ // `vector<float,N>` and `vector<int,3>`
+ //
+ // It is clear that we should solve with `N = 3`, and then
+ // a later step may find that the resulting types aren't
+ // actually a match.
+ //
+ // A more refined approach to "unification" could of course
+ // see that `int` can convert to `float` and use that fact.
+ // (and indeed we already use something like this to unify
+ // `float` and `vector<T,3>`)
+ //
+ // So the question is then whether a mismatch during the
+ // unification step should be taken as an immediate failure...
+
+ TryUnifyArgAndParamTypes(constraints, args[aa], params[aa]);
+#endif
+ }
+ }
+ else
+ {
+ // TODO(tfoley): any other cases needed here?
+ return DeclRef(nullptr, nullptr);
+ }
+
+ auto constraintSubst = TrySolveConstraintSystem(&constraints, genericDeclRef);
+ if (!constraintSubst)
+ {
+ // constraint solving failed
+ return DeclRef(nullptr, nullptr);
+ }
+
+ // We can now construct a reference to the inner declaration using
+ // the solution to our constraints.
+ return DeclRef(innerDecl, constraintSubst);
+ }
+
+ void AddAggTypeOverloadCandidates(
+ LookupResultItem typeItem,
+ RefPtr<ExpressionType> type,
+ AggTypeDeclRef aggTypeDeclRef,
+ OverloadResolveContext& context)
+ {
+ for (auto ctorDeclRef : aggTypeDeclRef.GetMembersOfType<ConstructorDeclRef>())
+ {
+ // now work through this candidate...
+ AddCtorOverloadCandidate(typeItem, type, ctorDeclRef, context);
+ }
+
+ // Now walk through any extensions we can find for this types
+ for (auto ext = aggTypeDeclRef.GetCandidateExtensions(); ext; ext = ext->nextCandidateExtension)
+ {
+ auto extDeclRef = ApplyExtensionToType(ext, type);
+ if (!extDeclRef)
+ continue;
+
+ for (auto ctorDeclRef : extDeclRef.GetMembersOfType<ConstructorDeclRef>())
+ {
+ // TODO(tfoley): `typeItem` here should really reference the extension...
+
+ // now work through this candidate...
+ AddCtorOverloadCandidate(typeItem, type, ctorDeclRef, context);
+ }
+
+ // Also check for generic constructors
+ for (auto genericDeclRef : extDeclRef.GetMembersOfType<GenericDeclRef>())
+ {
+ if (auto ctorDecl = genericDeclRef.GetDecl()->inner.As<ConstructorDecl>())
+ {
+ DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context);
+ if (!innerRef)
+ continue;
+
+ ConstructorDeclRef innerCtorRef = innerRef.As<ConstructorDeclRef>();
+
+ AddCtorOverloadCandidate(typeItem, type, innerCtorRef, context);
+
+ // TODO(tfoley): need a way to do the solving step for the constraint system
+ }
+ }
+ }
+ }
+
+ void AddTypeOverloadCandidates(
+ RefPtr<ExpressionType> type,
+ OverloadResolveContext& context)
+ {
+ if (auto declRefType = type->As<DeclRefType>())
+ {
+ if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>())
+ {
+ AddAggTypeOverloadCandidates(LookupResultItem(aggTypeDeclRef), type, aggTypeDeclRef, context);
+ }
+ }
+ }
+
+ void AddDeclRefOverloadCandidates(
+ LookupResultItem item,
+ OverloadResolveContext& context)
+ {
+ auto declRef = item.declRef;
+
+ if (auto funcDeclRef = item.declRef.As<CallableDeclRef>())
+ {
+ AddFuncOverloadCandidate(item, funcDeclRef, context);
+ }
+ else if (auto aggTypeDeclRef = item.declRef.As<AggTypeDeclRef>())
+ {
+ auto type = DeclRefType::Create(aggTypeDeclRef);
+ AddAggTypeOverloadCandidates(item, type, aggTypeDeclRef, context);
+ }
+ else if (auto genericDeclRef = item.declRef.As<GenericDeclRef>())
+ {
+ // Try to infer generic arguments, based on the context
+ DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context);
+
+ if (innerRef)
+ {
+ // If inference works, then we've now got a
+ // specialized declaration reference we can apply.
+
+ LookupResultItem innerItem;
+ innerItem.breadcrumbs = item.breadcrumbs;
+ innerItem.declRef = innerRef;
+
+ AddDeclRefOverloadCandidates(innerItem, context);
+ }
+ else
+ {
+ // If inference failed, then we need to create
+ // a candidate that can be used to reflect that fact
+ // (so we can report a good error)
+ OverloadCandidate candidate;
+ candidate.item = item;
+ candidate.flavor = OverloadCandidate::Flavor::UnspecializedGeneric;
+ candidate.status = OverloadCandidate::Status::GenericArgumentInferenceFailed;
+
+ AddOverloadCandidateInner(context, candidate);
+ }
+ }
+ else if( auto typeDefDeclRef = item.declRef.As<TypeDefDeclRef>() )
+ {
+ AddTypeOverloadCandidates(typeDefDeclRef.GetType(), context);
+ }
+ else
+ {
+ // TODO(tfoley): any other cases needed here?
+ }
+ }
+
+ void AddOverloadCandidates(
+ RefPtr<ExpressionSyntaxNode> funcExpr,
+ OverloadResolveContext& context)
+ {
+ auto funcExprType = funcExpr->Type;
+
+ if (auto funcDeclRefExpr = funcExpr.As<DeclRefExpr>())
+ {
+ // The expression referenced a function declaration
+ AddDeclRefOverloadCandidates(LookupResultItem(funcDeclRefExpr->declRef), context);
+ }
+ else if (auto funcType = funcExprType->As<FuncType>())
+ {
+ // TODO(tfoley): deprecate this path...
+ AddFuncOverloadCandidate(funcType, context);
+ }
+ else if (auto overloadedExpr = funcExpr.As<OverloadedExpr>())
+ {
+ auto lookupResult = overloadedExpr->lookupResult2;
+ assert(lookupResult.isOverloaded());
+ for(auto item : lookupResult.items)
+ {
+ AddDeclRefOverloadCandidates(item, context);
+ }
+ }
+ else if (auto typeType = funcExprType->As<TypeType>())
+ {
+ // If none of the above cases matched, but we are
+ // looking at a type, then I suppose we have
+ // a constructor call on our hands.
+ //
+ // TODO(tfoley): are there any meaningful types left
+ // that aren't declaration references?
+ AddTypeOverloadCandidates(typeType->type, context);
+ return;
+ }
+ }
+
+ void formatType(StringBuilder& sb, RefPtr<ExpressionType> type)
+ {
+ sb << type->ToString();
+ }
+
+ void formatVal(StringBuilder& sb, RefPtr<Val> val)
+ {
+ sb << val->ToString();
+ }
+
+ void formatDeclPath(StringBuilder& sb, DeclRef declRef)
+ {
+ // Find the parent declaration
+ auto parentDeclRef = declRef.GetParent();
+
+ // If the immediate parent is a generic, then we probably
+ // want the declaration above that...
+ auto parentGenericDeclRef = parentDeclRef.As<GenericDeclRef>();
+ if(parentGenericDeclRef)
+ {
+ parentDeclRef = parentGenericDeclRef.GetParent();
+ }
+
+ // Depending on what the parent is, we may want to format things specially
+ if(auto aggTypeDeclRef = parentDeclRef.As<AggTypeDeclRef>())
+ {
+ formatDeclPath(sb, aggTypeDeclRef);
+ sb << ".";
+ }
+
+ sb << declRef.GetName();
+
+ // If the parent declaration is a generic, then we need to print out its
+ // signature
+ if( parentGenericDeclRef )
+ {
+ assert(declRef.substitutions);
+ assert(declRef.substitutions->genericDecl == parentGenericDeclRef.GetDecl());
+
+ sb << "<";
+ bool first = true;
+ for(auto arg : declRef.substitutions->args)
+ {
+ if(!first) sb << ", ";
+ formatVal(sb, arg);
+ first = false;
+ }
+ sb << ">";
+ }
+ }
+
+ void formatDeclParams(StringBuilder& sb, DeclRef declRef)
+ {
+ if (auto funcDeclRef = declRef.As<CallableDeclRef>())
+ {
+
+ // This is something callable, so we need to also print parameter types for overloading
+ sb << "(";
+
+ bool first = true;
+ for (auto paramDeclRef : funcDeclRef.GetParameters())
+ {
+ if (!first) sb << ", ";
+
+ formatType(sb, paramDeclRef.GetType());
+
+ first = false;
+
+ }
+
+ sb << ")";
+ }
+ else if(auto genericDeclRef = declRef.As<GenericDeclRef>())
+ {
+ sb << "<";
+ bool first = true;
+ for (auto paramDeclRef : genericDeclRef.GetMembers())
+ {
+ if(auto genericTypeParam = paramDeclRef.As<GenericTypeParamDeclRef>())
+ {
+ if (!first) sb << ", ";
+ first = false;
+
+ sb << genericTypeParam.GetName();
+ }
+ else if(auto genericValParam = paramDeclRef.As<GenericValueParamDeclRef>())
+ {
+ if (!first) sb << ", ";
+ first = false;
+
+ formatType(sb, genericValParam.GetType());
+ sb << " ";
+ sb << genericValParam.GetName();
+ }
+ else
+ {}
+ }
+ sb << ">";
+
+ formatDeclParams(sb, DeclRef(genericDeclRef.GetInner(), genericDeclRef.substitutions));
+ }
+ else
+ {
+ }
+ }
+
+ void formatDeclSignature(StringBuilder& sb, DeclRef declRef)
+ {
+ formatDeclPath(sb, declRef);
+ formatDeclParams(sb, declRef);
+ }
+
+ String getDeclSignatureString(DeclRef declRef)
+ {
+ StringBuilder sb;
+ formatDeclSignature(sb, declRef);
+ return sb.ProduceString();
+ }
+
+ String getDeclSignatureString(LookupResultItem item)
+ {
+ return getDeclSignatureString(item.declRef);
+ }
+
+ String GetCallSignatureString(RefPtr<AppExprBase> expr)
+ {
+ StringBuilder argsListBuilder;
+ argsListBuilder << "(";
+ bool first = true;
+ for (auto a : expr->Arguments)
+ {
+ if (!first) argsListBuilder << ", ";
+ argsListBuilder << a->Type->ToString();
+ first = false;
+ }
+ argsListBuilder << ")";
+ return argsListBuilder.ProduceString();
+ }
+
+
+ RefPtr<ExpressionSyntaxNode> ResolveInvoke(InvokeExpressionSyntaxNode * expr)
+ {
+ // Look at the base expression for the call, and figure out how to invoke it.
+ auto funcExpr = expr->FunctionExpr;
+ auto funcExprType = funcExpr->Type;
+
+ // If we are trying to apply an erroroneous expression, then just bail out now.
+ if(IsErrorExpr(funcExpr))
+ {
+ return CreateErrorExpr(expr);
+ }
+
+ OverloadResolveContext context;
+ context.appExpr = expr;
+ if (auto funcMemberExpr = funcExpr.As<MemberExpressionSyntaxNode>())
+ {
+ context.baseExpr = funcMemberExpr->BaseExpression;
+ }
+ else if (auto funcOverloadExpr = funcExpr.As<OverloadedExpr>())
+ {
+ context.baseExpr = funcOverloadExpr->base;
+ }
+ AddOverloadCandidates(funcExpr, context);
+
+ if (context.bestCandidates.Count() > 0)
+ {
+ // Things were ambiguous.
+
+ // It might be that things were only ambiguous because
+ // one of the argument expressions had an error, and
+ // so a bunch of candidates could match at that position.
+ //
+ // If any argument was an error, we skip out on printing
+ // another message, to avoid cascading errors.
+ for (auto arg : expr->Arguments)
+ {
+ if (IsErrorExpr(arg))
+ {
+ return CreateErrorExpr(expr);
+ }
+ }
+
+ String funcName;
+ if (auto baseVar = funcExpr.As<VarExpressionSyntaxNode>())
+ funcName = baseVar->Variable;
+ else if(auto baseMemberRef = funcExpr.As<MemberExpressionSyntaxNode>())
+ funcName = baseMemberRef->MemberName;
+
+ String argsList = GetCallSignatureString(expr);
+
+ if (context.bestCandidates[0].status != OverloadCandidate::Status::Appicable)
+ {
+ // There were multple equally-good candidates, but none actually usable.
+ // We will construct a diagnostic message to help out.
+ if (funcName.Length() != 0)
+ {
+ getSink()->diagnose(expr, Diagnostics::noApplicableOverloadForNameWithArgs, funcName, argsList);
+ }
+ else
+ {
+ getSink()->diagnose(expr, Diagnostics::noApplicableWithArgs, argsList);
+ }
+ }
+ else
+ {
+ // There were multiple applicable candidates, so we need to report them.
+
+ if (funcName.Length() != 0)
+ {
+ getSink()->diagnose(expr, Diagnostics::ambiguousOverloadForNameWithArgs, funcName, argsList);
+ }
+ else
+ {
+ getSink()->diagnose(expr, Diagnostics::ambiguousOverloadWithArgs, argsList);
+ }
+ }
+
+ int candidateCount = context.bestCandidates.Count();
+ int maxCandidatesToPrint = 10; // don't show too many candidates at once...
+ int candidateIndex = 0;
+ for (auto candidate : context.bestCandidates)
+ {
+ String declString = getDeclSignatureString(candidate.item);
+
+ declString = declString + "[" + String(candidate.conversionCostSum) + "]";
+
+ getSink()->diagnose(candidate.item.declRef, Diagnostics::overloadCandidate, declString);
+
+ candidateIndex++;
+ if (candidateIndex == maxCandidatesToPrint)
+ break;
+ }
+ if (candidateIndex != candidateCount)
+ {
+ getSink()->diagnose(expr, Diagnostics::moreOverloadCandidates, candidateCount - candidateIndex);
+ }
+
+ return CreateErrorExpr(expr);
+ }
+ else if (context.bestCandidate)
+ {
+ // There was one best candidate, even if it might not have been
+ // applicable in the end.
+ // We will report errors for this one candidate, then, to give
+ // the user the most help we can.
+ return CompleteOverloadCandidate(context, *context.bestCandidate);
+ }
+ else
+ {
+ // Nothing at all was found that we could even consider invoking
+ getSink()->diagnose(expr->FunctionExpr, Diagnostics::expectedFunction);
+ expr->Type = ExpressionType::Error;
+ return expr;
+ }
+ }
+
+ void AddGenericOverloadCandidate(
+ LookupResultItem baseItem,
+ OverloadResolveContext& context)
+ {
+ if (auto genericDeclRef = baseItem.declRef.As<GenericDeclRef>())
+ {
+ EnsureDecl(genericDeclRef.GetDecl());
+
+ OverloadCandidate candidate;
+ candidate.flavor = OverloadCandidate::Flavor::Generic;
+ candidate.item = baseItem;
+ candidate.resultType = nullptr;
+
+ AddOverloadCandidate(context, candidate);
+ }
+ }
+
+ void AddGenericOverloadCandidates(
+ RefPtr<ExpressionSyntaxNode> baseExpr,
+ OverloadResolveContext& context)
+ {
+ if(auto baseDeclRefExpr = baseExpr.As<DeclRefExpr>())
+ {
+ auto declRef = baseDeclRefExpr->declRef;
+ AddGenericOverloadCandidate(LookupResultItem(declRef), context);
+ }
+ else if (auto overloadedExpr = baseExpr.As<OverloadedExpr>())
+ {
+ // We are referring to a bunch of declarations, each of which might be generic
+ LookupResult result;
+ for (auto item : overloadedExpr->lookupResult2.items)
+ {
+ AddGenericOverloadCandidate(item, context);
+ }
+ }
+ else
+ {
+ // any other cases?
+ }
+ }
+
+ RefPtr<ExpressionSyntaxNode> VisitGenericApp(GenericAppExpr * genericAppExpr) override
+ {
+ // We are applying a generic to arguments, but there might be multiple generic
+ // declarations with the same name, so this becomes a specialized case of
+ // overload resolution.
+
+
+ // Start by checking the base expression and arguments.
+ auto& baseExpr = genericAppExpr->FunctionExpr;
+ baseExpr = CheckTerm(baseExpr);
+ auto& args = genericAppExpr->Arguments;
+ for (auto& arg : args)
+ {
+ arg = CheckTerm(arg);
+ }
+
+ // If there was an error in the base expression, or in any of
+ // the arguments, then just bail.
+ if (IsErrorExpr(baseExpr))
+ {
+ return CreateErrorExpr(genericAppExpr);
+ }
+ for (auto argExpr : args)
+ {
+ if (IsErrorExpr(argExpr))
+ {
+ return CreateErrorExpr(genericAppExpr);
+ }
+ }
+
+ // Otherwise, let's start looking at how to find an overload...
+
+ OverloadResolveContext context;
+ context.appExpr = genericAppExpr;
+ context.baseExpr = GetBaseExpr(baseExpr);
+
+ AddGenericOverloadCandidates(baseExpr, context);
+
+ if (context.bestCandidates.Count() > 0)
+ {
+ // Things were ambiguous.
+ if (context.bestCandidates[0].status != OverloadCandidate::Status::Appicable)
+ {
+ // There were multple equally-good candidates, but none actually usable.
+ // We will construct a diagnostic message to help out.
+
+ // TODO(tfoley): print a reasonable message here...
+
+ getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "no applicable generic");
+
+ return CreateErrorExpr(genericAppExpr);
+ }
+ else
+ {
+ // There were multiple viable candidates, but that isn't an error: we just need
+ // to complete all of them and create an overloaded expression as a result.
+
+ LookupResult result;
+ for (auto candidate : context.bestCandidates)
+ {
+ auto candidateExpr = CompleteOverloadCandidate(context, candidate);
+ }
+
+ throw "what now?";
+// auto overloadedExpr = new OverloadedExpr();
+// return overloadedExpr;
+ }
+ }
+ else if (context.bestCandidate)
+ {
+ // There was one best candidate, even if it might not have been
+ // applicable in the end.
+ // We will report errors for this one candidate, then, to give
+ // the user the most help we can.
+ return CompleteOverloadCandidate(context, *context.bestCandidate);
+ }
+ else
+ {
+ // Nothing at all was found that we could even consider invoking
+ getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "expected a generic");
+ return CreateErrorExpr(genericAppExpr);
+ }
+
+
+#if TIMREMOVED
+
+ if (IsErrorExpr(base))
+ {
+ return CreateErrorExpr(typeNode);
+ }
+ else if(auto baseDeclRefExpr = base.As<DeclRefExpr>())
+ {
+ auto declRef = baseDeclRefExpr->declRef;
+
+ if (auto genericDeclRef = declRef.As<GenericDeclRef>())
+ {
+ int argCount = typeNode->Args.Count();
+ int argIndex = 0;
+ for (RefPtr<Decl> member : genericDeclRef.GetDecl()->Members)
+ {
+ if (auto typeParam = member.As<GenericTypeParamDecl>())
+ {
+ if (argIndex == argCount)
+ {
+ // Too few arguments!
+
+ }
+
+ // TODO: checking!
+ }
+ else if (auto valParam = member.As<GenericValueParamDecl>())
+ {
+ // TODO: checking
+ }
+ else
+ {
+
+ }
+ }
+ if (argIndex != argCount)
+ {
+ // Too many arguments!
+ }
+
+ // Now instantiate the declaration given those arguments
+ auto type = InstantiateGenericType(genericDeclRef, args);
+ typeResult = type;
+ typeNode->Type = new TypeExpressionType(type);
+ return typeNode;
+ }
+ }
+ else if (auto overloadedExpr = base.As<OverloadedExpr>())
+ {
+ // We are referring to a bunch of declarations, each of which might be generic
+ LookupResult result;
+ for (auto item : overloadedExpr->lookupResult2.items)
+ {
+ auto applied = TryApplyGeneric(item, typeNode);
+ if (!applied)
+ continue;
+
+ AddToLookupResult(result, appliedItem);
+ }
+ }
+
+ // TODO: correct diagnostic here!
+ getSink()->diagnose(typeNode, Diagnostics::expectedAGeneric, base->Type);
+ return CreateErrorExpr(typeNode);
+#endif
+ }
+
+ RefPtr<ExpressionSyntaxNode> VisitSharedTypeExpr(SharedTypeExpr* expr) override
+ {
+ if (!expr->Type.Ptr())
+ {
+ expr->base = CheckProperType(expr->base);
+ expr->Type = expr->base.exp->Type;
+ }
+ return expr;
+ }
+
+
+
+
+ RefPtr<ExpressionSyntaxNode> CheckExpr(RefPtr<ExpressionSyntaxNode> expr)
+ {
+ return expr->Accept(this).As<ExpressionSyntaxNode>();
+ }
+
+ RefPtr<ExpressionSyntaxNode> CheckInvokeExprWithCheckedOperands(InvokeExpressionSyntaxNode *expr)
+ {
+
+ auto rs = ResolveInvoke(expr);
+ if (auto invoke = dynamic_cast<InvokeExpressionSyntaxNode*>(rs.Ptr()))
+ {
+ // if this is still an invoke expression, test arguments passed to inout/out parameter are LValues
+ if(auto funcType = invoke->FunctionExpr->Type->As<FuncType>())
+ {
+ List<RefPtr<ParameterSyntaxNode>> paramsStorage;
+ List<RefPtr<ParameterSyntaxNode>> * params = nullptr;
+ if (auto func = funcType->declRef.GetDecl())
+ {
+ paramsStorage = func->GetParameters().ToArray();
+ params = &paramsStorage;
+ }
+ if (params)
+ {
+ for (int i = 0; i < (*params).Count(); i++)
+ {
+ if ((*params)[i]->HasModifier<OutModifier>())
+ {
+ if (i < expr->Arguments.Count() && expr->Arguments[i]->Type->AsBasicType() &&
+ !expr->Arguments[i]->Type.IsLeftValue)
+ {
+ getSink()->diagnose(expr->Arguments[i], Diagnostics::argumentExpectedLValue, (*params)[i]->Name);
+ }
+ }
+ }
+ }
+ }
+ }
+ return rs;
+ }
+
+ virtual RefPtr<ExpressionSyntaxNode> VisitInvokeExpression(InvokeExpressionSyntaxNode *expr) override
+ {
+ // check the base expression first
+ expr->FunctionExpr = CheckExpr(expr->FunctionExpr);
+
+ // Next check the argument expressions
+ for (auto & arg : expr->Arguments)
+ {
+ arg = CheckExpr(arg);
+ }
+
+ return CheckInvokeExprWithCheckedOperands(expr);
+ }
+
+
+ virtual RefPtr<ExpressionSyntaxNode> VisitVarExpression(VarExpressionSyntaxNode *expr) override
+ {
+ // If we've already resolved this expression, don't try again.
+ if (expr->declRef)
+ return expr;
+
+ expr->Type = ExpressionType::Error;
+
+ auto lookupResult = LookUp(expr->Variable, expr->scope);
+ if (lookupResult.isValid())
+ {
+ return createLookupResultExpr(
+ lookupResult,
+ nullptr,
+ expr);
+ }
+
+ getSink()->diagnose(expr, Diagnostics::undefinedIdentifier2, expr->Variable);
+
+ return expr;
+ }
+ virtual RefPtr<ExpressionSyntaxNode> VisitTypeCastExpression(TypeCastExpressionSyntaxNode * expr) override
+ {
+ expr->Expression = expr->Expression->Accept(this).As<ExpressionSyntaxNode>();
+ auto targetType = CheckProperType(expr->TargetType);
+ expr->TargetType = targetType;
+
+ // The way to perform casting depends on the types involved
+ if (expr->Expression->Type->Equals(ExpressionType::Error.Ptr()))
+ {
+ // If the expression being casted has an error type, then just silently succeed
+ expr->Type = targetType;
+ return expr;
+ }
+ else if (auto targetArithType = targetType->AsArithmeticType())
+ {
+ if (auto exprArithType = expr->Expression->Type->AsArithmeticType())
+ {
+ // Both source and destination types are arithmetic, so we might
+ // have a valid cast
+ auto targetScalarType = targetArithType->GetScalarType();
+ auto exprScalarType = exprArithType->GetScalarType();
+
+ if (!IsNumeric(exprScalarType->BaseType)) goto fail;
+ if (!IsNumeric(targetScalarType->BaseType)) goto fail;
+
+ // TODO(tfoley): this checking is incomplete here, and could
+ // lead to downstream compilation failures
+ expr->Type = targetType;
+ return expr;
+ }
+ }
+
+ fail:
+ // Default: in no other case succeds, then the cast failed and we emit a diagnostic.
+ getSink()->diagnose(expr, Diagnostics::invalidTypeCast, expr->Expression->Type, targetType->ToString());
+ expr->Type = ExpressionType::Error;
+ return expr;
+ }
+#if TIMREMOVED
+ virtual RefPtr<ExpressionSyntaxNode> VisitSelectExpression(SelectExpressionSyntaxNode * expr) override
+ {
+ auto selectorExpr = expr->SelectorExpr;
+ selectorExpr = CheckExpr(selectorExpr);
+ selectorExpr = Coerce(ExpressionType::GetBool(), selectorExpr);
+ expr->SelectorExpr = selectorExpr;
+
+ // TODO(tfoley): We need a general purpose "join" on types for inferring
+ // generic argument types for builtins/intrinsics, so this should really
+ // be using the exact same logic...
+ //
+ expr->Expr0 = expr->Expr0->Accept(this).As<ExpressionSyntaxNode>();
+ expr->Expr1 = expr->Expr1->Accept(this).As<ExpressionSyntaxNode>();
+ if (!expr->Expr0->Type->Equals(expr->Expr1->Type.Ptr()))
+ {
+ getSink()->diagnose(expr, Diagnostics::selectValuesTypeMismatch);
+ }
+ expr->Type = expr->Expr0->Type;
+ return expr;
+ }
+#endif
+
+ // Get the type to use when referencing a declaration
+ QualType GetTypeForDeclRef(DeclRef declRef)
+ {
+ return getTypeForDeclRef(
+ this,
+ getSink(),
+ declRef,
+ &typeResult);
+ }
+
+ RefPtr<ExpressionSyntaxNode> MaybeDereference(RefPtr<ExpressionSyntaxNode> inExpr)
+ {
+ RefPtr<ExpressionSyntaxNode> expr = inExpr;
+ for (;;)
+ {
+ auto& type = expr->Type;
+ if (auto pointerLikeType = type->As<PointerLikeType>())
+ {
+ type = pointerLikeType->elementType;
+
+ auto derefExpr = new DerefExpr();
+ derefExpr->base = expr;
+ derefExpr->Type = pointerLikeType->elementType;
+
+ // TODO(tfoley): deal with l-value-ness here
+
+ expr = derefExpr;
+ continue;
+ }
+
+ // Default case: just use the expression as-is
+ return expr;
+ }
+ }
+
+ RefPtr<ExpressionSyntaxNode> CheckSwizzleExpr(
+ MemberExpressionSyntaxNode* memberRefExpr,
+ RefPtr<ExpressionType> baseElementType,
+ int baseElementCount)
+ {
+ RefPtr<SwizzleExpr> swizExpr = new SwizzleExpr();
+ swizExpr->Position = memberRefExpr->Position;
+ swizExpr->base = memberRefExpr->BaseExpression;
+
+ int limitElement = baseElementCount;
+
+ int elementIndices[4];
+ int elementCount = 0;
+
+ bool elementUsed[4] = { false, false, false, false };
+ bool anyDuplicates = false;
+ bool anyError = false;
+
+ for (int i = 0; i < memberRefExpr->MemberName.Length(); i++)
+ {
+ auto ch = memberRefExpr->MemberName[i];
+ int elementIndex = -1;
+ switch (ch)
+ {
+ case 'x': case 'r': elementIndex = 0; break;
+ case 'y': case 'g': elementIndex = 1; break;
+ case 'z': case 'b': elementIndex = 2; break;
+ case 'w': case 'a': elementIndex = 3; break;
+ default:
+ // An invalid character in the swizzle is an error
+ getSink()->diagnose(swizExpr, Diagnostics::unimplemented, "invalid component name for swizzle");
+ anyError = true;
+ continue;
+ }
+
+ // TODO(tfoley): GLSL requires that all component names
+ // come from the same "family"...
+
+ // Make sure the index is in range for the source type
+ if (elementIndex >= limitElement)
+ {
+ getSink()->diagnose(swizExpr, Diagnostics::unimplemented, "swizzle component out of range for type");
+ anyError = true;
+ continue;
+ }
+
+ // Check if we've seen this index before
+ for (int ee = 0; ee < elementCount; ee++)
+ {
+ if (elementIndices[ee] == elementIndex)
+ anyDuplicates = true;
+ }
+
+ // add to our list...
+ elementIndices[elementCount++] = elementIndex;
+ }
+
+ for (int ee = 0; ee < elementCount; ++ee)
+ {
+ swizExpr->elementIndices[ee] = elementIndices[ee];
+ }
+ swizExpr->elementCount = elementCount;
+
+ if (anyError)
+ {
+ swizExpr->Type = ExpressionType::Error;
+ }
+ else if (elementCount == 1)
+ {
+ // single-component swizzle produces a scalar
+ //
+ // Note(tfoley): the official HLSL rules seem to be that it produces
+ // a one-component vector, which is then implicitly convertible to
+ // a scalar, but that seems like it just adds complexity.
+ swizExpr->Type = baseElementType;
+ }
+ else
+ {
+ // TODO(tfoley): would be nice to "re-sugar" type
+ // here if the input type had a sugared name...
+ swizExpr->Type = createVectorType(
+ baseElementType,
+ new ConstantIntVal(elementCount));
+ }
+
+ // A swizzle can be used as an l-value as long as there
+ // were no duplicates in the list of components
+ swizExpr->Type.IsLeftValue = !anyDuplicates;
+
+ return swizExpr;
+ }
+
+ RefPtr<ExpressionSyntaxNode> CheckSwizzleExpr(
+ MemberExpressionSyntaxNode* memberRefExpr,
+ RefPtr<ExpressionType> baseElementType,
+ RefPtr<IntVal> baseElementCount)
+ {
+ if (auto constantElementCount = baseElementCount.As<ConstantIntVal>())
+ {
+ return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->value);
+ }
+ else
+ {
+ getSink()->diagnose(memberRefExpr, Diagnostics::unimplemented, "swizzle on vector of unknown size");
+ return CreateErrorExpr(memberRefExpr);
+ }
+ }
+
+
+ virtual RefPtr<ExpressionSyntaxNode> VisitMemberExpression(MemberExpressionSyntaxNode * expr) override
+ {
+ expr->BaseExpression = CheckExpr(expr->BaseExpression);
+
+ expr->BaseExpression = MaybeDereference(expr->BaseExpression);
+
+ auto & baseType = expr->BaseExpression->Type;
+
+ // Note: Checking for vector types before declaration-reference types,
+ // because vectors are also declaration reference types...
+ if (auto baseVecType = baseType->AsVectorType())
+ {
+ return CheckSwizzleExpr(
+ expr,
+ baseVecType->elementType,
+ baseVecType->elementCount);
+ }
+ else if(auto baseScalarType = baseType->AsBasicType())
+ {
+ // Treat scalar like a 1-element vector when swizzling
+ return CheckSwizzleExpr(
+ expr,
+ baseScalarType,
+ 1);
+ }
+ else if (auto declRefType = baseType->AsDeclRefType())
+ {
+ if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>())
+ {
+ // Checking of the type must be complete before we can reference its members safely
+ EnsureDecl(aggTypeDeclRef.GetDecl(), DeclCheckState::Checked);
+
+
+ LookupResult lookupResult = LookUpLocal(expr->MemberName, aggTypeDeclRef);
+ if (!lookupResult.isValid())
+ {
+ goto fail;
+ }
+
+ return createLookupResultExpr(
+ lookupResult,
+ expr->BaseExpression,
+ expr);
+#if 0
+ DeclRef memberDeclRef(lookupResult.decl, aggTypeDeclRef.substitutions);
+ return ConstructDeclRefExpr(memberDeclRef, expr->BaseExpression, expr);
+#endif
+
+#if 0
+
+
+ // TODO(tfoley): It is unfortunate that the lookup strategy
+ // here isn't unified with the ordinary `Scope` case.
+ // In particular, if we add support for "transparent" declarations,
+ // etc. here then we would need to add them in ordinary lookup
+ // as well.
+
+ Decl* memberDecl = nullptr; // The first declaration we found, if any
+ Decl* secondDecl = nullptr; // Another declaration with the same name, if any
+ for (auto m : aggTypeDeclRef.GetMembers())
+ {
+ if (m.GetName() != expr->MemberName)
+ continue;
+
+ if (!memberDecl)
+ {
+ memberDecl = m.GetDecl();
+ }
+ else
+ {
+ secondDecl = m.GetDecl();
+ break;
+ }
+ }
+
+ // If we didn't find any member, then we signal an error
+ if (!memberDecl)
+ {
+ expr->Type = ExpressionType::Error;
+ getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType);
+ return expr;
+ }
+
+ // If we found only a single member, then we are fine
+ if (!secondDecl)
+ {
+ // TODO: need to
+ DeclRef memberDeclRef(memberDecl, aggTypeDeclRef.substitutions);
+
+ expr->declRef = memberDeclRef;
+ expr->Type = GetTypeForDeclRef(memberDeclRef);
+
+ // When referencing a member variable, the result is an l-value
+ // if and only if the base expression was.
+ if (auto memberVarDecl = dynamic_cast<VarDeclBase*>(memberDecl))
+ {
+ expr->Type.IsLeftValue = expr->BaseExpression->Type.IsLeftValue;
+ }
+ return expr;
+ }
+
+ // We found multiple members with the same name, and need
+ // to resolve the embiguity at some point...
+ expr->Type = ExpressionType::Error;
+ getSink()->diagnose(expr, Diagnostics::unimplemented, "ambiguous member reference");
+ return expr;
+
+#endif
+
+#if 0
+
+ StructField* field = structDecl->FindField(expr->MemberName);
+ if (!field)
+ {
+ expr->Type = ExpressionType::Error;
+ getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType);
+ }
+ else
+ expr->Type = field->Type;
+
+ // A reference to a struct member is an l-value if the reference to the struct
+ // value was also an l-value.
+ expr->Type.IsLeftValue = expr->BaseExpression->Type.IsLeftValue;
+ return expr;
+#endif
+ }
+
+ // catch-all
+ fail:
+ getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType);
+ expr->Type = ExpressionType::Error;
+ return expr;
+ }
+ // All remaining cases assume we have a `BasicType`
+ else if (!baseType->AsBasicType())
+ expr->Type = ExpressionType::Error;
+ else
+ expr->Type = ExpressionType::Error;
+ if (!baseType->Equals(ExpressionType::Error.Ptr()) &&
+ expr->Type->Equals(ExpressionType::Error.Ptr()))
+ {
+ getSink()->diagnose(expr, Diagnostics::typeHasNoPublicMemberOfName, baseType, expr->MemberName);
+ }
+ return expr;
+ }
+ SemanticsVisitor & operator = (const SemanticsVisitor &) = delete;
+
+
+ //
+
+ virtual RefPtr<ExpressionSyntaxNode> visitInitializerListExpr(InitializerListExpr* expr) override
+ {
+ // When faced with an initializer list, we first just check the sub-expressions blindly.
+ // Actually making them conform to a desired type will wait for when we know the desired
+ // type based on context.
+
+ for( auto& arg : expr->args )
+ {
+ arg = CheckTerm(arg);
+ }
+
+ expr->Type = ExpressionType::getInitializerListType();
+
+ return expr;
+ }
+ };
+
+ SyntaxVisitor * CreateSemanticsVisitor(
+ DiagnosticSink * err,
+ CompileOptions const& options)
+ {
+ return new SemanticsVisitor(err, options);
+ }
+
+ //
+
+ // Get the type to use when referencing a declaration
+ QualType getTypeForDeclRef(
+ SemanticsVisitor* sema,
+ DiagnosticSink* sink,
+ DeclRef declRef,
+ RefPtr<ExpressionType>* outTypeResult)
+ {
+ if( sema )
+ {
+ sema->EnsureDecl(declRef.GetDecl());
+ }
+
+ // We need to insert an appropriate type for the expression, based on
+ // what we found.
+ if (auto varDeclRef = declRef.As<VarDeclBaseRef>())
+ {
+ QualType qualType;
+ qualType.type = varDeclRef.GetType();
+ qualType.IsLeftValue = true; // TODO(tfoley): allow explicit `const` or `let` variables
+ return qualType;
+ }
+ else if (auto typeAliasDeclRef = declRef.As<TypeDefDeclRef>())
+ {
+ auto type = new NamedExpressionType(typeAliasDeclRef);
+ *outTypeResult = type;
+ return new TypeType(type);
+ }
+ else if (auto aggTypeDeclRef = declRef.As<AggTypeDeclRef>())
+ {
+ auto type = DeclRefType::Create(aggTypeDeclRef);
+ *outTypeResult = type;
+ return new TypeType(type);
+ }
+ else if (auto simpleTypeDeclRef = declRef.As<SimpleTypeDeclRef>())
+ {
+ auto type = DeclRefType::Create(simpleTypeDeclRef);
+ *outTypeResult = type;
+ return new TypeType(type);
+ }
+ else if (auto genericDeclRef = declRef.As<GenericDeclRef>())
+ {
+ auto type = new GenericDeclRefType(genericDeclRef);
+ *outTypeResult = type;
+ return new TypeType(type);
+ }
+ else if (auto funcDeclRef = declRef.As<CallableDeclRef>())
+ {
+ auto type = new FuncType();
+ type->declRef = funcDeclRef;
+ return type;
+ }
+
+ if( sink )
+ {
+ sink->diagnose(declRef, Diagnostics::unimplemented, "cannot form reference to this kind of declaration");
+ }
+ return ExpressionType::Error;
+ }
+
+ QualType getTypeForDeclRef(
+ DeclRef declRef)
+ {
+ RefPtr<ExpressionType> typeResult;
+ return getTypeForDeclRef(nullptr, nullptr, declRef, &typeResult);
+ }
+
+ }
+} \ No newline at end of file
diff --git a/source/slang/compiled-program.h b/source/slang/compiled-program.h
new file mode 100644
index 000000000..7a86dfe90
--- /dev/null
+++ b/source/slang/compiled-program.h
@@ -0,0 +1,96 @@
+#ifndef BAKER_SL_COMPILED_PROGRAM_H
+#define BAKER_SL_COMPILED_PROGRAM_H
+
+#include "../core/basic.h"
+#include "diagnostics.h"
+#include "syntax.h"
+#include "type-layout.h"
+
+namespace Slang
+{
+ namespace Compiler
+ {
+#if 0
+ class ShaderMetaData
+ {
+ public:
+ CoreLib::String ShaderName;
+ CoreLib::EnumerableDictionary<CoreLib::String, CoreLib::RefPtr<ILModuleParameterSet>> ParameterSets; // bindingName->DescSet
+ };
+
+ class StageSource
+ {
+ public:
+ String MainCode;
+ List<unsigned char> BinaryCode;
+ };
+
+ class CompiledShaderSource
+ {
+ public:
+ EnumerableDictionary<String, StageSource> Stages;
+ ShaderMetaData MetaData;
+ };
+#endif
+
+ void IndentString(StringBuilder & sb, String src);
+
+ struct EntryPointResult
+ {
+ String outputSource;
+ };
+
+ struct TranslationUnitResult
+ {
+ String outputSource;
+ List<EntryPointResult> entryPoints;
+ };
+
+ class CompileResult
+ {
+ public:
+ DiagnosticSink* mSink = nullptr;
+
+#if 0
+ String ScheduleFile;
+ RefPtr<ILProgram> Program;
+ EnumerableDictionary<String, CompiledShaderSource> CompiledSource; // shader -> stage -> code
+#endif
+
+ // Per-translation-unit results
+ List<TranslationUnitResult> translationUnits;
+
+#if 0
+ void PrintDiagnostics()
+ {
+ for (int i = 0; i < sink.diagnostics.Count(); i++)
+ {
+ fprintf(stderr, "%S(%d): %s %d: %S\n",
+ sink.diagnostics[i].Position.FileName.ToWString(),
+ sink.diagnostics[i].Position.Line,
+ getSeverityName(sink.diagnostics[i].severity),
+ sink.diagnostics[i].ErrorID,
+ sink.diagnostics[i].Message.ToWString());
+ }
+ }
+#endif
+
+ CompileResult()
+ {}
+ ~CompileResult()
+ {
+ }
+ DiagnosticSink * GetErrorWriter()
+ {
+ return mSink;
+ }
+ int GetErrorCount()
+ {
+ return mSink->GetErrorCount();
+ }
+ };
+
+ }
+}
+
+#endif \ No newline at end of file
diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp
new file mode 100644
index 000000000..e78de9133
--- /dev/null
+++ b/source/slang/compiler.cpp
@@ -0,0 +1,659 @@
+// Compiler.cpp : Defines the entry point for the console application.
+//
+#include "../core/basic.h"
+#include "../core/slang-io.h"
+#include "compiler.h"
+#include "lexer.h"
+#include "parameter-binding.h"
+#include "parser.h"
+#include "preprocessor.h"
+#include "syntax-visitors.h"
+#include "slang-stdlib.h"
+
+#include "reflection.h"
+#include "emit.h"
+
+// Utilities for pass-through modes
+#include "../../tools/glslang/glslang.h"
+
+
+#ifdef _WIN32
+#define WIN32_LEAN_AND_MEAN
+#define NOMINMAX
+#include <Windows.h>
+#undef WIN32_LEAN_AND_MEAN
+#undef NOMINMAX
+#include <d3dcompiler.h>
+#endif
+
+#ifdef CreateDirectory
+#undef CreateDirectory
+#endif
+
+using namespace CoreLib::Basic;
+using namespace CoreLib::IO;
+using namespace Slang::Compiler;
+
+namespace Slang
+{
+ namespace Compiler
+ {
+ //
+
+ Profile Profile::LookUp(char const* name)
+ {
+ #define PROFILE(TAG, NAME, STAGE, VERSION) if(strcmp(name, #NAME) == 0) return Profile::TAG;
+ #define PROFILE_ALIAS(TAG, NAME) if(strcmp(name, #NAME) == 0) return Profile::TAG;
+ #include "profile-defs.h"
+
+ return Profile::Unknown;
+ }
+
+
+
+ //
+
+ int compilerInstances = 0;
+
+ class ShaderCompilerImpl : public ShaderCompiler
+ {
+ public:
+
+ // Actual context for compilation... :(
+ struct ExtraContext
+ {
+ CompileOptions const* options = nullptr;
+ TranslationUnitOptions const* translationUnitOptions = nullptr;
+
+ CompileResult* compileResult = nullptr;
+
+ RefPtr<ProgramSyntaxNode> programSyntax;
+ ProgramLayout* programLayout;
+
+ String sourceText;
+ String sourcePath;
+
+ CompileOptions const& getOptions() { return *options; }
+ TranslationUnitOptions const& getTranslationUnitOptions() { return *translationUnitOptions; }
+ };
+
+
+ String EmitHLSL(ExtraContext& context)
+ {
+ if (context.getOptions().passThrough != PassThroughMode::None)
+ {
+ return context.sourceText;
+ }
+ else
+ {
+ // TODO(tfoley): probably need a way to customize the emit logic...
+ return emitProgram(
+ context.programSyntax.Ptr(),
+ context.programLayout,
+ CodeGenTarget::HLSL);
+ }
+ }
+
+ String emitGLSLForEntryPoint(ExtraContext& context, EntryPointOption const& entryPoint)
+ {
+ if (context.getOptions().passThrough != PassThroughMode::None)
+ {
+ return context.sourceText;
+ }
+ else
+ {
+ // TODO(tfoley): probably need a way to customize the emit logic...
+ return emitProgram(
+ context.programSyntax.Ptr(),
+ context.programLayout,
+ CodeGenTarget::GLSL);
+ }
+ }
+
+ char const* GetHLSLProfileName(Profile profile)
+ {
+ switch(profile.raw)
+ {
+ #define PROFILE(TAG, NAME, STAGE, VERSION) case Profile::TAG: return #NAME;
+ #include "profile-defs.h"
+
+ default:
+ // TODO: emit an error here!
+ return "unknown";
+ }
+ }
+
+#ifdef _WIN32
+ void* GetD3DCompilerDLL()
+ {
+ // TODO(tfoley): let user specify version of d3dcompiler DLL to use.
+ static HMODULE d3dCompiler = LoadLibraryA("d3dcompiler_47");
+ // TODO(tfoley): handle case where we can't find it gracefully
+ assert(d3dCompiler);
+ return d3dCompiler;
+ }
+
+ List<uint8_t> EmitDXBytecodeForEntryPoint(
+ ExtraContext& context,
+ EntryPointOption const& entryPoint)
+ {
+ static pD3DCompile D3DCompile_ = nullptr;
+ if (!D3DCompile_)
+ {
+ HMODULE d3dCompiler = (HMODULE)GetD3DCompilerDLL();
+ assert(d3dCompiler);
+
+ D3DCompile_ = (pD3DCompile)GetProcAddress(d3dCompiler, "D3DCompile");
+ assert(D3DCompile_);
+ }
+
+ // The HLSL compiler will try to "canonicalize" our input file path,
+ // and we don't want it to do that, because they it won't report
+ // the same locations on error messages that we would.
+ //
+ // To work around that, we prepend a custom `#line` directive.
+
+ String rawHlslCode = EmitHLSL(context);
+
+ StringBuilder hlslCodeBuilder;
+ hlslCodeBuilder << "#line 1 \"";
+ for(auto c : context.sourcePath)
+ {
+ char buffer[] = { c, 0 };
+ switch(c)
+ {
+ default:
+ hlslCodeBuilder << buffer;
+ break;
+
+ case '\\':
+ hlslCodeBuilder << "\\\\";
+ }
+ }
+ hlslCodeBuilder << "\"\n";
+ hlslCodeBuilder << rawHlslCode;
+
+ auto hlslCode = hlslCodeBuilder.ProduceString();
+
+ ID3DBlob* codeBlob;
+ ID3DBlob* diagnosticsBlob;
+ HRESULT hr = D3DCompile_(
+ hlslCode.begin(),
+ hlslCode.Length(),
+ context.sourcePath.begin(),
+ nullptr,
+ nullptr,
+ entryPoint.name.begin(),
+ GetHLSLProfileName(entryPoint.profile),
+ 0,
+ 0,
+ &codeBlob,
+ &diagnosticsBlob);
+ List<uint8_t> data;
+ if (codeBlob)
+ {
+ data.AddRange((uint8_t const*)codeBlob->GetBufferPointer(), (int)codeBlob->GetBufferSize());
+ codeBlob->Release();
+ }
+ if (diagnosticsBlob)
+ {
+ // TODO(tfoley): need a better policy for how we translate diagnostics
+ // back into the Slang world (although we should always try to generate
+ // HLSL that doesn't produce any diagnostics...)
+ String diagnostics = (char const*) diagnosticsBlob->GetBufferPointer();
+ fprintf(stderr, "%s", diagnostics.begin());
+ OutputDebugStringA(diagnostics.begin());
+ diagnosticsBlob->Release();
+ }
+ if (FAILED(hr))
+ {
+ // TODO(tfoley): What to do on failure?
+ }
+ return data;
+ }
+
+ List<uint8_t> EmitDXBytecode(
+ ExtraContext& context)
+ {
+ if(context.getTranslationUnitOptions().entryPoints.Count() != 1)
+ {
+ if(context.getTranslationUnitOptions().entryPoints.Count() == 0)
+ {
+ // TODO(tfoley): need to write diagnostics into this whole thing...
+ fprintf(stderr, "no entry point specified\n");
+ }
+ else
+ {
+ fprintf(stderr, "multiple entry points specified\n");
+ }
+ return List<uint8_t>();
+ }
+
+ return EmitDXBytecodeForEntryPoint(context, context.getTranslationUnitOptions().entryPoints[0]);
+ }
+
+ String EmitDXBytecodeAssemblyForEntryPoint(
+ ExtraContext& context,
+ EntryPointOption const& entryPoint)
+ {
+ static pD3DDisassemble D3DDisassemble_ = nullptr;
+ if (!D3DDisassemble_)
+ {
+ HMODULE d3dCompiler = (HMODULE)GetD3DCompilerDLL();
+ assert(d3dCompiler);
+
+ D3DDisassemble_ = (pD3DDisassemble)GetProcAddress(d3dCompiler, "D3DDisassemble");
+ assert(D3DDisassemble_);
+ }
+
+ List<uint8_t> dxbc = EmitDXBytecodeForEntryPoint(context, entryPoint);
+ if (!dxbc.Count())
+ {
+ return "";
+ }
+
+ ID3DBlob* codeBlob;
+ HRESULT hr = D3DDisassemble_(
+ &dxbc[0],
+ dxbc.Count(),
+ 0,
+ nullptr,
+ &codeBlob);
+
+ String result;
+ if (codeBlob)
+ {
+ result = String((char const*) codeBlob->GetBufferPointer());
+ codeBlob->Release();
+ }
+ if (FAILED(hr))
+ {
+ // TODO(tfoley): need to figure out what to diagnose here...
+ }
+ return result;
+ }
+
+
+ String EmitDXBytecodeAssembly(
+ ExtraContext& context)
+ {
+ if(context.getTranslationUnitOptions().entryPoints.Count() == 0)
+ {
+ // TODO(tfoley): need to write diagnostics into this whole thing...
+ fprintf(stderr, "no entry point specified\n");
+ return "";
+ }
+
+ StringBuilder sb;
+ for (auto entryPoint : context.getTranslationUnitOptions().entryPoints)
+ {
+ sb << EmitDXBytecodeAssemblyForEntryPoint(context, entryPoint);
+ }
+ return sb.ProduceString();
+ }
+
+
+ HMODULE getGLSLCompilerDLL()
+ {
+ // TODO(tfoley): let user specify version of glslang DLL to use.
+ static HMODULE glslCompiler = LoadLibraryA("glslang");
+ // TODO(tfoley): handle case where we can't find it gracefully
+ assert(glslCompiler);
+ return glslCompiler;
+ }
+
+
+ String emitSPIRVAssemblyForEntryPoint(
+ ExtraContext& context,
+ EntryPointOption const& entryPoint)
+ {
+ String rawGLSL = emitGLSLForEntryPoint(context, entryPoint);
+
+ static glslang_CompileFunc glslang_compile = nullptr;
+ if (!glslang_compile)
+ {
+ HMODULE glslCompiler = getGLSLCompilerDLL();
+ assert(glslCompiler);
+
+ glslang_compile = (glslang_CompileFunc)GetProcAddress(glslCompiler, "glslang_compile");
+ assert(glslang_compile);
+ }
+
+ StringBuilder diagnosticBuilder;
+ StringBuilder outputBuilder;
+
+ auto outputFunc = [](char const* text, void* userData)
+ {
+ *(StringBuilder*)userData << text;
+ };
+
+ glslang_CompileRequest request;
+ request.sourcePath = context.sourcePath.begin();
+ request.sourceText = rawGLSL.begin();
+ request.slangStage = (SlangStage) entryPoint.profile.GetStage();
+
+ request.diagnosticFunc = outputFunc;
+ request.diagnosticUserData = &diagnosticBuilder;
+
+ request.outputFunc = outputFunc;
+ request.outputUserData = &outputBuilder;
+
+ int err = glslang_compile(&request);
+
+ String diagnostics = diagnosticBuilder.ProduceString();
+ String output = outputBuilder.ProduceString();
+
+ if(err)
+ {
+ OutputDebugStringA(diagnostics.Buffer());
+ fprintf(stderr, "%s", diagnostics.Buffer());
+ exit(1);
+ }
+
+ return output;
+ }
+#endif
+
+ String emitSPIRVAssembly(
+ ExtraContext& context)
+ {
+ if(context.getTranslationUnitOptions().entryPoints.Count() == 0)
+ {
+ // TODO(tfoley): need to write diagnostics into this whole thing...
+ fprintf(stderr, "no entry point specified\n");
+ return "";
+ }
+
+ StringBuilder sb;
+ for (auto entryPoint : context.getTranslationUnitOptions().entryPoints)
+ {
+ sb << emitSPIRVAssemblyForEntryPoint(context, entryPoint);
+ }
+ return sb.ProduceString();
+ }
+
+ // Do emit logic for a single entry point
+ EntryPointResult emitEntryPoint(ExtraContext& context, EntryPointOption& entryPoint)
+ {
+ EntryPointResult result;
+
+ switch (context.getOptions().Target)
+ {
+ case CodeGenTarget::GLSL:
+ {
+ String code = emitGLSLForEntryPoint(context, entryPoint);
+ result.outputSource = code;
+ }
+ break;
+
+ case CodeGenTarget::DXBytecode:
+ {
+ auto code = EmitDXBytecodeForEntryPoint(context, entryPoint);
+
+ // TODO(tfoley): Need to figure out an appropriate interface
+ // for returning binary code, in addition to source.
+#if 0
+ if (context.compileResult)
+ {
+ StringBuilder sb;
+ sb.Append((char*) code.begin(), code.Count());
+
+ String codeString = sb.ProduceString();
+ result.outputSource = codeString;
+ }
+ else
+#endif
+ {
+ int col = 0;
+ for(auto ii : code)
+ {
+ if(col != 0) fputs(" ", stdout);
+ fprintf(stdout, "%02X", ii);
+ col++;
+ if(col == 8)
+ {
+ fputs("\n", stdout);
+ col = 0;
+ }
+ }
+ if(col != 0)
+ {
+ fputs("\n", stdout);
+ }
+ }
+ return result;
+ }
+ break;
+
+ case CodeGenTarget::DXBytecodeAssembly:
+ {
+ String code = EmitDXBytecodeAssemblyForEntryPoint(context, entryPoint);
+ result.outputSource = code;
+ }
+ break;
+
+ case CodeGenTarget::SPIRVAssembly:
+ {
+ String code = emitSPIRVAssemblyForEntryPoint(context, entryPoint);
+ result.outputSource = code;
+ }
+ break;
+
+ // Note(tfoley): We currently hit this case when compiling the stdlib
+ case CodeGenTarget::Unknown:
+ break;
+
+ default:
+ throw "unimplemented";
+ }
+
+ return result;
+
+
+ }
+
+ TranslationUnitResult emitTranslationUnitEntryPoints(ExtraContext& context)
+ {
+ TranslationUnitResult result;
+
+ for (auto& entryPoint : context.getTranslationUnitOptions().entryPoints)
+ {
+ EntryPointResult entryPointResult = emitEntryPoint(context, entryPoint);
+
+ result.entryPoints.Add(entryPointResult);
+ }
+
+ // The result for the translation unit will just be the concatenation
+ // of the results for each entry point. This doesn't actually make
+ // much sense, but it is good enough for now.
+ StringBuilder sb;
+ for (auto& entryPointResult : result.entryPoints)
+ {
+ sb << entryPointResult.outputSource;
+ }
+
+ result.outputSource = sb.ProduceString();
+
+ return result;
+ }
+
+ // Do emit logic for an entire translation unit, which might
+ // have zero or more entry points
+ TranslationUnitResult emitTranslationUnit(ExtraContext& context)
+ {
+ // Most of our code generation targets will require us
+ // to proceed through one entry point at a time, but
+ // in some cases we can emit an entire translation unit
+ // in one go.
+
+ switch (context.getOptions().Target)
+ {
+ default:
+ // The default behavior is going to loop over all the entry
+ // points, and then collect an aggregate result.
+ return emitTranslationUnitEntryPoints(context);
+
+ case CodeGenTarget::HLSL:
+ // When targetting HLSL, we can emit the entire translation unit
+ // as a single HLSL program, and include all the entry points.
+ {
+
+ String hlsl = EmitHLSL(context);
+
+ TranslationUnitResult result;
+ result.outputSource = hlsl;
+
+ // Because the user might ask for per-entry-point source,
+ // we will just attach the same string as the result for
+ // each entry point.
+ for( auto& entryPoint : context.getTranslationUnitOptions().entryPoints )
+ {
+ (void)entryPoint;
+
+ EntryPointResult entryPointResult;
+ entryPointResult.outputSource = hlsl;
+ result.entryPoints.Add(entryPointResult);
+ }
+
+ return result;
+ }
+ break;
+ }
+ }
+
+ TranslationUnitResult DoNewEmitLogic(ExtraContext& context)
+ {
+ TranslationUnitResult result = emitTranslationUnit(context);
+
+ // As a bit of a hack, we include a mode where we just
+ // print things to standard output, so that we can see them
+ //
+ // TODO(tfoley): Is this path ever needed/used?
+ if( !context.compileResult )
+ {
+ fprintf(stdout, "%s", result.outputSource.Buffer());
+ }
+
+ return result;
+ }
+
+ void DoNewEmitLogic(
+ ExtraContext& context,
+ CollectionOfTranslationUnits* collectionOfTranslationUnits)
+ {
+ switch (context.getOptions().Target)
+ {
+ default:
+ // For most targets, we will do things per-translation-unit
+ for( auto translationUnit : collectionOfTranslationUnits->translationUnits )
+ {
+ ExtraContext innerContext = context;
+ innerContext.translationUnitOptions = &translationUnit.options;
+ innerContext.programSyntax = translationUnit.SyntaxNode;
+ innerContext.sourcePath = "slang"; // don't have this any more!
+ innerContext.sourceText = "";
+
+ TranslationUnitResult translationUnitResult = DoNewEmitLogic(innerContext);
+ context.compileResult->translationUnits.Add(translationUnitResult);
+ }
+ break;
+
+ case CodeGenTarget::ReflectionJSON:
+ {
+ String reflectionJSON = emitReflectionJSON(context.programLayout);
+
+ // HACK(tfoley): just print it out since that is what people probably expect.
+ // TODO: need a way to control where output gets routed across all possible targets.
+ fprintf(stdout, "%s", reflectionJSON.begin());
+ }
+ break;
+ }
+ }
+
+ virtual void Compile(
+ CompileResult& result,
+ CollectionOfTranslationUnits* collectionOfTranslationUnits,
+ const CompileOptions& options) override
+ {
+ RefPtr<SyntaxVisitor> visitor = CreateSemanticsVisitor(result.GetErrorWriter(), options);
+ try
+ {
+ for( auto& translationUnit : collectionOfTranslationUnits->translationUnits )
+ {
+ visitor->setSourceLanguage(translationUnit.options.sourceLanguage);
+
+ translationUnit.SyntaxNode->Accept(visitor.Ptr());
+ }
+ if (result.GetErrorCount() > 0)
+ return;
+
+ // Do binding generation, and then reflection (globally)
+ // before we move on to any code-generation activites.
+ GenerateParameterBindings(collectionOfTranslationUnits);
+
+
+ // HACK(tfoley): for right now I just want to pretty-print an AST
+ // into another language, so the whole compiler back-end is just
+ // getting in the way.
+ //
+ // I'm going to bypass it for now and see what I can do:
+
+ ExtraContext extra;
+ extra.options = &options;
+ extra.programLayout = collectionOfTranslationUnits->layout.Ptr();
+ extra.compileResult = &result;
+ DoNewEmitLogic(extra, collectionOfTranslationUnits);
+ }
+ catch (int)
+ {
+ }
+ catch (...)
+ {
+ throw;
+ }
+ return;
+ }
+
+ ShaderCompilerImpl()
+ {
+ if (compilerInstances == 0)
+ {
+ BasicExpressionType::Init();
+ }
+ compilerInstances++;
+ }
+
+ ~ShaderCompilerImpl()
+ {
+ compilerInstances--;
+ if (compilerInstances == 0)
+ {
+ BasicExpressionType::Finalize();
+ SlangStdLib::Finalize();
+ }
+ }
+
+ virtual TranslationUnitResult PassThrough(
+ String const& sourceText,
+ String const& sourcePath,
+ const CompileOptions & options,
+ TranslationUnitOptions const& translationUnitOptions) override
+ {
+ ExtraContext extra;
+ extra.options = &options;
+ extra.translationUnitOptions = &translationUnitOptions;
+ extra.sourcePath = sourcePath;
+ extra.sourceText = sourceText;
+
+ return DoNewEmitLogic(extra);
+ }
+
+ };
+
+ ShaderCompiler * CreateShaderCompiler()
+ {
+ return new ShaderCompilerImpl();
+ }
+
+ }
+}
diff --git a/source/slang/compiler.h b/source/slang/compiler.h
new file mode 100644
index 000000000..f35d6603c
--- /dev/null
+++ b/source/slang/compiler.h
@@ -0,0 +1,156 @@
+#ifndef RASTER_SHADER_COMPILER_H
+#define RASTER_SHADER_COMPILER_H
+
+#include "../core/basic.h"
+
+#include "compiled-program.h"
+#include "diagnostics.h"
+#include "profile.h"
+#include "syntax.h"
+#include "type-layout.h"
+
+#include "../../slang.h"
+
+namespace Slang
+{
+ namespace Compiler
+ {
+ class ILConstOperand;
+ struct IncludeHandler;
+
+ enum class CompilerMode
+ {
+ ProduceLibrary,
+ ProduceShader,
+ GenerateChoice
+ };
+
+ enum class StageTarget
+ {
+ Unknown,
+ VertexShader,
+ HullShader,
+ DomainShader,
+ GeometryShader,
+ FragmentShader,
+ ComputeShader,
+ };
+
+ enum class CodeGenTarget
+ {
+ Unknown = SLANG_TARGET_UNKNOWN,
+ GLSL = SLANG_GLSL,
+ GLSL_Vulkan = SLANG_GLSL_VULKAN,
+ GLSL_Vulkan_OneDesc = SLANG_GLSL_VULKAN_ONE_DESC,
+ HLSL = SLANG_HLSL,
+ SPIRV = SLANG_SPIRV,
+ SPIRVAssembly = SLANG_SPIRV_ASM,
+ DXBytecode = SLANG_DXBC,
+ DXBytecodeAssembly = SLANG_DXBC_ASM,
+ ReflectionJSON = SLANG_REFLECTION_JSON,
+ };
+
+ // Describes an entry point that we've been requested to compile
+ struct EntryPointOption
+ {
+ String name;
+ Profile profile;
+ };
+
+ enum class PassThroughMode : SlangPassThrough
+ {
+ None = SLANG_PASS_THROUGH_NONE, // don't pass through: use Slang compiler
+ HLSL = SLANG_PASS_THROUGH_FXC, // pass through HLSL to `D3DCompile` API
+// GLSL, // pass through GLSL to `glslang` library
+ };
+
+ // Represents a single source file (either an on-disk file, or a
+ // "virtual" file passed in as a string)
+ class SourceFile : public RefObject
+ {
+ public:
+ // The file path for a real file, or the nominal path for a virtual file
+ String path;
+
+ // The actual contents of the file
+ String content;
+ };
+
+ // Options for a single translation unit being requested by the user
+ class TranslationUnitOptions
+ {
+ public:
+ SourceLanguage sourceLanguage = SourceLanguage::Unknown;
+
+ // All entry points we've been asked to compile for this translation unit
+ List<EntryPointOption> entryPoints;
+
+ // The source file(s) that will be compiled to form this translation unit
+ List<RefPtr<SourceFile> > sourceFiles;
+ };
+
+ class CompileOptions
+ {
+ public:
+ CompilerMode Mode = CompilerMode::ProduceShader;
+ CodeGenTarget Target = CodeGenTarget::Unknown;
+ StageTarget stage = StageTarget::Unknown;
+ EnumerableDictionary<String, String> BackendArguments;
+
+ String SymbolToCompile;
+ String outputName;
+ List<String> TemplateShaderArguments;
+ List<String> SearchDirectories;
+ Dictionary<String, String> PreprocessorDefinitions;
+
+ List<TranslationUnitOptions> translationUnits;
+
+ // the code generation profile we've been asked to use
+ Profile profile;
+
+ // should we just pass the input to another compiler?
+ PassThroughMode passThrough = PassThroughMode::None;
+
+ // Flags supplied through the API
+ SlangCompileFlags flags = 0;
+ };
+
+ // This is the representation of a given translation unit
+ class CompileUnit
+ {
+ public:
+ TranslationUnitOptions options;
+ RefPtr<ProgramSyntaxNode> SyntaxNode;
+ };
+
+ // TODO: pick an appropriate name for this...
+ class CollectionOfTranslationUnits : public RefObject
+ {
+ public:
+ List<CompileUnit> translationUnits;
+
+ // TODO: this is more output-oriented, but maybe okay to have here...
+ RefPtr<ProgramLayout> layout;
+ };
+
+ class ShaderCompiler : public CoreLib::Basic::Object
+ {
+ public:
+ virtual void Compile(
+ CompileResult& result,
+ CollectionOfTranslationUnits* collectionOfTranslationUnits,
+ const CompileOptions& options) = 0;
+
+ virtual TranslationUnitResult PassThrough(
+ String const& sourceText,
+ String const& sourcePath,
+ const CompileOptions & options,
+ TranslationUnitOptions const& translationUnitOptions) = 0;
+
+ };
+
+ ShaderCompiler * CreateShaderCompiler();
+ }
+}
+
+#endif \ No newline at end of file
diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h
new file mode 100644
index 000000000..3f690a5da
--- /dev/null
+++ b/source/slang/diagnostic-defs.h
@@ -0,0 +1,338 @@
+//
+
+// The file is meant to be included multiple times, to produce different
+// pieces of declaration/definition code related to diagnostic messages
+//
+// Each diagnostic is declared here with:
+//
+// DIAGNOSTIC(id, severity, name, messageFormat)
+//
+// Where `id` is the unique diagnostic ID, `severity` is the default
+// severity (from the `Severity` enum), `name` is a name used to refer
+// to this diagnostic from code, and `messageFormat` is the default
+// (non-localized) message for the diagnostic, with placeholders
+// for any arguments.
+
+#ifndef DIAGNOSTIC
+#error Need to #define DIAGNOSTIC(...) before including "DiagnosticDefs.h"
+#define DIAGNOSTIC(id, severity, name, messageFormat) /* */
+#endif
+
+//
+// -1 - Notes that decorate another diagnostic.
+//
+
+DIAGNOSTIC(-1, Note, alsoSeePipelineDefinition, "also see pipeline definition");
+DIAGNOSTIC(-1, Note, implicitParameterMatchingFailedBecauseNameNotAccessible, "implicit parameter matching failed because the component of the same name is not accessible from '$0'.\ncheck if you have declared necessary requirements and properly used the 'public' qualifier.")
+DIAGNOSTIC(-1, Note, implicitParameterMatchingFailedBecauseShaderDoesNotDefineComponent, "implicit parameter matching failed because shader '$0' does not define component '$1'.")
+DIAGNOSTIC(-1, Note, implicitParameterMatchingFailedBecauseTypeMismatch, "implicit parameter matching failed because the component of the same name does not match parameter type '$0'.")
+DIAGNOSTIC(-1, Note, noteShaderIsTargetingPipeine, "shader '$0' is targeting pipeline '$1'")
+DIAGNOSTIC(-1, Note, seeDefinitionOf, "see definition of '$0'")
+DIAGNOSTIC(-1, Note, seeInterfaceDefinitionOf, "see interface definition of '$0'")
+DIAGNOSTIC(-1, Note, seeUsingOf, "see using of '$0'")
+DIAGNOSTIC(-1, Note, seeDefinitionOfShader, "see definition of shader '$0'")
+DIAGNOSTIC(-1, Note, seeInclusionOf, "see inclusion of '$0'")
+DIAGNOSTIC(-1, Note, seeModuleBeingUsedIn, "see module '$0' being used in '$1'")
+DIAGNOSTIC(-1, Note, seePipelineRequirementDefinition, "see pipeline requirement definition")
+DIAGNOSTIC(-1, Note, seePotentialDefinitionOfComponent, "see potential definition of component '$0'")
+DIAGNOSTIC(-1, Note, seePreviousDefinition, "see previous definition")
+DIAGNOSTIC(-1, Note, seePreviousDefinitionOf, "see previous definition of '$0'")
+DIAGNOSTIC(-1, Note, seeRequirementDeclaration, "see requirement declaration")
+DIAGNOSTIC(-1, Note, doYouForgetToMakeComponentAccessible, "do you forget to make component '$0' acessible from '$1' (missing public qualifier)?")
+//
+// 0xxxx - Command line and interaction with host platform APIs.
+//
+
+DIAGNOSTIC( 1, Error, cannotOpenFile, "cannot open file '$0'.")
+DIAGNOSTIC( 2, Error, cannotFindFile, "cannot find file '$0'.")
+DIAGNOSTIC( 2, Error, unsupportedCompilerMode, "unsupported compiler mode.")
+DIAGNOSTIC( 4, Error, cannotWriteOutputFile, "cannot write output file '$0'.")
+
+//
+// 1xxxx - Lexical anaylsis
+//
+
+DIAGNOSTIC(10000, Error, illegalCharacterPrint, "illegal character '$0'");
+DIAGNOSTIC(10000, Error, illegalCharacterHex, "illegal character (0x$0)");
+DIAGNOSTIC(10001, Error, illegalCharacterLiteral, "illegal character literal");
+
+DIAGNOSTIC(10002, Warning, octalLiteral, "'0' prefix indicates octal literal")
+DIAGNOSTIC(10003, Error, invalidDigitForBase, "invalid digit for base-$1 literal: '$0'")
+
+DIAGNOSTIC(10004, Error, endOfFileInLiteral, "end of file in literal");
+DIAGNOSTIC(10005, Error, newlineInLiteral, "newline in literal");
+
+//
+// 15xxx - Preprocessing
+//
+
+// 150xx - conditionals
+DIAGNOSTIC(15000, Error, endOfFileInPreprocessorConditional, "end of file encountered during preprocessor conditional")
+DIAGNOSTIC(15001, Error, directiveWithoutIf, "'$0' directive without '#if'")
+DIAGNOSTIC(15002, Error, directiveAfterElse , "'$0' directive without '#if'")
+
+DIAGNOSTIC(-1, Note, seeDirective, "see '$0' directive")
+
+// 151xx - directive parsing
+DIAGNOSTIC(15100, Error, expectedPreprocessorDirectiveName, "expected preprocessor directive name")
+DIAGNOSTIC(15101, Error, unknownPreprocessorDirective, "unknown preprocessor directive '$0'")
+DIAGNOSTIC(15102, Error, expectedTokenInPreprocessorDirective, "expected '$0' in '$1' directive")
+DIAGNOSTIC(15102, Error, expected2TokensInPreprocessorDirective, "expected '$0' or '$1' in '$2' directive")
+DIAGNOSTIC(15103, Error, unexpectedTokensAfterDirective, "unexpected tokens following '$0' directive")
+
+
+// 152xx - preprocessor expressions
+DIAGNOSTIC(15200, Error, expectedTokenInPreprocessorExpression, "expected '$0' in preprocessor expression");
+DIAGNOSTIC(15201, Error, syntaxErrorInPreprocessorExpression, "syntax error in preprocessor expression");
+DIAGNOSTIC(15202, Error, divideByZeroInPreprocessorExpression, "division by zero in preprocessor expression");
+DIAGNOSTIC(15203, Error, expectedTokenInDefinedExpression, "expected '$0' in 'defined' expression");
+DIAGNOSTIC(15204, Warning, directiveExpectsExpression, "'$0' directive requires an expression");
+
+DIAGNOSTIC(-1, Note, seeOpeningToken, "see opening '$0'")
+
+// 153xx - #include
+DIAGNOSTIC(15300, Error, includeFailed, "failed to find include file '$0'")
+DIAGNOSTIC(-1, Error, noIncludeHandlerSpecified, "no `#include` handler was specified")
+
+// 154xx - macro definition
+DIAGNOSTIC(15400, Warning, macroRedefinition, "redefinition of macro '$0'")
+DIAGNOSTIC(15401, Warning, macroNotDefined, "macro '$0' is not defined")
+DIAGNOSTIC(15403, Error, expectedTokenInMacroParameters, "expected '$0' in macro parameters")
+
+// 155xx - macro expansion
+DIAGNOSTIC(15500, Warning, expectedTokenInMacroArguments, "expected '$0' in macro invocation")
+
+// 159xx - user-defined error/warning
+DIAGNOSTIC(15900, Error, userDefinedError, "#error: $0")
+DIAGNOSTIC(15901, Warning, userDefinedWarning, "#warning: $0")
+
+//
+// 2xxxx - Parsing
+//
+
+DIAGNOSTIC(20003, Error, unexpectedToken, "unexpected $0");
+DIAGNOSTIC(20001, Error, unexpectedTokenExpectedTokenType, "unexpected $0, expected $1");
+DIAGNOSTIC(20001, Error, unexpectedTokenExpectedTokenName, "unexpected $0, expected '$1'");
+
+DIAGNOSTIC(0, Error, tokenNameExpectedButEOF, "\"$0\" expected but end of file encountered.");
+DIAGNOSTIC(0, Error, tokenTypeExpectedButEOF, "$0 expected but end of file encountered.");
+DIAGNOSTIC(20001, Error, tokenNameExpected, "\"$0\" expected");
+DIAGNOSTIC(20001, Error, tokenNameExpectedButEOF2, "\"$0\" expected but end of file encountered.");
+DIAGNOSTIC(20001, Error, tokenTypeExpected, "$0 expected");
+DIAGNOSTIC(20001, Error, tokenTypeExpectedButEOF2, "$0 expected but end of file encountered.");
+DIAGNOSTIC(20001, Error, typeNameExpectedBut, "unexpected $0, expected type name");
+DIAGNOSTIC(20001, Error, typeNameExpectedButEOF, "type name expected but end of file encountered.");
+DIAGNOSTIC(20001, Error, unexpectedEOF, " Unexpected end of file.");
+DIAGNOSTIC(20002, Error, syntaxError, "syntax error.");
+DIAGNOSTIC(20004, Error, unexpectedTokenExpectedComponentDefinition, "unexpected token '$0', only component definitions are allowed in a shader scope.")
+DIAGNOSTIC(20008, Error, invalidOperator, "invalid operator '$0'.");
+DIAGNOSTIC(20011, Error, unexpectedColon, "unexpected ':'.")
+
+//
+// 3xxxx - Semantic analysis
+//
+
+DIAGNOSTIC(30001, Error, functionRedefinitionWithArgList, "'$0$1': function redefinition.")
+DIAGNOSTIC(30002, Error, parameterAlreadyDefined, "parameter '$0' already defined.")
+DIAGNOSTIC(30003, Error, breakOutsideLoop, "'break' must appear inside loop constructs.")
+DIAGNOSTIC(30004, Error, continueOutsideLoop, "'continue' must appear inside loop constructs.")
+DIAGNOSTIC(30005, Error, whilePredicateTypeError, "'while': expression must evaluate to int.")
+DIAGNOSTIC(30006, Error, ifPredicateTypeError, "'if': expression must evaluate to int.")
+DIAGNOSTIC(30006, Error, returnNeedsExpression, "'return' should have an expression.")
+DIAGNOSTIC(30007, Error, componentReturnTypeMismatch, "expression type '$0' does not match component's type '$1'")
+DIAGNOSTIC(30007, Error, functionReturnTypeMismatch, "expression type '$0' does not match function's return type '$1'")
+DIAGNOSTIC(30008, Error, variableNameAlreadyDefined, "variable $0 already defined.")
+DIAGNOSTIC(30009, Error, invalidTypeVoid, "invalid type 'void'.")
+DIAGNOSTIC(30010, Error, whilePredicateTypeError2, "'while': expression must evaluate to int.")
+DIAGNOSTIC(30011, Error, assignNonLValue, "left of '=' is not an l-value.")
+DIAGNOSTIC(30012, Error, noApplicationUnaryOperator, "no overload found for operator $0 ($1).")
+DIAGNOSTIC(30012, Error, noOverloadFoundForBinOperatorOnTypes, "no overload found for operator $0 ($1, $2).")
+DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for type '$0'")
+DIAGNOSTIC(30014, Error, subscriptIndexNonInteger, "index expression must evaluate to int.")
+DIAGNOSTIC(30015, Error, undefinedIdentifier, "'$0': undefined identifier.")
+DIAGNOSTIC(30015, Error, undefinedIdentifier2, "undefined identifier '$0'.")
+DIAGNOSTIC(30016, Error, parameterCannotBeVoid, "'void' can not be parameter type.")
+DIAGNOSTIC(30017, Error, componentNotAccessibleFromShader, "component '$0' is not accessible from shader '$1'.")
+DIAGNOSTIC(30019, Error, typeMismatch, "expected an expression of type '$0', got '$1'")
+DIAGNOSTIC(30020, Error, importOperatorReturnTypeMismatch, "import operator should return '$1', but the expression has type '$0''. do you forget 'project'?")
+DIAGNOSTIC(30021, Error, noApplicationFunction, "$0: no overload takes arguments ($1)")
+DIAGNOSTIC(30022, Error, invalidTypeCast, "invalid type cast between \"$0\" and \"$1\".")
+DIAGNOSTIC(30023, Error, typeHasNoPublicMemberOfName, "\"$0\" does not have public member \"$1\".");
+DIAGNOSTIC(30025, Error, invalidArraySize, "array size must be larger than zero.")
+DIAGNOSTIC(30026, Error, returnInComponentMustComeLast, "'return' can only appear as the last statement in component definition.")
+DIAGNOSTIC(30027, Error, noMemberOfNameInType, "'$0' is not a member of '$1'.");
+DIAGNOSTIC(30028, Error, forPredicateTypeError, "'for': predicate expression must evaluate to bool.")
+DIAGNOSTIC(30030, Error, projectionOutsideImportOperator, "'project': invalid use outside import operator.")
+DIAGNOSTIC(30031, Error, projectTypeMismatch, "'project': expression must evaluate to record type '$0'.")
+DIAGNOSTIC(30033, Error, invalidTypeForLocalVariable, "cannot declare a local variable of this type.")
+DIAGNOSTIC(30035, Error, componentOverloadTypeMismatch, "'$0': type of overloaded component mismatches previous definition.")
+DIAGNOSTIC(30041, Error, bitOperationNonIntegral, "bit operation: operand must be integral type.")
+DIAGNOSTIC(30047, Error, argumentExpectedLValue, "argument passed to parameter '$0' must be l-value.")
+DIAGNOSTIC(30051, Error, invalidValueForArgument, "invalid value for argument '$0'")
+DIAGNOSTIC(30052, Error, ordinaryFunctionAsModuleArgument, "ordinary functions not allowed as argument to function-typed module parameter.")
+DIAGNOSTIC(30079, Error, selectPrdicateTypeMismatch, "selector must evaluate to bool.");
+DIAGNOSTIC(30080, Error, selectValuesTypeMismatch, "the two value expressions in a select clause must have same type.");
+DIAGNOSTIC(31040, Error, undefinedTypeName, "undefined type name: '$0'.")
+DIAGNOSTIC(32013, Error, circularReferenceNotAllowed, "'$0': circular reference is not allowed.");
+DIAGNOSTIC(32014, Error, shaderDoesProvideRequirement, "shader '$0' does not provide '$1' as required by '$2'.")
+DIAGNOSTIC(32015, Error, argumentNotAvilableInWorld, "argument '$0' is not available in world '$1' as required by '$2'.")
+DIAGNOSTIC(32015, Error, componentNotAvilableInWorld, "component '$0' is not available in world '$1' as required by '$2'.")
+DIAGNOSTIC(32047, Error, firstArgumentToImportNotComponent, "first argument of an import operator call does not resolve to a component.");
+DIAGNOSTIC(32051, Error, componentTypeNotWhatPipelineRequires, "component '$0' has type '$1', but pipeline '$2' requires it to be '$3'.")
+DIAGNOSTIC(32052, Error, shaderDoesNotDefineComponentAsRequiredByPipeline, "shader '$0' does not define '$1' as required by pipeline '$2''.")
+DIAGNOSTIC(33001, Error, worldNameAlreadyDefined, "world '$0' is already defined.")
+DIAGNOSTIC(33002, Error, explicitPipelineSpecificationRequiredForShader, "explicit pipeline specification required for shader '$0' because multiple pipelines are defined in current context.")
+DIAGNOSTIC(33003, Error, cannotDefineComponentsInAPipeline, "cannot define components in a pipeline.")
+DIAGNOSTIC(33004, Error, undefinedWorldName, "undefined world name '$0'.")
+DIAGNOSTIC(33005, Error, abstractWorldAsTargetOfImport, "abstract world cannot appear as target as an import operator.")
+
+// Note(tfoley): This is a duplicate of 33004 above.
+DIAGNOSTIC(33006, Error, undefinedWorldName2, "undefined world name '$0'.")
+
+DIAGNOSTIC(33007, Error, importOperatorCircularity, "import operator '$0' creates a circular dependency between world '$1' and '$2'")
+DIAGNOSTIC(33009, Error, parametersOnlyAllowedInModules, "parameters can only be defined in modules.")
+DIAGNOSTIC(33010, Error, undefinedPipelineName, "pipeline '$0' is undefined.")
+DIAGNOSTIC(33011, Error, shaderCircularity, "shader '$0' involves circular reference.")
+DIAGNOSTIC(33012, Error, worldIsNotDefinedInPipeline, "'$0' is not a defined world in '$1'.")
+DIAGNOSTIC(33013, Error, abstractWorldCannotAppearWithOthers, "abstract world cannot appear with other worlds.")
+DIAGNOSTIC(33014, Error, nonAbstractComponentMustHaveImplementation, "non-abstract component must have an implementation.")
+DIAGNOSTIC(33016, Error, usingInComponentDefinition, "'using': importing not allowed in component definition.")
+DIAGNOSTIC(33018, Error, nameAlreadyDefined, "'$0' is already defined.")
+DIAGNOSTIC(33018, Error, shaderAlreadyDefined, "shader '$0' has already been defined.")
+DIAGNOSTIC(33019, Error, componentMarkedExportMustHaveWorld, "component '$0': definition marked as 'export' must have an explicitly specified world.")
+DIAGNOSTIC(33020, Error, componentIsAlreadyDefined, "'$0' is already defined.")
+DIAGNOSTIC(33020, Error, componentIsAlreadyDefinedInThatWorld, "'$0' is already defined at '$1'.")
+DIAGNOSTIC(33021, Error, inconsistentSignatureForComponent, "'$0': inconsistent signature.")
+DIAGNOSTIC(33022, Error, nameAlreadyDefinedInCurrentScope, "'$0' is already defined in current scope.")
+DIAGNOSTIC(33022, Error, parameterNameConflictsWithExistingDefinition, "'$0': parameter name conflicts with existing definition.")
+DIAGNOSTIC(33023, Error, parameterOfModuleIsUnassigned, "parameter '$0' of module '$1' is unassigned.")
+DIAGNOSTIC(33027, Error, argumentTypeDoesNotMatchParameterType, "argument type ($0) does not match parameter type ($1)")
+DIAGNOSTIC(33028, Error, nameIsNotAParameterOfCallee, "'$0' is not a parameter of '$1'.")
+DIAGNOSTIC(33029, Error, requirementsClashWithPreviousDef, "'$0': requirement clash with previous definition.")
+DIAGNOSTIC(33030, Error, positionArgumentAfterNamed, "positional argument cannot appear after a named argument.")
+DIAGNOSTIC(33032, Error, functionRedefinition, "'$0': function redefinition.")
+DIAGNOSTIC(33034, Error, recordTypeVariableInImportOperator, "cannot declare a record-typed variable in an import operator.")
+DIAGNOSTIC(33037, Error, componetMarkedExportCannotHaveParameters, "component '$0': definition marked as 'export' cannot have parameters.")
+DIAGNOSTIC(33039, Error, componentInInputWorldCantHaveCode, "'$0': no code allowed for component defined in input world.")
+DIAGNOSTIC(33040, Error, requireWithComputation, "'require': cannot define computation on component requirements.")
+DIAGNOSTIC(33042, Error, paramWithComputation, "'param': cannot define computation on parameters.")
+DIAGNOSTIC(33041, Error, pipelineOfModuleIncompatibleWithPipelineOfShader, "pipeline '$0' targeted by module '$1' is incompatible with pipeline '$2' targeted by shader '$3'.")
+DIAGNOSTIC(33070, Error, expectedFunction, "expression preceding parenthesis of apparent call must have function type.")
+DIAGNOSTIC(33071, Error, importOperatorCalledFromAutoPlacedComponent, "cannot call an import operator from an auto-placed component '$0'. try qualify the component with explicit worlds.")
+DIAGNOSTIC(33072, Error, noApplicableImportOperator, "'$0' is an import operator defined in pipeline '$1', but none of the import operator overloads converting to world '$2' matches argument list ($3).")
+DIAGNOSTIC(33073, Error, importOperatorCalledFromMultiWorldComponent, "cannot call an import operator from a multi-world component definition. consider qualify the component with only one explicit world.")
+DIAGNOSTIC(33080, Error, componentTypeDoesNotMatchInterface, "'$0': component type does not match definition in interface '$1'.")
+DIAGNOSTIC(33081, Error, shaderDidNotDefineComponentFunction, "shader '$0' did not define component function $1 as required by interface '$2'.")
+DIAGNOSTIC(33082, Error, shaderDidNotDefineComponent, "shader '$0' did not define component '$1' as required by interface '$2'.")
+DIAGNOSTIC(33083, Error, interfaceImplMustBePublic, "'$0': component fulfilling interface '$1' must be declared as 'public'.")
+DIAGNOSTIC(33084, Error, defaultParamNotAllowedInInterface, "'$0': default parameter value not allowed in interface definition.")
+
+DIAGNOSTIC(33100, Error, componentCantBeComputedAtWorldBecauseDependentNotAvailable, "'$0' cannot be computed at '$1' because the dependent component '$2' is not accessible.")
+DIAGNOSTIC(33101, Warning, worldIsNotAValidChoiceForKey, "'$0' is not a valid choice for '$1'.")
+DIAGNOSTIC(33102, Error, componentDefinitionCircularity, "component definition '$0' involves circular reference.")
+DIAGNOSTIC(34024, Error, componentAlreadyDefinedWhenCompiling, "component named '$0' is already defined when compiling '$1'.")
+DIAGNOSTIC(34025, Error, globalComponentConflictWithPreviousDeclaration, "'$0': global component conflicts with previous declaration.")
+DIAGNOSTIC(34026, Warning, componentIsAlreadyDefinedUseRequire, "'$0': component is already defined when compiling shader '$1'. use 'require' to declare it as a parameter.")
+DIAGNOSTIC(34062, Error, cylicReference, "cyclic reference: $0");
+DIAGNOSTIC(34064, Error, noApplicableImplicitImportOperator, "cannot find import operator to import component '$0' to world '$1' when compiling '$2'.")
+DIAGNOSTIC(34065, Error, resourceTypeMustBeParamOrRequire, "'$0': resource typed component must be declared as 'param' or 'require'.");
+DIAGNOSTIC(34066, Error, cannotDefineComputationOnResourceType, "'$0': cannot define computation on resource typed component.");
+
+DIAGNOSTIC(35001, Error, fragDepthAttributeCanOnlyApplyToOutput, "FragDepth attribute can only apply to an output component.");
+DIAGNOSTIC(35002, Error, fragDepthAttributeCanOnlyApplyToFloatComponent, "FragDepth attribute can only apply to a float component.");
+
+
+DIAGNOSTIC(36001, Error, insufficientTemplateShaderArguments, "instantiating template shader '$0': insufficient arguments.");
+DIAGNOSTIC(36002, Error, tooManyTemplateShaderArguments, "instantiating template shader '$0': too many arguments.");
+DIAGNOSTIC(36003, Error, templateShaderArgumentIsNotDefined, "'$0' provided as template shader argument to '$1' is not a defined module.");
+DIAGNOSTIC(36004, Error, templateShaderArgumentDidNotImplementRequiredInterface, "module '$0' provided as template shader argument to '$1' did not implement required interface '$2'.");
+
+// TODO: need to assign numbers to all these extra diagnostics...
+
+DIAGNOSTIC(39999, Error, expectedIntegerConstantWrongType, "expected integer constant (found: '$0')")
+DIAGNOSTIC(39999, Error, expectedIntegerConstantNotConstant, "expression does not evaluate to a compile-time constant")
+DIAGNOSTIC(39999, Error, expectedIntegerConstantNotLiteral, "could not extract value from integer constant")
+
+DIAGNOSTIC(39999, Error, noApplicableOverloadForNameWithArgs, "no overload for '$0' applicable to arguments of type $1")
+DIAGNOSTIC(39999, Error, noApplicableWithArgs, "no overload applicable to arguments of type $0")
+
+DIAGNOSTIC(39999, Error, ambiguousOverloadForNameWithArgs, "ambiguous call to '$0' operation with arguments of type $1")
+DIAGNOSTIC(39999, Error, ambiguousOverloadWithArgs, "ambiguous call to overloaded operation with arguments of type $0")
+
+DIAGNOSTIC(39999, Note, overloadCandidate, "candidate: $0")
+DIAGNOSTIC(39999, Note, moreOverloadCandidates, "$0 more overload candidates")
+
+DIAGNOSTIC(39999, Error, caseOutsideSwitch, "'case' not allowed outside of a 'switch' statement")
+DIAGNOSTIC(39999, Error, defaultOutsideSwitch, "'default' not allowed outside of a 'switch' statement")
+
+DIAGNOSTIC(39999, Error, expectedAGeneric, "expected a generic when using '<...>' (found: '$0')")
+
+DIAGNOSTIC(39999, Error, genericArgumentInferenceFailed, "could not specialize generic for arguments of type $0")
+DIAGNOSTIC(39999, Note, genericSignatureTried, "see declaration of $0")
+
+DIAGNOSTIC(39999, Error, expectedATraitGot, "expected a trait, got '$0'")
+
+DIAGNOSTIC(39999, Error, ambiguousReference, "amiguous reference to '$0'");
+
+DIAGNOSTIC(39999, Error, declarationDidntDeclareAnything, "declaration does not declare anything");
+
+
+DIAGNOSTIC(39999, Error, expectedPrefixOperator, "function called as prefix operator was not declared `__prefix`")
+DIAGNOSTIC(39999, Error, expectedPostfixOperator, "function called as postfix operator was not declared `__postfix`")
+
+DIAGNOSTIC(39999, Error, notEnoughArguments, "not enough arguments to call (got $0, expected $1)")
+DIAGNOSTIC(39999, Error, tooManyArguments, "too many arguments to call (got $0, expected $1)")
+
+//
+// 4xxxx - IL code generation.
+//
+DIAGNOSTIC(40001, Error, bindingAlreadyOccupiedByComponent, "resource binding location '$0' is already occupied by component '$1'.")
+DIAGNOSTIC(40002, Error, invalidBindingValue, "binding location '$0' is out of valid range.")
+DIAGNOSTIC(40003, Error, bindingExceedsLimit, "binding location '$0' assigned to component '$1' exceeds maximum limit.")
+DIAGNOSTIC(40004, Error, bindingAlreadyOccupiedByModule, "DescriptorSet ID '$0' is already occupied by module instance '$1'.")
+DIAGNOSTIC(40005, Error, topLevelModuleUsedWithoutSpecifyingBinding, "top level module '$0' is being used without specifying binding location. Use [Binding: \"index\"] attribute to provide a binding location.")
+//
+// 5xxxx - Target code generation.
+//
+
+DIAGNOSTIC(50020, Error, unknownStageType, "Unknown stage type '$0'.")
+DIAGNOSTIC(50020, Error, invalidTessCoordType, "TessCoord must have vec2 or vec3 type.")
+DIAGNOSTIC(50020, Error, invalidFragCoordType, "FragCoord must be a vec4.")
+DIAGNOSTIC(50020, Error, invalidInvocationIdType, "InvocationId must have int type.")
+DIAGNOSTIC(50020, Error, invalidThreadIdType, "ThreadId must have int type.")
+DIAGNOSTIC(50020, Error, invalidPrimitiveIdType, "PrimitiveId must have int type.")
+DIAGNOSTIC(50020, Error, invalidPatchVertexCountType, "PatchVertexCount must have int type.")
+DIAGNOSTIC(50022, Error, worldIsNotDefined, "world '$0' is not defined.");
+DIAGNOSTIC(50023, Error, stageShouldProvideWorldAttribute, "'$0' should provide 'World' attribute.");
+DIAGNOSTIC(50040, Error, componentHasInvalidTypeForPositionOutput, "'$0': component used as 'Position' output must be of vec4 type.")
+DIAGNOSTIC(50041, Error, componentNotDefined, "'$0': component not defined.")
+
+DIAGNOSTIC(50052, Error, domainShaderRequiresControlPointCount, "'DomainShader' requires attribute 'ControlPointCount'.");
+DIAGNOSTIC(50052, Error, hullShaderRequiresControlPointCount, "'HullShader' requires attribute 'ControlPointCount'.")
+DIAGNOSTIC(50052, Error, hullShaderRequiresControlPointWorld, "'HullShader' requires attribute 'ControlPointWorld'.");
+DIAGNOSTIC(50052, Error, hullShaderRequiresCornerPointWorld, "'HullShader' requires attribute 'CornerPointWorld'.");
+DIAGNOSTIC(50052, Error, hullShaderRequiresDomain, "'HullShader' requires attribute 'Domain'.");
+DIAGNOSTIC(50052, Error, hullShaderRequiresInputControlPointCount, "'HullShader' requires attribute 'InputControlPointCount'.")
+DIAGNOSTIC(50052, Error, hullShaderRequiresOutputTopology, "'HullShader' requires attribute 'OutputTopology'.")
+DIAGNOSTIC(50052, Error, hullShaderRequiresPartitioning, "'HullShader' requires attribute 'Partitioning'.")
+DIAGNOSTIC(50052, Error, hullShaderRequiresPatchWorld, "'HullShader' requires attribute 'PatchWorld'.");
+DIAGNOSTIC(50052, Error, hullShaderRequiresTessLevelInner, "'HullShader' requires attribute 'TessLevelInner'.")
+DIAGNOSTIC(50052, Error, hullShaderRequiresTessLevelOuter, "'HullShader' requires attribute 'TessLevelOuter'.")
+
+DIAGNOSTIC(50053, Error, invalidTessellationDomian, "'Domain' should be either 'triangles' or 'quads'.");
+DIAGNOSTIC(50053, Error, invalidTessellationOutputTopology, "'OutputTopology' must be one of: 'point', 'line', 'triangle_cw', or 'triangle_ccw'.");
+DIAGNOSTIC(50053, Error, invalidTessellationPartitioning, "'Partitioning' must be one of: 'integer', 'pow2', 'fractional_even', or 'fractional_odd'.")
+DIAGNOSTIC(50053, Error, invalidTessellationDomain, "'Domain' should be either 'triangles' or 'quads'.")
+
+DIAGNOSTIC(50082, Error, importingFromPackedBufferUnsupported, "importing type '$0' from PackedBuffer is not supported by the GLSL backend.")
+DIAGNOSTIC(51090, Error, cannotGenerateCodeForExternComponentType, "cannot generate code for extern component type '$0'.")
+DIAGNOSTIC(51091, Error, typeCannotBePlacedInATexture, "type '$0' cannot be placed in a texture.")
+DIAGNOSTIC(51092, Error, stageDoesntHaveInputWorld, "'$0' doesn't appear to have any input world");
+
+
+// 99999 - Internal compiler errors, and not-yet-classified diagnostics.
+
+DIAGNOSTIC(99999, Internal, internalCompilerError, "internal compiler error")
+DIAGNOSTIC(99999, Internal, unimplemented, "unimplemented feature: $0")
+
+#undef DIAGNOSTIC
diff --git a/source/slang/diagnostics.cpp b/source/slang/diagnostics.cpp
new file mode 100644
index 000000000..d8527466b
--- /dev/null
+++ b/source/slang/diagnostics.cpp
@@ -0,0 +1,204 @@
+// Diagnostics.cpp
+#include "Diagnostics.h"
+
+#include "Syntax.h"
+
+#include <assert.h>
+
+#ifdef _WIN32
+#define WIN32_LEAN_AND_MEAN
+#define NOMINMAX
+#include <Windows.h>
+#undef WIN32_LEAN_AND_MEAN
+#undef NOMINMAX
+#include <d3dcompiler.h>
+#endif
+
+namespace Slang {
+namespace Compiler {
+
+void printDiagnosticArg(StringBuilder& sb, char const* str)
+{
+ sb << str;
+}
+
+void printDiagnosticArg(StringBuilder& sb, int str)
+{
+ sb << str;
+}
+
+void printDiagnosticArg(StringBuilder& sb, CoreLib::Basic::String const& str)
+{
+ sb << str;
+}
+
+void printDiagnosticArg(StringBuilder& sb, Decl* decl)
+{
+ sb << decl->Name.Content;
+}
+
+void printDiagnosticArg(StringBuilder& sb, Type* type)
+{
+ sb << type->DataType->ToString();
+}
+
+void printDiagnosticArg(StringBuilder& sb, ExpressionType* type)
+{
+ sb << type->ToString();
+}
+
+void printDiagnosticArg(StringBuilder& sb, TypeExp const& type)
+{
+ sb << type.type->ToString();
+}
+
+void printDiagnosticArg(StringBuilder& sb, QualType const& type)
+{
+ sb << type.type->ToString();
+}
+
+void printDiagnosticArg(StringBuilder& sb, TokenType tokenType)
+{
+ sb << TokenTypeToString(tokenType);
+}
+
+void printDiagnosticArg(StringBuilder& sb, Token const& token)
+{
+ sb << token.Content;
+}
+
+CodePosition const& getDiagnosticPos(SyntaxNode const* syntax)
+{
+ return syntax->Position;
+}
+
+CodePosition const& getDiagnosticPos(Token const& token)
+{
+ return token.Position;
+}
+
+CodePosition const& getDiagnosticPos(TypeExp const& typeExp)
+{
+ return typeExp.exp->Position;
+}
+
+// Take the format string for a diagnostic message, along with its arguments, and turn it into a
+static void formatDiagnosticMessage(StringBuilder& sb, char const* format, int argCount, DiagnosticArg const* const* args)
+{
+ char const* spanBegin = format;
+ for(;;)
+ {
+ char const* spanEnd = spanBegin;
+ while (int c = *spanEnd)
+ {
+ if (c == '$')
+ break;
+ spanEnd++;
+ }
+
+ sb.Append(spanBegin, int(spanEnd - spanBegin));
+ if (!*spanEnd)
+ return;
+
+ assert(*spanEnd == '$');
+ spanEnd++;
+ int d = *spanEnd++;
+ switch (d)
+ {
+ // A double dollar sign `$$` is used to emit a single `$`
+ case '$':
+ sb.Append('$');
+ break;
+
+ // A single digit means to emit the corresponding argument.
+ // TODO: support more than 10 arguments, and add options
+ // to control formatting, etc.
+ case '0': case '1': case '2': case '3': case '4':
+ case '5': case '6': case '7': case '8': case '9':
+ {
+ int index = d - '0';
+ if (index >= argCount)
+ {
+ // TODO(tfoley): figure out what a good policy will be for "panic" situations like this
+ throw InvalidOperationException("too few arguments for diagnostic message");
+ }
+ else
+ {
+ DiagnosticArg const* arg = args[index];
+ arg->printFunc(sb, arg->data);
+ }
+ }
+ break;
+
+ default:
+ throw InvalidOperationException("invalid diagnostic message format");
+ break;
+ }
+
+ spanBegin = spanEnd;
+ }
+}
+
+static void formatDiagnostic(
+ StringBuilder& sb,
+ Diagnostic const& diagnostic)
+{
+ sb << diagnostic.Position.FileName;
+ sb << "(";
+ sb << diagnostic.Position.Line;
+ sb << "): ";
+ sb << getSeverityName(diagnostic.severity);
+ sb << " ";
+ sb << diagnostic.ErrorID;
+ sb << ": ";
+ sb << diagnostic.Message;
+ sb << "\n";
+}
+
+void DiagnosticSink::diagnoseImpl(CodePosition const& pos, DiagnosticInfo const& info, int argCount, DiagnosticArg const* const* args)
+{
+ StringBuilder sb;
+ formatDiagnosticMessage(sb, info.messageFormat, argCount, args);
+
+ Diagnostic diagnostic;
+ diagnostic.ErrorID = info.id;
+ diagnostic.Message = sb.ProduceString();
+ diagnostic.Position = pos;
+ diagnostic.severity = info.severity;
+
+ if (diagnostic.severity >= Severity::Error)
+ {
+ errorCount++;
+ }
+
+ // Did the client supply a callback for us to use?
+ if( callback )
+ {
+ // If so, pass the error string along to them
+ StringBuilder sb;
+ formatDiagnostic(sb, diagnostic);
+
+ callback(sb.ProduceString().begin(), callbackUserData);
+ }
+ else
+ {
+ // If the user doesn't have a callback, then just
+ // collect our diagnostic messages into a buffer
+ formatDiagnostic(outputBuffer, diagnostic);
+ }
+
+ if (diagnostic.severity >= Severity::Fatal)
+ {
+ // TODO: figure out a better policy for aborting compilation
+ throw InvalidOperationException();
+ }
+}
+
+namespace Diagnostics
+{
+#define DIAGNOSTIC(id, severity, name, messageFormat) const DiagnosticInfo name = { id, Severity::severity, messageFormat };
+#include "diagnostic-defs.h"
+}
+
+
+}} // namespace Slang::Compiler
diff --git a/source/slang/diagnostics.h b/source/slang/diagnostics.h
new file mode 100644
index 000000000..c1559df5d
--- /dev/null
+++ b/source/slang/diagnostics.h
@@ -0,0 +1,218 @@
+#ifndef RASTER_RENDERER_COMPILE_ERROR_H
+#define RASTER_RENDERER_COMPILE_ERROR_H
+
+#include "../core/basic.h"
+
+#include "source-loc.h"
+#include "token.h"
+
+#include "../../slang.h"
+
+namespace Slang
+{
+ namespace Compiler
+ {
+ using namespace CoreLib::Basic;
+
+ enum class Severity
+ {
+ Note,
+ Warning,
+ Error,
+ Fatal,
+ Internal,
+ };
+
+ // TODO(tfoley): move this into a source file...
+ inline const char* getSeverityName(Severity severity)
+ {
+ switch (severity)
+ {
+ case Severity::Note: return "note";
+ case Severity::Warning: return "warning";
+ case Severity::Error: return "error";
+ case Severity::Fatal: return "fatal error";
+ case Severity::Internal: return "internal error";
+ default: return "unknown error";
+ }
+ }
+
+ // A structure to be used in static data describing different
+ // diagnostic messages.
+ struct DiagnosticInfo
+ {
+ int id;
+ Severity severity;
+ char const* messageFormat;
+ };
+
+ class Diagnostic
+ {
+ public:
+ String Message;
+ CodePosition Position;
+ int ErrorID;
+ Severity severity;
+
+ Diagnostic()
+ {
+ ErrorID = -1;
+ }
+ Diagnostic(
+ const String & msg,
+ int id,
+ const CodePosition & pos,
+ Severity severity)
+ : severity(severity)
+ {
+ Message = msg;
+ ErrorID = id;
+ Position = pos;
+ }
+ };
+
+ class Decl;
+ class Type;
+ class ExpressionType;
+ class ILType;
+ class StageAttribute;
+ struct TypeExp;
+ struct QualType;
+
+ void printDiagnosticArg(StringBuilder& sb, char const* str);
+ void printDiagnosticArg(StringBuilder& sb, int val);
+ void printDiagnosticArg(StringBuilder& sb, CoreLib::Basic::String const& str);
+ void printDiagnosticArg(StringBuilder& sb, Decl* decl);
+ void printDiagnosticArg(StringBuilder& sb, Type* type);
+ void printDiagnosticArg(StringBuilder& sb, ExpressionType* type);
+ void printDiagnosticArg(StringBuilder& sb, TypeExp const& type);
+ void printDiagnosticArg(StringBuilder& sb, QualType const& type);
+ void printDiagnosticArg(StringBuilder& sb, TokenType tokenType);
+ void printDiagnosticArg(StringBuilder& sb, Token const& token);
+
+ template<typename T>
+ void printDiagnosticArg(StringBuilder& sb, RefPtr<T> ptr)
+ {
+ printDiagnosticArg(sb, ptr.Ptr());
+ }
+
+ inline CodePosition const& getDiagnosticPos(CodePosition const& pos) { return pos; }
+
+ class SyntaxNode;
+ class ShaderClosure;
+ CodePosition const& getDiagnosticPos(SyntaxNode const* syntax);
+ CodePosition const& getDiagnosticPos(Token const& token);
+ CodePosition const& getDiagnosticPos(TypeExp const& typeExp);
+
+ template<typename T>
+ CodePosition getDiagnosticPos(RefPtr<T> const& ptr)
+ {
+ return getDiagnosticPos(ptr.Ptr());
+ }
+
+ struct DiagnosticArg
+ {
+ void* data;
+ void (*printFunc)(StringBuilder&, void*);
+
+ template<typename T>
+ struct Helper
+ {
+ static void printFunc(StringBuilder& sb, void* data) { printDiagnosticArg(sb, *(T*)data); }
+ };
+
+ template<typename T>
+ DiagnosticArg(T const& arg)
+ : data((void*)&arg)
+ , printFunc(&Helper<T>::printFunc)
+ {}
+ };
+
+ class DiagnosticSink
+ {
+ public:
+ StringBuilder outputBuffer;
+// List<Diagnostic> diagnostics;
+ int errorCount = 0;
+
+ SlangDiagnosticCallback callback = nullptr;
+ void* callbackUserData = nullptr;
+
+/*
+ void Error(int id, const String & msg, const CodePosition & pos)
+ {
+ diagnostics.Add(Diagnostic(msg, id, pos, Severity::Error));
+ errorCount++;
+ }
+
+ void Warning(int id, const String & msg, const CodePosition & pos)
+ {
+ diagnostics.Add(Diagnostic(msg, id, pos, Severity::Warning));
+ }
+*/
+ int GetErrorCount() { return errorCount; }
+
+ void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info)
+ {
+ diagnoseImpl(pos, info, 0, NULL);
+ }
+
+ void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0)
+ {
+ DiagnosticArg const* args[] = { &arg0 };
+ diagnoseImpl(pos, info, 1, args);
+ }
+
+ void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1)
+ {
+ DiagnosticArg const* args[] = { &arg0, &arg1 };
+ diagnoseImpl(pos, info, 2, args);
+ }
+
+ void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1, DiagnosticArg const& arg2)
+ {
+ DiagnosticArg const* args[] = { &arg0, &arg1, &arg2 };
+ diagnoseImpl(pos, info, 3, args);
+ }
+
+ void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1, DiagnosticArg const& arg2, DiagnosticArg const& arg3)
+ {
+ DiagnosticArg const* args[] = { &arg0, &arg1, &arg2, &arg3 };
+ diagnoseImpl(pos, info, 4, args);
+ }
+
+ template<typename P, typename... Args>
+ void diagnose(P const& pos, DiagnosticInfo const& info, Args const&... args )
+ {
+ diagnoseDispatch(getDiagnosticPos(pos), info, args...);
+ }
+
+ void diagnoseImpl(CodePosition const& pos, DiagnosticInfo const& info, int argCount, DiagnosticArg const* const* args);
+ };
+
+ namespace Diagnostics
+ {
+#define DIAGNOSTIC(id, severity, name, messageFormat) extern const DiagnosticInfo name;
+#include "diagnostic-defs.h"
+ }
+ }
+}
+
+#ifdef _DEBUG
+#define SLANG_INTERNAL_ERROR(sink, pos) \
+ (sink)->diagnose(Slang::Compiler::CodePosition(__LINE__, 0, 0, __FILE__), Slang::Compiler::Diagnostics::internalCompilerError)
+#define SLANG_UNIMPLEMENTED(sink, pos, what) \
+ (sink)->diagnose(Slang::Compiler::CodePosition(__LINE__, 0, 0, __FILE__), Slang::Compiler::Diagnostics::unimplemented, what)
+
+#define SLANG_UNREACHABLE(msg) do { assert(!"ureachable code:" msg); exit(1); } while(0)
+#else
+#define SLANG_INTERNAL_ERROR(sink, pos) \
+ (sink)->diagnose(pos, Slang::Compiler::Diagnostics::internalCompilerError)
+#define SLANG_UNIMPLEMENTED(sink, pos, what) \
+ (sink)->diagnose(pos, Slang::Compiler::Diagnostics::unimplemented, what)
+
+// TODO: find something that will perform better
+#define SLANG_UNREACHABLE(msg) exit(1)
+#endif
+
+#endif
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
new file mode 100644
index 000000000..6b488a68f
--- /dev/null
+++ b/source/slang/emit.cpp
@@ -0,0 +1,2537 @@
+// emit.cpp
+#include "emit.h"
+
+#include "syntax.h"
+#include "type-layout.h"
+
+#include <assert.h>
+
+#ifdef _WIN32
+#include <d3dcompiler.h>
+#pragma warning(disable:4996)
+#endif
+
+namespace Slang { namespace Compiler {
+
+struct EmitContext
+{
+ StringBuilder sb;
+
+ // Current source position for tracking purposes...
+ CodePosition loc;
+
+ // The target language we want to generate code for
+ CodeGenTarget target;
+
+ // A set of words reserved by the target
+ Dictionary<String, String> reservedWords;
+};
+
+//
+
+static void EmitDecl(EmitContext* context, RefPtr<Decl> decl);
+static void EmitDecl(EmitContext* context, RefPtr<DeclBase> declBase);
+static void EmitDeclUsingLayout(EmitContext* context, RefPtr<Decl> decl, RefPtr<VarLayout> layout);
+static void EmitType(EmitContext* context, RefPtr<ExpressionType> type, String const& name);
+static void EmitType(EmitContext* context, RefPtr<ExpressionType> type);
+static void EmitExpr(EmitContext* context, RefPtr<ExpressionSyntaxNode> expr);
+static void EmitStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt);
+static void EmitDeclRef(EmitContext* context, DeclRef declRef);
+
+// Low-level emit logic
+
+static void emitRawTextSpan(EmitContext* context, char const* textBegin, char const* textEnd)
+{
+ // TODO(tfoley): Need to make "corelib" not use `int` for pointer-sized things...
+ auto len = int(textEnd - textBegin);
+
+ context->sb.Append(textBegin, len);
+}
+
+static void emitRawText(EmitContext* context, char const* text)
+{
+ emitRawTextSpan(context, text, text + strlen(text));
+}
+
+static void emitTextSpan(EmitContext* context, char const* textBegin, char const* textEnd)
+{
+ // Emit the raw text
+ emitRawTextSpan(context, textBegin, textEnd);
+
+ // Update our logical position
+ // TODO(tfoley): Need to make "corelib" not use `int` for pointer-sized things...
+ auto len = int(textEnd - textBegin);
+ context->loc.Col += len;
+}
+
+static void Emit(EmitContext* context, char const* textBegin, char const* textEnd)
+{
+ char const* spanBegin = textBegin;
+
+ char const* spanEnd = spanBegin;
+ for(;;)
+ {
+ if(spanEnd == textEnd)
+ {
+ // We have a whole range of text waiting to be flushed
+ emitTextSpan(context, spanBegin, spanEnd);
+ return;
+ }
+
+ auto c = *spanEnd++;
+
+ if( c == '\n' )
+ {
+ // At the end of a line, we need to update our tracking
+ // information on code positions
+ emitTextSpan(context, spanBegin, spanEnd);
+ context->loc.Line++;
+ context->loc.Col = 1;
+
+ // Start a new span for emit purposes
+ spanBegin = spanEnd;
+ }
+ }
+}
+
+static void Emit(EmitContext* context, char const* text)
+{
+ Emit(context, text, text + strlen(text));
+}
+
+static void emit(EmitContext* context, String const& text)
+{
+ Emit(context, text.begin(), text.end());
+}
+
+static bool isReservedWord(EmitContext* context, String const& name)
+{
+ return context->reservedWords.TryGetValue(name) != nullptr;
+}
+
+static void emitName(EmitContext* context, String const& inName)
+{
+ String name = inName;
+
+ // By default, we would like to emit a name in the generated
+ // code exactly as it appeared in the soriginal program.
+ // When that isn't possible, we'd like to emit a name as
+ // close to the original as possible (to ensure that existing
+ // debugging tools still work reasonably well).
+ //
+ // One reason why a name might not be allowed as-is is that
+ // it could collide with a reserved word in the target language.
+ // Another reason is that it might not follow a naming convention
+ // imposed by the target (e.g., in GLSL names starting with
+ // `gl_` or containing `__` are reserved).
+ //
+ // Given a name that should not be allowed, we want to
+ // change it to a name that *is* allowed. e.g., by adding
+ // `_` to the end of a reserved word.
+ //
+ // The next problem this creates is that the modified name
+ // could not collide with an existing use of the same
+ // (valid) name.
+ //
+ // For now we are going to solve this problem in a simple
+ // and ad hoc fashion, but longer term we'll want to do
+ // something sytematic.
+
+ if (isReservedWord(context, name))
+ {
+ name = name + "_";
+ }
+
+ emit(context, name);
+}
+
+static void Emit(EmitContext* context, UInt value)
+{
+ char buffer[32];
+ sprintf(buffer, "%llu", (unsigned long long)(value));
+ Emit(context, buffer);
+}
+
+static void Emit(EmitContext* context, int value)
+{
+ char buffer[16];
+ sprintf(buffer, "%d", value);
+ Emit(context, buffer);
+}
+
+static void Emit(EmitContext* context, double value)
+{
+ // TODO(tfoley): need to print things in a way that can round-trip
+ char buffer[128];
+ sprintf(buffer, "%.20ff", value);
+ Emit(context, buffer);
+}
+
+// Expressions
+
+// Determine if an expression should not be emitted when it is the base of
+// a member reference expression.
+static bool IsBaseExpressionImplicit(EmitContext* /*context*/, RefPtr<ExpressionSyntaxNode> expr)
+{
+ // HACK(tfoley): For now, anything with a constant-buffer type should be
+ // left implicit.
+
+ // Look through any dereferencing that took place
+ RefPtr<ExpressionSyntaxNode> e = expr;
+ while (auto derefExpr = e.As<DerefExpr>())
+ {
+ e = derefExpr->base;
+ }
+ // Is the expression referencing a constant buffer?
+ if (auto cbufferType = e->Type->As<ConstantBufferType>())
+ {
+ return true;
+ }
+
+ return false;
+}
+
+enum
+{
+ kPrecedence_None,
+ kPrecedence_Comma,
+
+ kPrecedence_Assign,
+ kPrecedence_AddAssign = kPrecedence_Assign,
+ kPrecedence_SubAssign = kPrecedence_Assign,
+ kPrecedence_MulAssign = kPrecedence_Assign,
+ kPrecedence_DivAssign = kPrecedence_Assign,
+ kPrecedence_ModAssign = kPrecedence_Assign,
+ kPrecedence_LshAssign = kPrecedence_Assign,
+ kPrecedence_RshAssign = kPrecedence_Assign,
+ kPrecedence_OrAssign = kPrecedence_Assign,
+ kPrecedence_AndAssign = kPrecedence_Assign,
+ kPrecedence_XorAssign = kPrecedence_Assign,
+
+ kPrecedence_General = kPrecedence_Assign,
+
+ kPrecedence_Conditional, // "ternary"
+ kPrecedence_Or,
+ kPrecedence_And,
+ kPrecedence_BitOr,
+ kPrecedence_BitXor,
+ kPrecedence_BitAnd,
+
+ kPrecedence_Eql,
+ kPrecedence_Neq = kPrecedence_Eql,
+
+ kPrecedence_Less,
+ kPrecedence_Greater = kPrecedence_Less,
+ kPrecedence_Leq = kPrecedence_Less,
+ kPrecedence_Geq = kPrecedence_Less,
+
+ kPrecedence_Lsh,
+ kPrecedence_Rsh = kPrecedence_Lsh,
+
+ kPrecedence_Add,
+ kPrecedence_Sub = kPrecedence_Add,
+
+ kPrecedence_Mul,
+ kPrecedence_Div = kPrecedence_Mul,
+ kPrecedence_Mod = kPrecedence_Mul,
+
+ kPrecedence_Prefix,
+ kPrecedence_Postfix,
+ kPrecedence_Atomic = kPrecedence_Postfix
+};
+
+static void EmitExprWithPrecedence(EmitContext* context, RefPtr<ExpressionSyntaxNode> expr, int outerPrec);
+
+static void EmitPostfixExpr(EmitContext* context, RefPtr<ExpressionSyntaxNode> expr)
+{
+ EmitExprWithPrecedence(context, expr, kPrecedence_Postfix);
+}
+
+static void EmitExpr(EmitContext* context, RefPtr<ExpressionSyntaxNode> expr)
+{
+ EmitExprWithPrecedence(context, expr, kPrecedence_General);
+}
+
+static bool MaybeEmitParens(EmitContext* context, int outerPrec, int prec)
+{
+ if (prec <= outerPrec)
+ {
+ Emit(context, "(");
+ return true;
+ }
+ return false;
+}
+
+// When we are going to emit an expression in an l-value context,
+// we may need to ignore certain constructs that the type-checker
+// might have introduced, but which interfere with our ability
+// to use it effectively in the target language
+static RefPtr<ExpressionSyntaxNode> prepareLValueExpr(
+ EmitContext* context,
+ RefPtr<ExpressionSyntaxNode> expr)
+{
+ for(;;)
+ {
+ if(auto typeCastExpr = expr.As<TypeCastExpressionSyntaxNode>())
+ {
+ expr = typeCastExpr->Expression;
+ }
+ // TODO: any other cases?
+ else
+ {
+ return expr;
+ }
+ }
+
+}
+
+static void emitInfixExprImpl(
+ EmitContext* context,
+ int outerPrec,
+ int prec,
+ char const* op,
+ RefPtr<InvokeExpressionSyntaxNode> binExpr,
+ bool isAssign)
+{
+ bool needsClose = MaybeEmitParens(context, outerPrec, prec);
+
+ auto left = binExpr->Arguments[0];
+ if(isAssign)
+ {
+ left = prepareLValueExpr(context, left);
+ }
+
+ EmitExprWithPrecedence(context, left, prec);
+ Emit(context, " ");
+ Emit(context, op);
+ Emit(context, " ");
+ EmitExprWithPrecedence(context, binExpr->Arguments[1], prec);
+ if (needsClose)
+ {
+ Emit(context, ")");
+ }
+}
+
+static void EmitBinExpr(EmitContext* context, int outerPrec, int prec, char const* op, RefPtr<InvokeExpressionSyntaxNode> binExpr)
+{
+ emitInfixExprImpl(context, outerPrec, prec, op, binExpr, false);
+}
+
+static void EmitBinAssignExpr(EmitContext* context, int outerPrec, int prec, char const* op, RefPtr<InvokeExpressionSyntaxNode> binExpr)
+{
+ emitInfixExprImpl(context, outerPrec, prec, op, binExpr, true);
+}
+
+static void emitUnaryExprImpl(
+ EmitContext* context,
+ int outerPrec,
+ int prec,
+ char const* preOp,
+ char const* postOp,
+ RefPtr<InvokeExpressionSyntaxNode> expr,
+ bool isAssign)
+{
+ bool needsClose = MaybeEmitParens(context, outerPrec, prec);
+ Emit(context, preOp);
+
+ auto arg = expr->Arguments[0];
+ if(isAssign)
+ {
+ arg = prepareLValueExpr(context, arg);
+ }
+
+ EmitExprWithPrecedence(context, arg, prec);
+ Emit(context, postOp);
+ if (needsClose)
+ {
+ Emit(context, ")");
+ }
+}
+
+static void EmitUnaryExpr(
+ EmitContext* context,
+ int outerPrec,
+ int prec,
+ char const* preOp,
+ char const* postOp,
+ RefPtr<InvokeExpressionSyntaxNode> expr)
+{
+ emitUnaryExprImpl(context, outerPrec, prec, preOp, postOp, expr, false);
+}
+
+static void EmitUnaryAssignExpr(
+ EmitContext* context,
+ int outerPrec,
+ int prec,
+ char const* preOp,
+ char const* postOp,
+ RefPtr<InvokeExpressionSyntaxNode> expr)
+{
+ emitUnaryExprImpl(context, outerPrec, prec, preOp, postOp, expr, true);
+}
+
+static void emitCallExpr(
+ EmitContext* context,
+ RefPtr<InvokeExpressionSyntaxNode> callExpr,
+ int outerPrec)
+{
+ auto funcExpr = callExpr->FunctionExpr;
+ if (auto funcDeclRefExpr = funcExpr.As<DeclRefExpr>())
+ {
+ auto funcDeclRef = funcDeclRefExpr->declRef;
+ auto funcDecl = funcDeclRef.GetDecl();
+ if (auto intrinsicModifier = funcDecl->FindModifier<IntrinsicModifier>())
+ {
+ switch (intrinsicModifier->op)
+ {
+#define CASE(NAME, OP) case IntrinsicOp::NAME: EmitBinExpr(context, outerPrec, kPrecedence_##NAME, #OP, callExpr); return
+ CASE(Mul, *);
+ CASE(Div, / );
+ CASE(Mod, %);
+ CASE(Add, +);
+ CASE(Sub, -);
+ CASE(Lsh, << );
+ CASE(Rsh, >> );
+ CASE(Eql, == );
+ CASE(Neq, != );
+ CASE(Greater, >);
+ CASE(Less, <);
+ CASE(Geq, >= );
+ CASE(Leq, <= );
+ CASE(BitAnd, &);
+ CASE(BitXor, ^);
+ CASE(BitOr, | );
+ CASE(And, &&);
+ CASE(Or, || );
+#undef CASE
+
+#define CASE(NAME, OP) case IntrinsicOp::NAME: EmitBinAssignExpr(context, outerPrec, kPrecedence_##NAME, #OP, callExpr); return
+ CASE(Assign, =);
+ CASE(AddAssign, +=);
+ CASE(SubAssign, -=);
+ CASE(MulAssign, *=);
+ CASE(DivAssign, /=);
+ CASE(ModAssign, %=);
+ CASE(LshAssign, <<=);
+ CASE(RshAssign, >>=);
+ CASE(OrAssign, |=);
+ CASE(AndAssign, &=);
+ CASE(XorAssign, ^=);
+#undef CASE
+
+ case IntrinsicOp::Sequence: EmitBinExpr(context, outerPrec, kPrecedence_Comma, ",", callExpr); return;
+
+#define CASE(NAME, OP) case IntrinsicOp::NAME: EmitUnaryExpr(context, outerPrec, kPrecedence_Prefix, #OP, "", callExpr); return
+ CASE(Neg, -);
+ CASE(Not, !);
+ CASE(BitNot, ~);
+#undef CASE
+
+#define CASE(NAME, OP) case IntrinsicOp::NAME: EmitUnaryAssignExpr(context, outerPrec, kPrecedence_Prefix, #OP, "", callExpr); return
+ CASE(PreInc, ++);
+ CASE(PreDec, --);
+#undef CASE
+
+#define CASE(NAME, OP) case IntrinsicOp::NAME: EmitUnaryAssignExpr(context, outerPrec, kPrecedence_Postfix, "", #OP, callExpr); return
+ CASE(PostInc, ++);
+ CASE(PostDec, --);
+#undef CASE
+
+ case IntrinsicOp::InnerProduct_Vector_Vector:
+ // HLSL allows `mul()` to be used as a synonym for `dot()`,
+ // so we need to translate to `dot` for GLSL
+ if (context->target == CodeGenTarget::GLSL)
+ {
+ Emit(context, "dot(");
+ EmitExpr(context, callExpr->Arguments[0]);
+ Emit(context, ", ");
+ EmitExpr(context, callExpr->Arguments[1]);
+ Emit(context, ")");
+ return;
+ }
+ break;
+
+ case IntrinsicOp::InnerProduct_Matrix_Matrix:
+ case IntrinsicOp::InnerProduct_Matrix_Vector:
+ case IntrinsicOp::InnerProduct_Vector_Matrix:
+ // HLSL exposes these with the `mul()` function, while GLSL uses ordinary
+ // `operator*`.
+ //
+ // The other critical detail here is that the way we handle matrix
+ // conventions requires that the operands to the product be swapped.
+ if (context->target == CodeGenTarget::GLSL)
+ {
+ Emit(context, "((");
+ EmitExpr(context, callExpr->Arguments[1]);
+ Emit(context, ") * (");
+ EmitExpr(context, callExpr->Arguments[0]);
+ Emit(context, "))");
+ return;
+ }
+ break;
+
+ default:
+ break;
+ }
+
+
+ // We might be calling an intrinsic subscript operation,
+ // and should desugar it accordingly
+ if(auto subscriptDeclRef = funcDeclRef.As<SubscriptDeclRef>())
+ {
+ // We expect any subscript operation to be invoked as a member,
+ // so the function expression had better be in the correct form.
+ if(auto memberExpr = funcExpr.As<MemberExpressionSyntaxNode>())
+ {
+
+ Emit(context, "(");
+ EmitExpr(context, memberExpr->BaseExpression);
+ Emit(context, ")[");
+ int argCount = callExpr->Arguments.Count();
+ for (int aa = 0; aa < argCount; ++aa)
+ {
+ if (aa != 0) Emit(context, ", ");
+ EmitExpr(context, callExpr->Arguments[aa]);
+ }
+ Emit(context, "]");
+ return;
+ }
+ }
+ }
+ }
+
+ // Fall through to default handling...
+
+ bool needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Postfix);
+
+ if (auto funcDeclRefExpr = funcExpr.As<DeclRefExpr>())
+ {
+ auto declRef = funcDeclRefExpr->declRef;
+ if (auto ctorDeclRef = declRef.As<ConstructorDeclRef>())
+ {
+ // We really want to emit a reference to the type begin constructed
+ EmitType(context, callExpr->Type);
+ }
+ else
+ {
+ // default case: just emit the decl ref
+ EmitExpr(context, funcExpr);
+ }
+ }
+ else
+ {
+ // default case: just emit the expression
+ EmitPostfixExpr(context, funcExpr);
+ }
+
+ Emit(context, "(");
+ int argCount = callExpr->Arguments.Count();
+ for (int aa = 0; aa < argCount; ++aa)
+ {
+ if (aa != 0) Emit(context, ", ");
+ EmitExpr(context, callExpr->Arguments[aa]);
+ }
+ Emit(context, ")");
+
+ if (needClose)
+ {
+ Emit(context, ")");
+ }
+}
+
+static void EmitExprWithPrecedence(EmitContext* context, RefPtr<ExpressionSyntaxNode> expr, int outerPrec)
+{
+ bool needClose = false;
+ if (auto selectExpr = expr.As<SelectExpressionSyntaxNode>())
+ {
+ needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Conditional);
+
+ EmitExprWithPrecedence(context, selectExpr->Arguments[0], kPrecedence_Conditional);
+ Emit(context, " ? ");
+ EmitExprWithPrecedence(context, selectExpr->Arguments[1], kPrecedence_Conditional);
+ Emit(context, " : ");
+ EmitExprWithPrecedence(context, selectExpr->Arguments[2], kPrecedence_Conditional);
+ }
+ else if (auto callExpr = expr.As<InvokeExpressionSyntaxNode>())
+ {
+ emitCallExpr(context, callExpr, outerPrec);
+ }
+ else if (auto memberExpr = expr.As<MemberExpressionSyntaxNode>())
+ {
+ needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Postfix);
+
+ // TODO(tfoley): figure out a good way to reference
+ // declarations that might be generic and/or might
+ // not be generated as lexically nested declarations...
+
+ // TODO(tfoley): also, probably need to special case
+ // this for places where we are using a built-in...
+
+ auto base = memberExpr->BaseExpression;
+ if (IsBaseExpressionImplicit(context, base))
+ {
+ // don't emit the base expression
+ }
+ else
+ {
+ EmitExprWithPrecedence(context, memberExpr->BaseExpression, kPrecedence_Postfix);
+ Emit(context, ".");
+ }
+
+ emitName(context, memberExpr->declRef.GetName());
+ }
+ else if (auto swizExpr = expr.As<SwizzleExpr>())
+ {
+ needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Postfix);
+
+ EmitExprWithPrecedence(context, swizExpr->base, kPrecedence_Postfix);
+ Emit(context, ".");
+ static const char* kComponentNames[] = { "x", "y", "z", "w" };
+ int elementCount = swizExpr->elementCount;
+ for (int ee = 0; ee < elementCount; ++ee)
+ {
+ Emit(context, kComponentNames[swizExpr->elementIndices[ee]]);
+ }
+ }
+ else if (auto indexExpr = expr.As<IndexExpressionSyntaxNode>())
+ {
+ needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Postfix);
+
+ EmitExprWithPrecedence(context, indexExpr->BaseExpression, kPrecedence_Postfix);
+ Emit(context, "[");
+ EmitExpr(context, indexExpr->IndexExpression);
+ Emit(context, "]");
+ }
+ else if (auto varExpr = expr.As<VarExpressionSyntaxNode>())
+ {
+ needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Atomic);
+
+ // Because of the "rewriter" use case, it is possible that we will
+ // be trying to emit an expression that hasn't been wired up to
+ // any associated declaration. In that case, we will just emit
+ // the variable name.
+ //
+ // TODO: A better long-term solution here is to have a distinct
+ // case for an "unchecked" `NameExpr` that doesn't include
+ // a declaration reference.
+
+ if(varExpr->declRef)
+ {
+ EmitDeclRef(context, varExpr->declRef);
+ }
+ else
+ {
+ emitName(context, varExpr->Variable);
+ }
+ }
+ else if (auto derefExpr = expr.As<DerefExpr>())
+ {
+ // TODO(tfoley): dereference shouldn't always be implicit
+ EmitExprWithPrecedence(context, derefExpr->base, outerPrec);
+ }
+ else if (auto litExpr = expr.As<ConstantExpressionSyntaxNode>())
+ {
+ needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Atomic);
+
+ switch (litExpr->ConstType)
+ {
+ case ConstantExpressionSyntaxNode::ConstantType::Int:
+ Emit(context, litExpr->IntValue);
+ break;
+ case ConstantExpressionSyntaxNode::ConstantType::Float:
+ Emit(context, litExpr->FloatValue);
+ break;
+ case ConstantExpressionSyntaxNode::ConstantType::Bool:
+ Emit(context, litExpr->IntValue ? "true" : "false");
+ break;
+ default:
+ assert(!"unreachable");
+ break;
+ }
+ }
+ else if (auto castExpr = expr.As<TypeCastExpressionSyntaxNode>())
+ {
+ switch(context->target)
+ {
+ case CodeGenTarget::GLSL:
+ // GLSL requires constructor syntax for all conversions
+ EmitType(context, castExpr->Type);
+ Emit(context, "(");
+ EmitExpr(context, castExpr->Expression);
+ Emit(context, ")");
+ break;
+
+ default:
+ // HLSL (and C/C++) prefer cast syntax
+ // (In fact, HLSL doesn't allow constructor syntax for some conversions it allows as a cast)
+ needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Prefix);
+
+ Emit(context, "(");
+ EmitType(context, castExpr->Type);
+ Emit(context, ")(");
+ EmitExpr(context, castExpr->Expression);
+ Emit(context, ")");
+ break;
+ }
+
+ }
+ else if(auto initExpr = expr.As<InitializerListExpr>())
+ {
+ Emit(context, "{ ");
+ for(auto& arg : initExpr->args)
+ {
+ EmitExpr(context, arg);
+ Emit(context, ", ");
+ }
+ Emit(context, "}");
+ }
+ else
+ {
+ throw "unimplemented";
+ }
+ if (needClose)
+ {
+ Emit(context, ")");
+ }
+}
+
+// Types
+
+void Emit(EmitContext* context, RefPtr<IntVal> val)
+{
+ if(auto constantIntVal = val.As<ConstantIntVal>())
+ {
+ Emit(context, constantIntVal->value);
+ }
+ else if(auto varRefVal = val.As<GenericParamIntVal>())
+ {
+ EmitDeclRef(context, varRefVal->declRef);
+ }
+ else
+ {
+ assert(!"unimplemented");
+ }
+}
+
+// represents a declarator for use in emitting types
+struct EDeclarator
+{
+ enum class Flavor
+ {
+ Name,
+ Array,
+ UnsizedArray,
+ };
+ Flavor flavor;
+ EDeclarator* next = nullptr;
+
+ // Used for `Flavor::Name`
+ String name;
+
+ // Used for `Flavor::Array`
+ IntVal* elementCount;
+};
+
+static void EmitDeclarator(EmitContext* context, EDeclarator* declarator)
+{
+ if (!declarator) return;
+
+ Emit(context, " ");
+
+ switch (declarator->flavor)
+ {
+ case EDeclarator::Flavor::Name:
+ emitName(context, declarator->name);
+ break;
+
+ case EDeclarator::Flavor::Array:
+ EmitDeclarator(context, declarator->next);
+ Emit(context, "[");
+ if(auto elementCount = declarator->elementCount)
+ {
+ Emit(context, elementCount);
+ }
+ Emit(context, "]");
+ break;
+
+ case EDeclarator::Flavor::UnsizedArray:
+ EmitDeclarator(context, declarator->next);
+ Emit(context, "[]");
+ break;
+
+ default:
+ assert(!"unreachable");
+ break;
+ }
+}
+
+static void emitGLSLTypePrefix(
+ EmitContext* context,
+ RefPtr<ExpressionType> type)
+{
+ if(auto basicElementType = type->As<BasicExpressionType>())
+ {
+ switch (basicElementType->BaseType)
+ {
+ case BaseType::Float:
+ // no prefix
+ break;
+
+ case BaseType::Int: Emit(context, "i"); break;
+ case BaseType::UInt: Emit(context, "u"); break;
+ case BaseType::Bool: Emit(context, "b"); break;
+ default:
+ assert(!"unreachable");
+ break;
+ }
+ }
+ else if(auto vectorType = type->As<VectorExpressionType>())
+ {
+ emitGLSLTypePrefix(context, vectorType->elementType);
+ }
+ else if(auto matrixType = type->As<MatrixExpressionType>())
+ {
+ emitGLSLTypePrefix(context, matrixType->getElementType());
+ }
+ else
+ {
+ assert(!"unreachable");
+ }
+}
+
+static void emitHLSLTextureType(
+ EmitContext* context,
+ RefPtr<TextureTypeBase> texType)
+{
+ switch(texType->getAccess())
+ {
+ case SLANG_RESOURCE_ACCESS_READ:
+ break;
+
+ case SLANG_RESOURCE_ACCESS_READ_WRITE:
+ Emit(context, "RW");
+ break;
+
+ case SLANG_RESOURCE_ACCESS_RASTER_ORDERED:
+ Emit(context, "RasterizerOrdered");
+ break;
+
+ case SLANG_RESOURCE_ACCESS_APPEND:
+ Emit(context, "Append");
+ break;
+
+ case SLANG_RESOURCE_ACCESS_CONSUME:
+ Emit(context, "Consume");
+ break;
+
+ default:
+ assert(!"unreachable");
+ break;
+ }
+
+ switch (texType->GetBaseShape())
+ {
+ case TextureType::Shape1D: Emit(context, "Texture1D"); break;
+ case TextureType::Shape2D: Emit(context, "Texture2D"); break;
+ case TextureType::Shape3D: Emit(context, "Texture3D"); break;
+ case TextureType::ShapeCube: Emit(context, "TextureCube"); break;
+ default:
+ assert(!"unreachable");
+ break;
+ }
+
+ if (texType->isMultisample())
+ {
+ Emit(context, "MS");
+ }
+ if (texType->isArray())
+ {
+ Emit(context, "Array");
+ }
+ Emit(context, "<");
+ EmitType(context, texType->elementType);
+ Emit(context, ">");
+}
+
+static void emitGLSLTextureOrTextureSamplerType(
+ EmitContext* context,
+ RefPtr<TextureTypeBase> type,
+ char const* baseName)
+{
+ emitGLSLTypePrefix(context, type->elementType);
+
+ Emit(context, baseName);
+ switch (type->GetBaseShape())
+ {
+ case TextureType::Shape1D: Emit(context, "1D"); break;
+ case TextureType::Shape2D: Emit(context, "2D"); break;
+ case TextureType::Shape3D: Emit(context, "3D"); break;
+ case TextureType::ShapeCube: Emit(context, "Cube"); break;
+ default:
+ assert(!"unreachable");
+ break;
+ }
+
+ if (type->isMultisample())
+ {
+ Emit(context, "MS");
+ }
+ if (type->isArray())
+ {
+ Emit(context, "Array");
+ }
+}
+
+static void emitGLSLTextureType(
+ EmitContext* context,
+ RefPtr<TextureType> texType)
+{
+ emitGLSLTextureOrTextureSamplerType(context, texType, "texture");
+}
+
+static void emitGLSLTextureSamplerType(
+ EmitContext* context,
+ RefPtr<TextureSamplerType> type)
+{
+ emitGLSLTextureOrTextureSamplerType(context, type, "sampler");
+}
+
+static void emitGLSLImageType(
+ EmitContext* context,
+ RefPtr<GLSLImageType> type)
+{
+ emitGLSLTextureOrTextureSamplerType(context, type, "image");
+}
+
+static void emitTextureType(
+ EmitContext* context,
+ RefPtr<TextureType> texType)
+{
+ switch(context->target)
+ {
+ case CodeGenTarget::HLSL:
+ emitHLSLTextureType(context, texType);
+ break;
+
+ case CodeGenTarget::GLSL:
+ emitGLSLTextureType(context, texType);
+ break;
+
+ default:
+ assert(!"unreachable");
+ break;
+ }
+}
+
+static void emitTextureSamplerType(
+ EmitContext* context,
+ RefPtr<TextureSamplerType> type)
+{
+ switch(context->target)
+ {
+ case CodeGenTarget::GLSL:
+ emitGLSLTextureSamplerType(context, type);
+ break;
+
+ default:
+ assert(!"unreachable");
+ break;
+ }
+}
+
+static void emitImageType(
+ EmitContext* context,
+ RefPtr<GLSLImageType> type)
+{
+ switch(context->target)
+ {
+ case CodeGenTarget::HLSL:
+ emitHLSLTextureType(context, type);
+ break;
+
+ case CodeGenTarget::GLSL:
+ emitGLSLImageType(context, type);
+ break;
+
+ default:
+ assert(!"unreachable");
+ break;
+ }
+}
+
+static void EmitType(EmitContext* context, RefPtr<ExpressionType> type, EDeclarator* declarator)
+{
+ if (auto basicType = type->As<BasicExpressionType>())
+ {
+ switch (basicType->BaseType)
+ {
+ case BaseType::Void: Emit(context, "void"); break;
+ case BaseType::Int: Emit(context, "int"); break;
+ case BaseType::Float: Emit(context, "float"); break;
+ case BaseType::UInt: Emit(context, "uint"); break;
+ case BaseType::Bool: Emit(context, "bool"); break;
+ default:
+ assert(!"unreachable");
+ break;
+ }
+
+ EmitDeclarator(context, declarator);
+ return;
+ }
+ else if (auto vecType = type->As<VectorExpressionType>())
+ {
+ switch(context->target)
+ {
+ case CodeGenTarget::GLSL:
+ case CodeGenTarget::GLSL_Vulkan:
+ case CodeGenTarget::GLSL_Vulkan_OneDesc:
+ {
+ emitGLSLTypePrefix(context, vecType->elementType);
+ Emit(context, "vec");
+ Emit(context, vecType->elementCount);
+ }
+ break;
+
+ case CodeGenTarget::HLSL:
+ // TODO(tfoley): should really emit these with sugar
+ Emit(context, "vector<");
+ EmitType(context, vecType->elementType);
+ Emit(context, ",");
+ Emit(context, vecType->elementCount);
+ Emit(context, ">");
+ break;
+
+ default:
+ assert(!"unreachable");
+ break;
+ }
+
+ Emit(context, " ");
+ EmitDeclarator(context, declarator);
+ return;
+ }
+ else if (auto matType = type->As<MatrixExpressionType>())
+ {
+ switch(context->target)
+ {
+ case CodeGenTarget::GLSL:
+ case CodeGenTarget::GLSL_Vulkan:
+ case CodeGenTarget::GLSL_Vulkan_OneDesc:
+ {
+ emitGLSLTypePrefix(context, matType->getElementType());
+ Emit(context, "mat");
+ Emit(context, matType->getRowCount());
+ // TODO(tfoley): only emit the next bit
+ // for non-square matrix
+ Emit(context, "x");
+ Emit(context, matType->getColumnCount());
+ }
+ break;
+
+ case CodeGenTarget::HLSL:
+ // TODO(tfoley): should really emit these with sugar
+ Emit(context, "matrix<");
+ EmitType(context, matType->getElementType());
+ Emit(context, ",");
+ Emit(context, matType->getRowCount());
+ Emit(context, ",");
+ Emit(context, matType->getColumnCount());
+ Emit(context, "> ");
+ break;
+
+ default:
+ assert(!"unreachable");
+ break;
+ }
+
+ Emit(context, " ");
+ EmitDeclarator(context, declarator);
+ return;
+ }
+ else if (auto texType = type->As<TextureType>())
+ {
+ emitTextureType(context, texType);
+ Emit(context, " ");
+ EmitDeclarator(context, declarator);
+ return;
+ }
+ else if (auto textureSamplerType = type->As<TextureSamplerType>())
+ {
+ emitTextureSamplerType(context, textureSamplerType);
+ Emit(context, " ");
+ EmitDeclarator(context, declarator);
+ return;
+ }
+ else if (auto imageType = type->As<GLSLImageType>())
+ {
+ emitImageType(context, imageType);
+ Emit(context, " ");
+ EmitDeclarator(context, declarator);
+ return;
+ }
+ else if (auto samplerStateType = type->As<SamplerStateType>())
+ {
+ switch(context->target)
+ {
+ case CodeGenTarget::HLSL:
+ default:
+ switch (samplerStateType->flavor)
+ {
+ case SamplerStateType::Flavor::SamplerState: Emit(context, "SamplerState"); break;
+ case SamplerStateType::Flavor::SamplerComparisonState: Emit(context, "SamplerComparisonState"); break;
+ default:
+ assert(!"unreachable");
+ break;
+ }
+ break;
+
+ case CodeGenTarget::GLSL:
+ Emit(context, "sampler");
+ break;
+ }
+
+
+ EmitDeclarator(context, declarator);
+ return;
+ }
+ else if (auto declRefType = type->As<DeclRefType>())
+ {
+ EmitDeclRef(context, declRefType->declRef);
+
+ EmitDeclarator(context, declarator);
+ return;
+ }
+ else if( auto arrayType = type->As<ArrayExpressionType>() )
+ {
+ EDeclarator arrayDeclarator;
+ arrayDeclarator.next = declarator;
+
+ if(arrayType->ArrayLength)
+ {
+ arrayDeclarator.flavor = EDeclarator::Flavor::Array;
+ arrayDeclarator.elementCount = arrayType->ArrayLength.Ptr();
+ }
+ else
+ {
+ arrayDeclarator.flavor = EDeclarator::Flavor::UnsizedArray;
+ }
+
+
+ EmitType(context, arrayType->BaseType, &arrayDeclarator);
+ return;
+ }
+
+ throw "unimplemented";
+}
+
+static void EmitType(EmitContext* context, RefPtr<ExpressionType> type, String const& name)
+{
+ EDeclarator nameDeclarator;
+ nameDeclarator.flavor = EDeclarator::Flavor::Name;
+ nameDeclarator.name = name;
+ EmitType(context, type, &nameDeclarator);
+}
+
+static void EmitType(EmitContext* context, RefPtr<ExpressionType> type)
+{
+ EmitType(context, type, nullptr);
+}
+
+// Statements
+
+// Emit a statement as a `{}`-enclosed block statement, but avoid adding redundant
+// curly braces if the statement is itself a block statement.
+static void EmitBlockStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt)
+{
+ // TODO(tfoley): support indenting
+ Emit(context, "{\n");
+ if( auto blockStmt = stmt.As<BlockStatementSyntaxNode>() )
+ {
+ for (auto s : blockStmt->Statements)
+ {
+ EmitStmt(context, s);
+ }
+ }
+ else
+ {
+ EmitStmt(context, stmt);
+ }
+ Emit(context, "}\n");
+}
+
+static void EmitLoopAttributes(EmitContext* context, RefPtr<StatementSyntaxNode> decl)
+{
+ // TODO(tfoley): There really ought to be a semantic checking step for attributes,
+ // that turns abstract syntax into a concrete hierarchy of attribute types (e.g.,
+ // a specific `LoopModifier` or `UnrollModifier`).
+
+ for(auto attr : decl->GetModifiersOfType<HLSLUncheckedAttribute>())
+ {
+ if(attr->nameToken.Content == "loop")
+ {
+ Emit(context, "[loop]");
+ }
+ else if(attr->nameToken.Content == "unroll")
+ {
+ Emit(context, "[unroll]");
+ }
+ }
+}
+
+static void advanceToSourceLocation(
+ EmitContext* context,
+ CodePosition const& sourceLocation)
+{
+ // If we are currently emitting code at a source location with
+ // a differnet file or line, *or* if the source location is
+ // somehow later on the line than what we want to emit,
+ // then we need to emit a new `#line` directive.
+ if(sourceLocation.FileName != context->loc.FileName
+ || sourceLocation.Line != context->loc.Line
+ || sourceLocation.Col < context->loc.Col)
+ {
+ emitRawText(context, "\n#line ");
+
+ char buffer[16];
+ sprintf(buffer, "%d", sourceLocation.Line);
+ emitRawText(context, buffer);
+
+ emitRawText(context, "\"");
+ for(auto c : sourceLocation.FileName)
+ {
+ char charBuffer[] = { c, 0 };
+ switch(c)
+ {
+ default:
+ emitRawText(context, charBuffer);
+ break;
+
+ // TODO: should probably canonicalize paths to not use backslash somewhere else
+ // in the compilation pipeline...
+ case '\\':
+ emitRawText(context, "/");
+ break;
+ }
+ }
+ emitRawText(context, "\"\n");
+
+ context->loc.FileName = sourceLocation.FileName;
+ context->loc.Line = sourceLocation.Line;
+ context->loc.Col = 1;
+ }
+
+ // Now indent up to the appropriate column, so that error messages
+ // that reference columns will be correct.
+ //
+ // TODO: This logic does not take into account whether indentation
+ // came in as spaces or tabs, so there is necessarily going to be
+ // coupling between how the downstream compiler counts columns,
+ // and how we do.
+ if(sourceLocation.Col > context->loc.Col)
+ {
+ int delta = sourceLocation.Col - context->loc.Col;
+ for( int ii = 0; ii < delta; ++ii )
+ {
+ emitRawText(context, " ");
+ }
+ context->loc.Col = sourceLocation.Col;
+ }
+}
+
+static void emitTokenWithLocation(EmitContext* context, Token const& token)
+{
+ if( token.Position.FileName.Length() != 0 )
+ {
+ advanceToSourceLocation(context, token.Position);
+ }
+ else
+ {
+ // If we don't have the original position info, we need to play
+ // it safe and emit whitespace to line things up nicely
+
+ if(token.flags & TokenFlag::AtStartOfLine)
+ Emit(context, "\n");
+ // TODO(tfoley): macro expansion can currently lead to whitespace getting dropped,
+ // so we will just insert it aggressively, to play it safe.
+ else // if(token.flags & TokenFlag::AfterWhitespace)
+ Emit(context, " ");
+ }
+
+ // Emit the raw textual content of the token
+ emit(context, token.Content);
+}
+
+static void EmitUnparsedStmt(EmitContext* context, RefPtr<UnparsedStmt> stmt)
+{
+ // TODO: actually emit the tokens that made up the statement...
+ Emit(context, "{\n");
+ for( auto& token : stmt->tokens )
+ {
+ emitTokenWithLocation(context, token);
+ }
+ Emit(context, "}\n");
+}
+
+static void EmitStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt)
+{
+ if (auto blockStmt = stmt.As<BlockStatementSyntaxNode>())
+ {
+ EmitBlockStmt(context, blockStmt);
+ return;
+ }
+ else if( auto unparsedStmt = stmt.As<UnparsedStmt>() )
+ {
+ EmitUnparsedStmt(context, unparsedStmt);
+ return;
+ }
+ else if (auto exprStmt = stmt.As<ExpressionStatementSyntaxNode>())
+ {
+ EmitExpr(context, exprStmt->Expression);
+ Emit(context, ";\n");
+ return;
+ }
+ else if (auto returnStmt = stmt.As<ReturnStatementSyntaxNode>())
+ {
+ Emit(context, "return");
+ if (auto expr = returnStmt->Expression)
+ {
+ Emit(context, " ");
+ EmitExpr(context, expr);
+ }
+ Emit(context, ";\n");
+ return;
+ }
+ else if (auto declStmt = stmt.As<VarDeclrStatementSyntaxNode>())
+ {
+ EmitDecl(context, declStmt->decl);
+ return;
+ }
+ else if (auto ifStmt = stmt.As<IfStatementSyntaxNode>())
+ {
+ Emit(context, "if(");
+ EmitExpr(context, ifStmt->Predicate);
+ Emit(context, ")\n");
+ EmitBlockStmt(context, ifStmt->PositiveStatement);
+ if(auto elseStmt = ifStmt->NegativeStatement)
+ {
+ Emit(context, "\nelse\n");
+ EmitBlockStmt(context, elseStmt);
+ }
+ return;
+ }
+ else if (auto forStmt = stmt.As<ForStatementSyntaxNode>())
+ {
+ EmitLoopAttributes(context, forStmt);
+
+ Emit(context, "for(");
+ if (auto initStmt = forStmt->InitialStatement)
+ {
+ EmitStmt(context, initStmt);
+ }
+ else
+ {
+ Emit(context, ";");
+ }
+ if (auto testExp = forStmt->PredicateExpression)
+ {
+ EmitExpr(context, testExp);
+ }
+ Emit(context, ";");
+ if (auto incrExpr = forStmt->SideEffectExpression)
+ {
+ EmitExpr(context, incrExpr);
+ }
+ Emit(context, ")\n");
+ EmitBlockStmt(context, forStmt->Statement);
+ return;
+ }
+ else if (auto discardStmt = stmt.As<DiscardStatementSyntaxNode>())
+ {
+ Emit(context, "discard;\n");
+ return;
+ }
+ else if (auto emptyStmt = stmt.As<EmptyStatementSyntaxNode>())
+ {
+ return;
+ }
+ else if (auto switchStmt = stmt.As<SwitchStmt>())
+ {
+ Emit(context, "switch(");
+ EmitExpr(context, switchStmt->condition);
+ Emit(context, ")\n");
+ EmitBlockStmt(context, switchStmt->body);
+ return;
+ }
+ else if (auto caseStmt = stmt.As<CaseStmt>())
+ {
+ Emit(context, "case ");
+ EmitExpr(context, caseStmt->expr);
+ Emit(context, ":\n");
+ return;
+ }
+ else if (auto defaultStmt = stmt.As<DefaultStmt>())
+ {
+ Emit(context, "default:{}\n");
+ return;
+ }
+ else if (auto breakStmt = stmt.As<BreakStatementSyntaxNode>())
+ {
+ Emit(context, "break;\n");
+ return;
+ }
+ else if (auto continueStmt = stmt.As<ContinueStatementSyntaxNode>())
+ {
+ Emit(context, "continue;\n");
+ return;
+ }
+
+ throw "unimplemented";
+
+}
+
+// Declaration References
+
+static void EmitVal(EmitContext* context, RefPtr<Val> val)
+{
+ if (auto type = val.As<ExpressionType>())
+ {
+ EmitType(context, type);
+ }
+ else if (auto intVal = val.As<IntVal>())
+ {
+ Emit(context, intVal);
+ }
+ else
+ {
+ // Note(tfoley): ignore unhandled cases for semantics for now...
+// assert(!"unimplemented");
+ }
+}
+
+static void EmitDeclRef(EmitContext* context, DeclRef declRef)
+{
+ // TODO: need to qualify a declaration name based on parent scopes/declarations
+
+ // Emit the name for the declaration itself
+ emitName(context, declRef.GetName());
+
+ // If the declaration is nested directly in a generic, then
+ // we need to output the generic arguments here
+ auto parentDeclRef = declRef.GetParent();
+ if (auto genericDeclRef = parentDeclRef.As<GenericDeclRef>())
+ {
+ // Only do this for declarations of appropriate flavors
+ if(auto funcDeclRef = declRef.As<FuncDeclBaseRef>())
+ {
+ // Don't emit generic arguments for functions, because HLSL doesn't allow them
+ return;
+ }
+
+ Substitutions* subst = declRef.substitutions.Ptr();
+ Emit(context, "<");
+ int argCount = subst->args.Count();
+ for (int aa = 0; aa < argCount; ++aa)
+ {
+ if (aa != 0) Emit(context, ",");
+ EmitVal(context, subst->args[aa]);
+ }
+ Emit(context, ">");
+ }
+
+}
+
+// Declarations
+
+// Emit any modifiers that should go in front of a declaration
+static void EmitModifiers(EmitContext* context, RefPtr<Decl> decl)
+{
+ // Emit any GLSL `layout` modifiers first
+ bool anyLayout = false;
+ for( auto mod : decl->GetModifiersOfType<GLSLUnparsedLayoutModifier>())
+ {
+ if(!anyLayout)
+ {
+ Emit(context, "layout(");
+ anyLayout = true;
+ }
+ else
+ {
+ Emit(context, ", ");
+ }
+
+ emit(context, mod->nameToken.Content);
+ if(mod->valToken.Type != TokenType::Unknown)
+ {
+ Emit(context, " = ");
+ emit(context, mod->valToken.Content);
+ }
+ }
+ if(anyLayout)
+ {
+ Emit(context, ")\n");
+ }
+
+ for (auto mod = decl->modifiers.first; mod; mod = mod->next)
+ {
+ if (0) {}
+
+ #define CASE(TYPE, KEYWORD) \
+ else if(auto mod_##TYPE = mod.As<TYPE>()) Emit(context, #KEYWORD " ")
+
+ CASE(RowMajorLayoutModifier, row_major);
+ CASE(ColumnMajorLayoutModifier, column_major);
+ CASE(HLSLNoInterpolationModifier, nointerpolation);
+ CASE(HLSLPreciseModifier, precise);
+ CASE(HLSLEffectSharedModifier, shared);
+ CASE(HLSLGroupSharedModifier, groupshared);
+ CASE(HLSLStaticModifier, static);
+ CASE(HLSLUniformModifier, uniform);
+ CASE(HLSLVolatileModifier, volatile);
+
+ CASE(InOutModifier, inout);
+ CASE(InModifier, in);
+ CASE(OutModifier, out);
+
+ CASE(HLSLPointModifier, point);
+ CASE(HLSLLineModifier, line);
+ CASE(HLSLTriangleModifier, triangle);
+ CASE(HLSLLineAdjModifier, lineadj);
+ CASE(HLSLTriangleAdjModifier, triangleadj);
+
+ CASE(HLSLLinearModifier, linear);
+ CASE(HLSLSampleModifier, sample);
+ CASE(HLSLCentroidModifier, centroid);
+
+ CASE(ConstModifier, const);
+
+ #undef CASE
+
+ // TODO: eventually we should be checked these modifiers, but for
+ // now we can emit them unchecked, I guess
+ else if (auto uncheckedAttr = mod.As<HLSLAttribute>())
+ {
+ Emit(context, "[");
+ emit(context, uncheckedAttr->nameToken.Content);
+ auto& args = uncheckedAttr->args;
+ auto argCount = args.Count();
+ if (argCount != 0)
+ {
+ Emit(context, "(");
+ for (int aa = 0; aa < argCount; ++aa)
+ {
+ if (aa != 0) Emit(context, ", ");
+ EmitExpr(context, args[aa]);
+ }
+ Emit(context, ")");
+ }
+ Emit(context, "]");
+ }
+
+ else if(auto simpleModifier = mod.As<SimpleModifier>())
+ {
+ emit(context, simpleModifier->nameToken.Content);
+ Emit(context, " ");
+ }
+
+ else
+ {
+ // skip any extra modifiers
+ }
+ }
+}
+
+
+typedef unsigned int ESemanticMask;
+enum
+{
+ kESemanticMask_None = 0,
+
+ kESemanticMask_NoPackOffset = 1 << 0,
+
+ kESemanticMask_Default = kESemanticMask_NoPackOffset,
+};
+
+static void EmitSemantic(EmitContext* context, RefPtr<HLSLSemantic> semantic, ESemanticMask /*mask*/)
+{
+ if (auto simple = semantic.As<HLSLSimpleSemantic>())
+ {
+ Emit(context, ": ");
+ emit(context, simple->name.Content);
+ }
+ else if(auto registerSemantic = semantic.As<HLSLRegisterSemantic>())
+ {
+ // Don't print out semantic from the user, since we are going to print the same thing our own way...
+#if 0
+ Emit(context, ": register(");
+ Emit(context, registerSemantic->registerName.Content);
+ if(registerSemantic->componentMask.Type != TokenType::Unknown)
+ {
+ Emit(context, ".");
+ Emit(context, registerSemantic->componentMask.Content);
+ }
+ Emit(context, ")");
+#endif
+ }
+ else if(auto packOffsetSemantic = semantic.As<HLSLPackOffsetSemantic>())
+ {
+ // Don't print out semantic from the user, since we are going to print the same thing our own way...
+#if 0
+ if(mask & kESemanticMask_NoPackOffset)
+ return;
+
+ Emit(context, ": packoffset(");
+ Emit(context, packOffsetSemantic->registerName.Content);
+ if(packOffsetSemantic->componentMask.Type != TokenType::Unknown)
+ {
+ Emit(context, ".");
+ Emit(context, packOffsetSemantic->componentMask.Content);
+ }
+ Emit(context, ")");
+#endif
+ }
+ else
+ {
+ assert(!"unimplemented");
+ }
+}
+
+
+static void EmitSemantics(EmitContext* context, RefPtr<Decl> decl, ESemanticMask mask = kESemanticMask_Default )
+{
+ // Don't emit semantics if we aren't translating down to HLSL
+ switch (context->target)
+ {
+ case CodeGenTarget::HLSL:
+ break;
+
+ default:
+ return;
+ }
+
+ for (auto mod = decl->modifiers.first; mod; mod = mod->next)
+ {
+ auto semantic = mod.As<HLSLSemantic>();
+ if (!semantic)
+ continue;
+
+ EmitSemantic(context, semantic, mask);
+ }
+}
+
+static void EmitDeclsInContainer(EmitContext* context, RefPtr<ContainerDecl> container)
+{
+ for (auto member : container->Members)
+ {
+ EmitDecl(context, member);
+ }
+}
+
+static void EmitDeclsInContainerUsingLayout(
+ EmitContext* context,
+ RefPtr<ContainerDecl> container,
+ RefPtr<StructTypeLayout> containerLayout)
+{
+ for (auto member : container->Members)
+ {
+ RefPtr<VarLayout> memberLayout;
+ if( containerLayout->mapVarToLayout.TryGetValue(member.Ptr(), memberLayout) )
+ {
+ EmitDeclUsingLayout(context, member, memberLayout);
+ }
+ else
+ {
+ // No layout for this decl
+ EmitDecl(context, member);
+ }
+ }
+}
+
+static void EmitTypeDefDecl(EmitContext* context, RefPtr<TypeDefDecl> decl)
+{
+ // TODO(tfoley): check if current compilation target even supports typedefs
+
+ Emit(context, "typedef ");
+ EmitType(context, decl->Type, decl->Name.Content);
+ Emit(context, ";\n");
+}
+
+static void EmitStructDecl(EmitContext* context, RefPtr<StructSyntaxNode> decl)
+{
+ // Don't emit a declaration that was only generated implicitly, for
+ // the purposes of semantic checking.
+ if(decl->HasModifier<ImplicitParameterBlockElementTypeModifier>())
+ return;
+
+ Emit(context, "struct ");
+ emitName(context, decl->Name.Content);
+ Emit(context, "\n{\n");
+
+ // TODO(tfoley): Need to hoist members functions, etc. out to global scope
+ EmitDeclsInContainer(context, decl);
+
+ Emit(context, "};\n");
+}
+
+// Shared emit logic for variable declarations (used for parameters, locals, globals, fields)
+static void EmitVarDeclCommon(EmitContext* context, VarDeclBaseRef declRef)
+{
+ EmitModifiers(context, declRef.GetDecl());
+
+ EmitType(context, declRef.GetType(), declRef.GetName());
+
+ EmitSemantics(context, declRef.GetDecl());
+
+ // TODO(tfoley): technically have to apply substitution here too...
+ if (auto initExpr = declRef.GetDecl()->Expr)
+ {
+ Emit(context, " = ");
+ EmitExpr(context, initExpr);
+ }
+}
+
+// Shared emit logic for variable declarations (used for parameters, locals, globals, fields)
+static void EmitVarDeclCommon(EmitContext* context, RefPtr<VarDeclBase> decl)
+{
+ EmitVarDeclCommon(context, DeclRef(decl.Ptr(), nullptr).As<VarDeclBaseRef>());
+}
+
+// Emit a single `regsiter` semantic, as appropriate for a given resource-type-specific layout info
+static void emitHLSLRegisterSemantic(
+ EmitContext* context,
+ VarLayout::ResourceInfo const& info)
+{
+ if( info.kind == LayoutResourceKind::Uniform )
+ {
+ size_t offset = info.index;
+
+ // The HLSL `c` register space is logically grouped in 16-byte registers,
+ // while we try to traffic in byte offsets. That means we need to pick
+ // a register number, based on the starting offset in 16-byte register
+ // units, and then a "component" within that register, based on 4-byte
+ // offsets from there. We cannot support more fine-grained offsets than that.
+
+ Emit(context, ": packoffset(c");
+
+ // Size of a logical `c` register in bytes
+ auto registerSize = 16;
+
+ // Size of each component of a logical `c` register, in bytes
+ auto componentSize = 4;
+
+ size_t startRegister = offset / registerSize;
+ Emit(context, int(startRegister));
+
+ size_t byteOffsetInRegister = offset % registerSize;
+
+ // If this field doesn't start on an even register boundary,
+ // then we need to emit additional information to pick the
+ // right component to start from
+ if (byteOffsetInRegister != 0)
+ {
+ // The value had better occupy a whole number of components.
+ assert(byteOffsetInRegister % componentSize == 0);
+
+ size_t startComponent = byteOffsetInRegister / componentSize;
+
+ static const char* kComponentNames[] = {"x", "y", "z", "w"};
+ Emit(context, ".");
+ Emit(context, kComponentNames[startComponent]);
+ }
+ Emit(context, ")");
+ }
+ else
+ {
+ Emit(context, ": register(");
+ switch( info.kind )
+ {
+ case LayoutResourceKind::ConstantBuffer:
+ Emit(context, "b");
+ break;
+ case LayoutResourceKind::ShaderResource:
+ Emit(context, "t");
+ break;
+ case LayoutResourceKind::UnorderedAccess:
+ Emit(context, "u");
+ break;
+ case LayoutResourceKind::SamplerState:
+ Emit(context, "s");
+ break;
+ default:
+ assert(!"unexpected");
+ break;
+ }
+ Emit(context, info.index);
+ if(info.space)
+ {
+ Emit(context, ", space");
+ Emit(context, info.space);
+ }
+ Emit(context, ")");
+ }
+}
+
+// Emit all the `register` semantics that are appropriate for a particular variable layout
+static void emitHLSLRegisterSemantics(
+ EmitContext* context,
+ RefPtr<VarLayout> layout)
+{
+ if (!layout) return;
+
+ switch( context->target )
+ {
+ default:
+ return;
+
+ case CodeGenTarget::HLSL:
+ break;
+ }
+
+ for( auto rr : layout->resourceInfos )
+ {
+ emitHLSLRegisterSemantic(context, rr);
+ }
+}
+
+static void emitHLSLParameterBlockDecl(
+ EmitContext* context,
+ RefPtr<VarDeclBase> varDecl,
+ RefPtr<ParameterBlockType> parameterBlockType,
+ RefPtr<VarLayout> layout)
+{
+ // The data type that describes where stuff in the constant buffer should go
+ RefPtr<ExpressionType> dataType = parameterBlockType->elementType;
+
+ // We expect/require the data type to be a user-defined `struct` type
+ auto declRefType = dataType->As<DeclRefType>();
+ assert(declRefType);
+
+ // We expect to always have layout information
+ assert(layout);
+
+ // We expect the layout to be for a structured type...
+ RefPtr<ParameterBlockTypeLayout> bufferLayout = layout->typeLayout.As<ParameterBlockTypeLayout>();
+ assert(bufferLayout);
+
+ RefPtr<StructTypeLayout> structTypeLayout = bufferLayout->elementTypeLayout.As<StructTypeLayout>();
+ assert(structTypeLayout);
+
+ if( auto constantBufferType = parameterBlockType->As<ConstantBufferType>() )
+ {
+ Emit(context, "cbuffer ");
+ }
+ else if( auto textureBufferType = parameterBlockType->As<TextureBufferType>() )
+ {
+ Emit(context, "tbuffer ");
+ }
+
+ if( auto reflectionNameModifier = varDecl->FindModifier<ParameterBlockReflectionName>() )
+ {
+ Emit(context, " ");
+ emitName(context, reflectionNameModifier->nameToken.Content);
+ }
+
+ EmitSemantics(context, varDecl, kESemanticMask_None);
+
+ auto info = layout->FindResourceInfo(LayoutResourceKind::ConstantBuffer);
+ assert(info);
+ emitHLSLRegisterSemantic(context, *info);
+
+ Emit(context, "\n{\n");
+ if (auto structRef = declRefType->declRef.As<StructDeclRef>())
+ {
+ for (auto field : structRef.GetMembersOfType<FieldDeclRef>())
+ {
+ EmitVarDeclCommon(context, field);
+
+ RefPtr<VarLayout> fieldLayout;
+ structTypeLayout->mapVarToLayout.TryGetValue(field.GetDecl(), fieldLayout);
+ assert(fieldLayout);
+
+ // Emit explicit layout annotations for every field
+ for( auto rr : fieldLayout->resourceInfos )
+ {
+ auto kind = rr.kind;
+
+ auto offsetResource = rr;
+
+ if(kind != LayoutResourceKind::Uniform)
+ {
+ // Add the base index from the cbuffer into the index of the field
+ //
+ // TODO(tfoley): consider maybe not doing this, since it actually
+ // complicates logic around constant buffers...
+
+ // If the member of the cbuffer uses a resource, it had better
+ // appear as part of the cubffer layout as well.
+ auto cbufferResource = layout->FindResourceInfo(kind);
+ assert(cbufferResource);
+
+ offsetResource.index += cbufferResource->index;
+ offsetResource.space += cbufferResource->space;
+ }
+
+ emitHLSLRegisterSemantic(context, offsetResource);
+ }
+
+ Emit(context, ";\n");
+ }
+ }
+ Emit(context, "}\n");
+}
+
+static void
+emitGLSLLayoutQualifier(
+ EmitContext* context,
+ VarLayout::ResourceInfo const& info)
+{
+ switch(info.kind)
+ {
+ case LayoutResourceKind::Uniform:
+ Emit(context, "layout(offset = ");
+ Emit(context, info.index);
+ Emit(context, ")\n");
+ break;
+
+ case LayoutResourceKind::VertexInput:
+ case LayoutResourceKind::FragmentOutput:
+ Emit(context, "layout(location = ");
+ Emit(context, info.index);
+ Emit(context, ")\n");
+ break;
+
+ case LayoutResourceKind::SpecializationConstant:
+ Emit(context, "layout(constant_id = ");
+ Emit(context, info.index);
+ Emit(context, ")\n");
+ break;
+
+ case LayoutResourceKind::ConstantBuffer:
+ case LayoutResourceKind::ShaderResource:
+ case LayoutResourceKind::UnorderedAccess:
+ case LayoutResourceKind::SamplerState:
+ case LayoutResourceKind::DescriptorTableSlot:
+ Emit(context, "layout(binding = ");
+ Emit(context, info.index);
+ if(info.space)
+ {
+ Emit(context, ", set = ");
+ Emit(context, info.space);
+ }
+ Emit(context, ")\n");
+ break;
+ }
+}
+
+static void
+emitGLSLLayoutQualifiers(
+ EmitContext* context,
+ RefPtr<VarLayout> layout)
+{
+ if(!layout) return;
+
+ switch( context->target )
+ {
+ default:
+ return;
+
+ case CodeGenTarget::GLSL:
+ break;
+ }
+
+ for( auto info : layout->resourceInfos )
+ {
+ emitGLSLLayoutQualifier(context, info);
+ }
+}
+
+static void emitGLSLParameterBlockDecl(
+ EmitContext* context,
+ RefPtr<VarDeclBase> varDecl,
+ RefPtr<ParameterBlockType> parameterBlockType,
+ RefPtr<VarLayout> layout)
+{
+ // The data type that describes where stuff in the constant buffer should go
+ RefPtr<ExpressionType> dataType = parameterBlockType->elementType;
+
+ // We expect/require the data type to be a user-defined `struct` type
+ auto declRefType = dataType->As<DeclRefType>();
+ assert(declRefType);
+
+ // We expect to always have layout information
+ assert(layout);
+
+ // We expect the layout to be for a structured type...
+ RefPtr<ParameterBlockTypeLayout> bufferLayout = layout->typeLayout.As<ParameterBlockTypeLayout>();
+ assert(bufferLayout);
+
+ RefPtr<StructTypeLayout> structTypeLayout = bufferLayout->elementTypeLayout.As<StructTypeLayout>();
+ assert(structTypeLayout);
+
+ emitGLSLLayoutQualifiers(context, layout);
+
+ EmitModifiers(context, varDecl);
+
+ // Emit an apprpriate declaration keyword based on the kind of block
+ if (parameterBlockType->As<ConstantBufferType>())
+ {
+ Emit(context, "uniform");
+ }
+ else if (parameterBlockType->As<GLSLInputParameterBlockType>())
+ {
+ Emit(context, "in");
+ }
+ else if (parameterBlockType->As<GLSLOutputParameterBlockType>())
+ {
+ Emit(context, "out");
+ }
+ else if (parameterBlockType->As<GLSLShaderStorageBufferType>())
+ {
+ Emit(context, "buffer");
+ }
+ else
+ {
+ assert(!"unexpected");
+ Emit(context, "uniform");
+ }
+
+ if( auto reflectionNameModifier = varDecl->FindModifier<ParameterBlockReflectionName>() )
+ {
+ Emit(context, " ");
+ emitName(context, reflectionNameModifier->nameToken.Content);
+ }
+
+ Emit(context, "\n{\n");
+ if (auto structRef = declRefType->declRef.As<StructDeclRef>())
+ {
+ for (auto field : structRef.GetMembersOfType<FieldDeclRef>())
+ {
+ RefPtr<VarLayout> fieldLayout;
+ structTypeLayout->mapVarToLayout.TryGetValue(field.GetDecl(), fieldLayout);
+ assert(fieldLayout);
+
+ // TODO(tfoley): We may want to emit *some* of these,
+ // some of the time...
+// emitGLSLLayoutQualifiers(context, fieldLayout);
+
+ EmitVarDeclCommon(context, field);
+
+ Emit(context, ";\n");
+ }
+ }
+ Emit(context, "}");
+
+ if( varDecl->Name.Type != TokenType::Unknown )
+ {
+ Emit(context, " ");
+ emitName(context, varDecl->Name.Content);
+ }
+
+ Emit(context, ";\n");
+}
+
+static void emitParameterBlockDecl(
+ EmitContext* context,
+ RefPtr<VarDeclBase> varDecl,
+ RefPtr<ParameterBlockType> parameterBlockType,
+ RefPtr<VarLayout> layout)
+{
+ switch(context->target)
+ {
+ case CodeGenTarget::HLSL:
+ emitHLSLParameterBlockDecl(context, varDecl, parameterBlockType, layout);
+ break;
+
+ case CodeGenTarget::GLSL:
+ emitGLSLParameterBlockDecl(context, varDecl, parameterBlockType, layout);
+ break;
+
+ default:
+ assert(!"unexpected");
+ break;
+ }
+}
+
+static void EmitVarDecl(EmitContext* context, RefPtr<VarDeclBase> decl, RefPtr<VarLayout> layout)
+{
+ // As a special case, a variable using a parameter block type
+ // will be translated into a declaration using the more primitive
+ // language syntax.
+ //
+ // TODO(tfoley): Be sure to unwrap arrays here, in the GLSL case.
+ //
+ // TODO(tfoley): Detect cases where we need to fall back to
+ // ordinary variable declaration syntax in HLSL.
+ //
+ // TODO(tfoley): there might be a better way to detect this, e.g.,
+ // with an attribute that gets attached to the variable declaration.
+ if (auto parameterBlockType = decl->Type->As<ParameterBlockType>())
+ {
+ emitParameterBlockDecl(context, decl, parameterBlockType, layout);
+ return;
+ }
+
+ emitGLSLLayoutQualifiers(context, layout);
+
+ EmitVarDeclCommon(context, decl);
+
+ emitHLSLRegisterSemantics(context, layout);
+
+ Emit(context, ";\n");
+}
+
+static void EmitParamDecl(EmitContext* context, RefPtr<ParameterSyntaxNode> decl)
+{
+ EmitVarDeclCommon(context, decl);
+}
+
+static void EmitFuncDecl(EmitContext* context, RefPtr<FunctionSyntaxNode> decl)
+{
+ EmitModifiers(context, decl);
+
+ // TODO: if a function returns an array type, or something similar that
+ // isn't allowed by declarator syntax and/or language rules, we could
+ // hypothetically wrap things in a `typedef` and work around it.
+
+ EmitType(context, decl->ReturnType, decl->Name.Content);
+
+ Emit(context, "(");
+ bool first = true;
+ for (auto paramDecl : decl->GetMembersOfType<ParameterSyntaxNode>())
+ {
+ if (!first) Emit(context, ", ");
+ EmitParamDecl(context, paramDecl);
+ first = false;
+ }
+ Emit(context, ")");
+
+ EmitSemantics(context, decl);
+
+ if (auto bodyStmt = decl->Body)
+ {
+ EmitBlockStmt(context, bodyStmt);
+ }
+ else
+ {
+ Emit(context, ";\n");
+ }
+}
+
+static void emitGLSLPreprocessorDirectives(
+ EmitContext* context,
+ RefPtr<ProgramSyntaxNode> program)
+{
+ switch(context->target)
+ {
+ // Don't emit this stuff unless we are targetting GLSL
+ default:
+ return;
+
+ case CodeGenTarget::GLSL:
+ break;
+ }
+
+ if( auto versionDirective = program->FindModifier<GLSLVersionDirective>() )
+ {
+ // TODO(tfoley): Emit an appropriate `#line` directive...
+
+ Emit(context, "#version ");
+ emit(context, versionDirective->versionNumberToken.Content);
+ if(versionDirective->glslProfileToken.Type != TokenType::Unknown)
+ {
+ Emit(context, " ");
+ emit(context, versionDirective->glslProfileToken.Content);
+ }
+ Emit(context, "\n");
+ }
+ else
+ {
+ // No explicit version was given (probably because we are cross-compiling).
+ //
+ // We need to pick an appropriate version, ideally based on the features
+ // that the shader ends up using.
+ //
+ // For now we just fall back to a reasonably recent version.
+
+ Emit(context, "#version 420\n");
+ }
+
+ // TODO: when cross-compiling we may need to output additional `#extension` directives
+ // based on the features that we have used.
+
+ for( auto extensionDirective : program->GetModifiersOfType<GLSLExtensionDirective>() )
+ {
+ // TODO(tfoley): Emit an appropriate `#line` directive...
+
+ Emit(context, "#extension ");
+ emit(context, extensionDirective->extensionNameToken.Content);
+ Emit(context, " : ");
+ emit(context, extensionDirective->dispositionToken.Content);
+ Emit(context, "\n");
+ }
+
+ // TODO: handle other cases...
+}
+
+static void EmitProgram(
+ EmitContext* context,
+ RefPtr<ProgramSyntaxNode> program,
+ RefPtr<ProgramLayout> programLayout)
+{
+ // There may be global-scope modifiers that we should emit now
+ emitGLSLPreprocessorDirectives(context, program);
+
+ switch(context->target)
+ {
+ case CodeGenTarget::GLSL:
+ {
+ // TODO(tfoley): Need a plan for how to enable/disable these as needed...
+// Emit(context, "#extension GL_GOOGLE_cpp_style_line_directive : require\n");
+ }
+ break;
+
+ default:
+ break;
+ }
+
+
+ // Layout information for the global scope is either an ordinary
+ // `struct` in the common case, or a constant buffer in the case
+ // where there were global-scope uniforms.
+ auto globalScopeLayout = programLayout->globalScopeLayout;
+ if( auto globalStructLayout = globalScopeLayout.As<StructTypeLayout>() )
+ {
+ // The `struct` case is easy enough to handle: we just
+ // emit all the declarations directly, using their layout
+ // information as a guideline.
+ EmitDeclsInContainerUsingLayout(context, program, globalStructLayout);
+ }
+ else if(auto globalConstantBufferLayout = globalScopeLayout.As<ParameterBlockTypeLayout>())
+ {
+ // TODO: the `cbuffer` case really needs to be emitted very
+ // carefully, but that is beyond the scope of what a simple rewriter
+ // can easily do (without semantic analysis, etc.).
+ //
+ // The crux of the problem is that we need to collect all the
+ // global-scope uniforms (but not declarations that don't involve
+ // uniform storage...) and put them in a single `cbuffer` declaration,
+ // so that we can give it an explicit location. The fields in that
+ // declaration might use various type declarations, so we'd really
+ // need to emit all the type declarations first, and that involves
+ // some large scale reorderings.
+ //
+ // For now we will punt and just emit the declarations normally,
+ // and hope that the global-scope block (`$Globals`) gets auto-assigned
+ // the same location that we manually asigned it.
+
+ auto elementTypeLayout = globalConstantBufferLayout->elementTypeLayout;
+ auto elementTypeStructLayout = elementTypeLayout.As<StructTypeLayout>();
+
+ // We expect all constant buffers to contain `struct` types for now
+ assert(elementTypeStructLayout);
+
+ EmitDeclsInContainerUsingLayout(
+ context,
+ program,
+ elementTypeStructLayout);
+ }
+ else
+ {
+ assert(!"unexpected");
+ }
+}
+
+static void EmitDeclImpl(EmitContext* context, RefPtr<Decl> decl, RefPtr<VarLayout> layout)
+{
+ // Don't emit code for declarations that came from the stdlib.
+ //
+ // TODO(tfoley): We probably need to relax this eventually,
+ // since different targets might have different sets of builtins.
+ if (decl->HasModifier<FromStdLibModifier>())
+ return;
+
+ if (auto typeDefDecl = decl.As<TypeDefDecl>())
+ {
+ EmitTypeDefDecl(context, typeDefDecl);
+ return;
+ }
+ else if (auto structDecl = decl.As<StructSyntaxNode>())
+ {
+ EmitStructDecl(context, structDecl);
+ return;
+ }
+ else if (auto varDecl = decl.As<VarDeclBase>())
+ {
+ EmitVarDecl(context, varDecl, layout);
+ return;
+ }
+ else if (auto funcDecl = decl.As<FunctionSyntaxNode>())
+ {
+ EmitFuncDecl(context, funcDecl);
+ return;
+ }
+ else if (auto genericDecl = decl.As<GenericDecl>())
+ {
+ // Don't emit generic decls directly; we will only
+ // ever emit particular instantiations of them.
+ return;
+ }
+ else if (auto classDecl = decl.As<ClassSyntaxNode>())
+ {
+ return;
+ }
+ else if( auto emptyDecl = decl.As<EmptyDecl>() )
+ {
+ EmitModifiers(context, emptyDecl);
+ Emit(context, ";\n");
+ return;
+ }
+ throw "unimplemented";
+}
+
+static void EmitDecl(EmitContext* context, RefPtr<Decl> decl)
+{
+ EmitDeclImpl(context, decl, nullptr);
+}
+
+static void EmitDeclUsingLayout(EmitContext* context, RefPtr<Decl> decl, RefPtr<VarLayout> layout)
+{
+ EmitDeclImpl(context, decl, layout);
+}
+
+static void EmitDecl(EmitContext* context, RefPtr<DeclBase> declBase)
+{
+ if( auto decl = declBase.As<Decl>() )
+ {
+ EmitDecl(context, decl);
+ }
+ else if(auto declGroup = declBase.As<DeclGroup>())
+ {
+ for(auto d : declGroup->decls)
+ EmitDecl(context, d);
+ }
+ else
+ {
+ throw "unimplemented";
+ }
+}
+
+static void registerReservedWord(
+ EmitContext* context,
+ String const& name)
+{
+ context->reservedWords.Add(name, name);
+}
+
+static void registerReservedWords(
+ EmitContext* context)
+{
+#define WORD(NAME) registerReservedWord(context, #NAME)
+
+ switch (context->target)
+ {
+ case CodeGenTarget::GLSL:
+ WORD(attribute);
+ WORD(const);
+ WORD(uniform);
+ WORD(varying);
+ WORD(buffer);
+
+ WORD(shared);
+ WORD(coherent);
+ WORD(volatile);
+ WORD(restrict);
+ WORD(readonly);
+ WORD(writeonly);
+ WORD(atomic_unit);
+ WORD(layout);
+ WORD(centroid);
+ WORD(flat);
+ WORD(smooth);
+ WORD(noperspective);
+ WORD(patch);
+ WORD(sample);
+ WORD(break);
+ WORD(continue);
+ WORD(do);
+ WORD(for);
+ WORD(while);
+ WORD(switch);
+ WORD(case);
+ WORD(default);
+ WORD(if);
+ WORD(else);
+ WORD(subroutine);
+ WORD(in);
+ WORD(out);
+ WORD(inout);
+ WORD(float);
+ WORD(double);
+ WORD(int);
+ WORD(void);
+ WORD(bool);
+ WORD(true);
+ WORD(false);
+ WORD(invariant);
+ WORD(precise);
+ WORD(discard);
+ WORD(return);
+
+ WORD(lowp);
+ WORD(mediump);
+ WORD(highp);
+ WORD(precision);
+ WORD(struct);
+ WORD(uint);
+
+ WORD(common);
+ WORD(partition);
+ WORD(active);
+ WORD(asm);
+ WORD(class);
+ WORD(union);
+ WORD(enum);
+ WORD(typedef);
+ WORD(template);
+ WORD(this);
+ WORD(resource);
+
+ WORD(goto);
+ WORD(inline);
+ WORD(noinline);
+ WORD(public);
+ WORD(static);
+ WORD(extern);
+ WORD(external);
+ WORD(interface);
+ WORD(long);
+ WORD(short);
+ WORD(half);
+ WORD(fixed);
+ WORD(unsigned);
+ WORD(superp);
+ WORD(input);
+ WORD(output);
+ WORD(filter);
+ WORD(sizeof);
+ WORD(cast);
+ WORD(namespace);
+ WORD(using);
+
+#define CASE(NAME) \
+ WORD(NAME ## 2); WORD(NAME ## 3); WORD(NAME ## 4)
+
+ CASE(mat);
+ CASE(dmat);
+ CASE(mat2x);
+ CASE(mat3x);
+ CASE(mat4x);
+ CASE(dmat2x);
+ CASE(dmat3x);
+ CASE(dmat4x);
+ CASE(vec);
+ CASE(ivec);
+ CASE(bvec);
+ CASE(dvec);
+ CASE(uvec);
+ CASE(hvec);
+ CASE(fvec);
+
+#undef CASE
+
+#define CASE(NAME) \
+ WORD(NAME ## 1D); \
+ WORD(NAME ## 2D); \
+ WORD(NAME ## 3D); \
+ WORD(NAME ## Cube); \
+ WORD(NAME ## 1DArray); \
+ WORD(NAME ## 2DArray); \
+ WORD(NAME ## 3DArray); \
+ WORD(NAME ## CubeArray);\
+ WORD(NAME ## 2DMS); \
+ WORD(NAME ## 2DMSArray) \
+ /* end */
+
+#define CASE2(NAME) \
+ CASE(NAME); \
+ CASE(i ## NAME); \
+ CASE(u ## NAME) \
+ /* end */
+
+ CASE2(sampler);
+ CASE2(image);
+ CASE2(texture);
+
+#undef CASE2
+#undef CASE
+ break;
+
+ default:
+ break;
+ }
+}
+
+String emitProgram(
+ ProgramSyntaxNode* program,
+ ProgramLayout* programLayout,
+ CodeGenTarget target)
+{
+ // TODO(tfoley): only emit symbols on-demand, as needed by a particular entry point
+
+ EmitContext context;
+ context.target = target;
+
+ registerReservedWords(&context);
+
+ EmitProgram(&context, program, programLayout);
+
+ String code = context.sb.ProduceString();
+
+ return code;
+
+#if 0
+ // HACK(tfoley): Invoke the D3D HLSL compiler on the result, to validate it
+
+#ifdef _WIN32
+ {
+ HMODULE d3dCompiler = LoadLibraryA("d3dcompiler_47");
+ assert(d3dCompiler);
+
+ pD3DCompile D3DCompile_ = (pD3DCompile)GetProcAddress(d3dCompiler, "D3DCompile");
+ assert(D3DCompile_);
+
+ ID3DBlob* codeBlob;
+ ID3DBlob* diagnosticsBlob;
+ HRESULT hr = D3DCompile_(
+ code.begin(),
+ code.Length(),
+ "slang",
+ nullptr,
+ nullptr,
+ "main",
+ "ps_5_0",
+ 0,
+ 0,
+ &codeBlob,
+ &diagnosticsBlob);
+ if (codeBlob) codeBlob->Release();
+ if (diagnosticsBlob)
+ {
+ String diagnostics = (char const*) diagnosticsBlob->GetBufferPointer();
+ fprintf(stderr, "%s", diagnostics.begin());
+ OutputDebugStringA(diagnostics.begin());
+ diagnosticsBlob->Release();
+ }
+ if (FAILED(hr))
+ {
+ int f = 9;
+ }
+ }
+
+ #include <d3dcompiler.h>
+#endif
+#endif
+
+}
+
+
+}} // Slang::Compiler
diff --git a/source/slang/emit.h b/source/slang/emit.h
new file mode 100644
index 000000000..05ea1550f
--- /dev/null
+++ b/source/slang/emit.h
@@ -0,0 +1,24 @@
+// Emit.h
+#ifndef SLANG_EMIT_H_INCLUDED
+#define SLANG_EMIT_H_INCLUDED
+
+#include "../core/basic.h"
+
+#include "compiler.h"
+
+namespace Slang
+{
+ namespace Compiler
+ {
+ using namespace CoreLib::Basic;
+
+ class ProgramSyntaxNode;
+ class ProgramLayout;
+
+ String emitProgram(
+ ProgramSyntaxNode* program,
+ ProgramLayout* programLayout,
+ CodeGenTarget target);
+ }
+}
+#endif
diff --git a/source/slang/intrinsic-defs.h b/source/slang/intrinsic-defs.h
new file mode 100644
index 000000000..19a3899a3
--- /dev/null
+++ b/source/slang/intrinsic-defs.h
@@ -0,0 +1,94 @@
+// intrinsic-defs.h
+
+// The file is meant to be included multiple times, to produce different
+// pieces of code related to intrinsic operations
+//
+// Each intrinsic op is declared here with:
+//
+// INTRINSIC(name)
+//
+
+#ifndef INTRINSIC
+#error Need to define INTRINSIC(NAME) before including "intrinsic-defs.h"
+#endif
+
+INTRINSIC(Add)
+INTRINSIC(Sub)
+INTRINSIC(Mul)
+INTRINSIC(Div)
+INTRINSIC(Mod)
+
+INTRINSIC(Lsh)
+INTRINSIC(Rsh)
+
+INTRINSIC(Eql)
+INTRINSIC(Neq)
+INTRINSIC(Greater)
+INTRINSIC(Less)
+INTRINSIC(Geq)
+INTRINSIC(Leq)
+INTRINSIC(BitAnd)
+INTRINSIC(BitXor)
+INTRINSIC(BitOr)
+
+// TODO(tfoley): need to distinguish short-circuiting and not...
+INTRINSIC(And)
+INTRINSIC(Or)
+
+INTRINSIC(Assign)
+INTRINSIC(AddAssign)
+INTRINSIC(SubAssign)
+INTRINSIC(MulAssign)
+INTRINSIC(DivAssign)
+INTRINSIC(ModAssign)
+INTRINSIC(LshAssign)
+INTRINSIC(RshAssign)
+INTRINSIC(OrAssign)
+INTRINSIC(AndAssign)
+INTRINSIC(XorAssign)
+INTRINSIC(Neg)
+INTRINSIC(Not)
+INTRINSIC(BitNot)
+INTRINSIC(PreInc)
+INTRINSIC(PreDec)
+INTRINSIC(PostInc)
+INTRINSIC(PostDec)
+
+INTRINSIC(Sequence)
+INTRINSIC(Select)
+
+INTRINSIC(Mul_Scalar_Scalar)
+INTRINSIC(Mul_Vector_Scalar)
+INTRINSIC(Mul_Scalar_Vector)
+INTRINSIC(Mul_Matrix_Scalar)
+INTRINSIC(Mul_Scalar_Matrix)
+INTRINSIC(InnerProduct_Vector_Vector)
+INTRINSIC(InnerProduct_Vector_Matrix)
+INTRINSIC(InnerProduct_Matrix_Vector)
+INTRINSIC(InnerProduct_Matrix_Matrix)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+// Un-deefine the macor here, so that the client does not have to.
+#undef INTRINSIC
diff --git a/source/slang/lexer.cpp b/source/slang/lexer.cpp
new file mode 100644
index 000000000..7234c4983
--- /dev/null
+++ b/source/slang/lexer.cpp
@@ -0,0 +1,1012 @@
+#include "Lexer.h"
+
+#include <assert.h>
+
+namespace Slang
+{
+ namespace Compiler
+ {
+ static Token GetEndOfFileToken()
+ {
+ return Token(TokenType::EndOfFile, "", 0, 0, 0, "");
+ }
+
+ Token* TokenList::begin() const
+ {
+ assert(mTokens.Count());
+ return &mTokens[0];
+ }
+
+ Token* TokenList::end() const
+ {
+ assert(mTokens.Count());
+ assert(mTokens[mTokens.Count()-1].Type == TokenType::EndOfFile);
+ return &mTokens[mTokens.Count() - 1];
+ }
+
+ TokenSpan::TokenSpan()
+ : mBegin(NULL)
+ , mEnd (NULL)
+ {}
+
+ TokenReader::TokenReader()
+ : mCursor(NULL)
+ , mEnd (NULL)
+ {}
+
+
+ Token TokenReader::PeekToken() const
+ {
+ if (!mCursor)
+ return GetEndOfFileToken();
+
+ Token token = *mCursor;
+ if (mCursor == mEnd)
+ token.Type = TokenType::EndOfFile;
+ return token;
+ }
+
+ TokenType TokenReader::PeekTokenType() const
+ {
+ if (mCursor == mEnd)
+ return TokenType::EndOfFile;
+ assert(mCursor);
+ return mCursor->Type;
+ }
+
+ CodePosition TokenReader::PeekLoc() const
+ {
+ if (!mCursor)
+ return CodePosition();
+ assert(mCursor);
+ return mCursor->Position;
+ }
+
+ Token TokenReader::AdvanceToken()
+ {
+ if (!mCursor)
+ return GetEndOfFileToken();
+
+ Token token = *mCursor;
+ if (mCursor == mEnd)
+ token.Type = TokenType::EndOfFile;
+ else
+ mCursor++;
+ return token;
+ }
+
+ // Lexer
+
+ Lexer::Lexer(
+ String const& path,
+ String const& content,
+ DiagnosticSink* sink)
+ : path(path)
+ , content(content)
+ , sink(sink)
+ {
+ cursor = content.begin();
+ end = content.end();
+
+ loc = CodePosition(1, 1, 0, path);
+ tokenFlags = TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace;
+ lexerFlags = 0;
+ }
+
+ Lexer::~Lexer()
+ {
+ }
+
+ enum { kEOF = -1 };
+
+ static int peek(Lexer* lexer)
+ {
+ if(lexer->cursor == lexer->end)
+ return kEOF;
+
+ return *lexer->cursor;
+ }
+
+ static int advance(Lexer* lexer)
+ {
+ if(lexer->cursor == lexer->end)
+ return kEOF;
+
+ lexer->loc.Col++;
+ lexer->loc.Pos++;
+
+ return *lexer->cursor++;
+ }
+
+ static void handleNewLine(Lexer* lexer)
+ {
+ int c = advance(lexer);
+ assert(c == '\n' || c == '\r');
+
+ int d = peek(lexer);
+ if( (c ^ d) == ('\n' ^ '\r') )
+ {
+ advance(lexer);
+ }
+
+ lexer->loc.Line++;
+ lexer->loc.Col = 1;
+ }
+
+ static void lexLineComment(Lexer* lexer)
+ {
+ for(;;)
+ {
+ switch(peek(lexer))
+ {
+ case '\n': case '\r': case kEOF:
+ return;
+
+ default:
+ advance(lexer);
+ continue;
+ }
+ }
+ }
+
+ static void lexBlockComment(Lexer* lexer)
+ {
+ for(;;)
+ {
+ switch(peek(lexer))
+ {
+ case kEOF:
+ // TODO(tfoley) diagnostic!
+ return;
+
+ case '\n': case '\r':
+ handleNewLine(lexer);
+ continue;
+
+ case '*':
+ advance(lexer);
+ switch( peek(lexer) )
+ {
+ case '/':
+ advance(lexer);
+ return;
+
+ default:
+ continue;
+ }
+
+ default:
+ advance(lexer);
+ continue;
+ }
+ }
+ }
+
+ static void lexHorizontalSpace(Lexer* lexer)
+ {
+ for(;;)
+ {
+ switch(peek(lexer))
+ {
+ case ' ': case '\t':
+ advance(lexer);
+ continue;
+
+ default:
+ return;
+ }
+ }
+ }
+
+ static void lexIdentifier(Lexer* lexer)
+ {
+ for(;;)
+ {
+ int c = peek(lexer);
+ if(('a' <= c ) && (c <= 'z')
+ || ('A' <= c) && (c <= 'Z')
+ || ('0' <= c) && (c <= '9')
+ || (c == '_'))
+ {
+ advance(lexer);
+ continue;
+ }
+
+ return;
+ }
+ }
+
+ static void lexDigits(Lexer* lexer, int base)
+ {
+ for(;;)
+ {
+ int c = peek(lexer);
+
+ int digitVal = 0;
+ switch(c)
+ {
+ case '0': case '1': case '2': case '3': case '4':
+ case '5': case '6': case '7': case '8': case '9':
+ digitVal = c - '0';
+ break;
+
+ case 'a': case 'b': case 'c': case 'd': case 'e': case 'f':
+ if(base <= 10) return;
+ digitVal = 10 + c - 'a';
+ break;
+
+ case 'A': case 'B': case 'C': case 'D': case 'E': case 'F':
+ if(base <= 10) return;
+ digitVal = 10 + c - 'A';
+ break;
+
+ default:
+ // Not more digits!
+ return;
+ }
+
+ if(digitVal >= base)
+ {
+ char buffer[] = { (char) c, 0 };
+ lexer->sink->diagnose(lexer->loc, Diagnostics::invalidDigitForBase, buffer, base);
+ }
+
+ advance(lexer);
+ }
+ }
+
+ static TokenType maybeLexNumberSuffix(Lexer* lexer, TokenType tokenType)
+ {
+ // First check for suffixes that
+ // indicate a floating-point number
+ switch(peek(lexer))
+ {
+ case 'f': case 'F':
+ advance(lexer);
+ return TokenType::DoubleLiterial;
+
+ default:
+ break;
+ }
+
+ // Once we've ruled out floating-point
+ // suffixes, we can check for the inter cases
+
+ // TODO: allow integer suffixes in any order...
+
+ // Leading `u` or `U` for unsigned
+ switch(peek(lexer))
+ {
+ default:
+ break;
+
+ case 'u': case 'U':
+ advance(lexer);
+ break;
+ }
+
+ // Optional `l`, `L`, `ll`, or `LL`
+ switch(peek(lexer))
+ {
+ default:
+ break;
+
+ case 'l': case 'L':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ default:
+ break;
+
+ case 'l': case 'L':
+ advance(lexer);
+ break;
+ }
+ break;
+ }
+
+ return tokenType;
+ }
+
+ static bool maybeLexNumberExponent(Lexer* lexer, int base)
+ {
+ switch( peek(lexer) )
+ {
+ default:
+ return false;
+
+ case 'e': case 'E':
+ if(base != 10) return false;
+ advance(lexer);
+ break;
+
+ case 'p': case 'P':
+ if(base != 16) return false;
+ advance(lexer);
+ break;
+ }
+
+ // we saw an exponent marker, so we must
+ switch( peek(lexer) )
+ {
+ case '+': case '-':
+ advance(lexer);
+ break;
+ }
+
+ // TODO(tfoley): it would be an error to not see digits here...
+
+ lexDigits(lexer, 10);
+
+ return true;
+ }
+
+ static TokenType lexNumberAfterDecimalPoint(Lexer* lexer, int base)
+ {
+ lexDigits(lexer, base);
+ maybeLexNumberExponent(lexer, base);
+
+ return maybeLexNumberSuffix(lexer, TokenType::DoubleLiterial);
+ }
+
+ static TokenType lexNumber(Lexer* lexer, int base)
+ {
+ // TODO(tfoley): Need to consider whehter to allow any kind of digit separator character.
+
+ TokenType tokenType = TokenType::IntLiterial;
+
+ // At the start of things, we just concern ourselves with digits
+ lexDigits(lexer, base);
+
+ if( peek(lexer) == '.' )
+ {
+ tokenType = TokenType::DoubleLiterial;
+
+ advance(lexer);
+ lexDigits(lexer, base);
+ }
+
+ if( maybeLexNumberExponent(lexer, base))
+ {
+ tokenType = TokenType::DoubleLiterial;
+ }
+
+ maybeLexNumberSuffix(lexer, tokenType);
+ return tokenType;
+ }
+
+ static void lexStringLiteralBody(Lexer* lexer, char quote)
+ {
+ for(;;)
+ {
+ int c = peek(lexer);
+ if(c == quote)
+ {
+ advance(lexer);
+ return;
+ }
+
+ switch(c)
+ {
+ case kEOF:
+ lexer->sink->diagnose(lexer->loc, Diagnostics::endOfFileInLiteral);
+ return;
+
+ case '\n': case '\r':
+ lexer->sink->diagnose(lexer->loc, Diagnostics::newlineInLiteral);
+ return;
+
+ case '\\':
+ // Need to handle various escape sequence cases
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '\'':
+ case '\"':
+ case '\\':
+ case '?':
+ case 'a':
+ case 'b':
+ case 'f':
+ case 'n':
+ case 'r':
+ case 't':
+ case 'v':
+ advance(lexer);
+ break;
+
+ case '0': case '1': case '2': case '3': case '4':
+ case '5': case '6': case '7':
+ // octal escape: up to 3 characters
+ advance(lexer);
+ for(int ii = 0; ii < 3; ++ii)
+ {
+ int d = peek(lexer);
+ if(('0' <= d) && (d <= '7'))
+ {
+ advance(lexer);
+ continue;
+ }
+ else
+ {
+ break;
+ }
+ }
+ break;
+
+ case 'x':
+ // hexadecimal escape: any number of characters
+ advance(lexer);
+ for(;;)
+ {
+ int d = peek(lexer);
+ if(('0' <= d) && (d <= '9')
+ || ('a' <= d) && (d <= 'f')
+ || ('A' <= d) && (d <= 'F'))
+ {
+ advance(lexer);
+ continue;
+ }
+ else
+ {
+ break;
+ }
+ }
+ break;
+
+ // TODO: Unicode escape sequences
+
+ }
+ break;
+
+ default:
+ advance(lexer);
+ continue;
+ }
+ }
+ }
+
+ String getStringLiteralTokenValue(Token const& token)
+ {
+ assert(token.Type == TokenType::StringLiterial
+ || token.Type == TokenType::CharLiterial);
+
+ char const* cursor = token.Content.begin();
+ char const* end = token.Content.end();
+
+ auto quote = *cursor++;
+ assert(quote == '\'' || quote == '"');
+
+ StringBuilder valueBuilder;
+ for(;;)
+ {
+ assert(cursor != end);
+
+ auto c = *cursor++;
+
+ // If we see a closing quote, then we are at the end of the string literal
+ if(c == quote)
+ {
+ assert(cursor == end);
+ return valueBuilder.ProduceString();
+ }
+
+ // Charcters that don't being escape sequences are easy;
+ // just append them to the buffer and move on.
+ if(c != '\\')
+ {
+ valueBuilder.Append(c);
+ continue;
+ }
+
+ // Now we look at another character to figure out the kind of
+ // escape sequence we are dealing with:
+
+ int d = *cursor++;
+
+ switch(d)
+ {
+ // Simple characters that just needed to be escaped
+ case '\'':
+ case '\"':
+ case '\\':
+ case '?':
+ valueBuilder.Append(d);
+ continue;
+
+ // Traditional escape sequences for special characters
+ case 'a': valueBuilder.Append('\a'); continue;
+ case 'b': valueBuilder.Append('\b'); continue;
+ case 'f': valueBuilder.Append('\f'); continue;
+ case 'n': valueBuilder.Append('\n'); continue;
+ case 'r': valueBuilder.Append('\r'); continue;
+ case 't': valueBuilder.Append('\t'); continue;
+ case 'v': valueBuilder.Append('\v'); continue;
+
+ // Octal escape: up to 3 characterws
+ case '0': case '1': case '2': case '3': case '4':
+ case '5': case '6': case '7':
+ {
+ cursor--;
+ int value = 0;
+ for(int ii = 0; ii < 3; ++ii)
+ {
+ d = *cursor;
+ if(('0' <= d) && (d <= '7'))
+ {
+ value = value*8 + (d - '0');
+
+ cursor++;
+ continue;
+ }
+ else
+ {
+ break;
+ }
+ }
+
+ // TODO: add support for appending an arbitrary code point?
+ valueBuilder.Append((char) value);
+ }
+ continue;
+
+ // Hexadecimal escape: any number of characters
+ case 'x':
+ {
+ cursor--;
+ int value = 0;
+ for(;;)
+ {
+ d = *cursor++;
+ int digitValue = 0;
+ if(('0' <= d) && (d <= '9'))
+ {
+ digitValue = d - '0';
+ }
+ else if( ('a' <= d) && (d <= 'f') )
+ {
+ digitValue = d - 'a';
+ }
+ else if( ('A' <= d) && (d <= 'F') )
+ {
+ digitValue = d - 'A';
+ }
+ else
+ {
+ cursor--;
+ break;
+ }
+
+ value = value*16 + digitValue;
+ }
+
+ // TODO: add support for appending an arbitrary code point?
+ valueBuilder.Append((char) value);
+ }
+ continue;
+
+ // TODO: Unicode escape sequences
+
+ }
+ }
+ }
+
+ String getFileNameTokenValue(Token const& token)
+ {
+ // A file name usually doesn't process escape sequences
+ // (this is import on Windows, where `\\` is a valid
+ // path separator cahracter).
+
+ // Just trim off the first and last characters to remove the quotes
+ // (whether they were `""` or `<>`.
+ return token.Content.SubString(1, token.Content.Length()-2);
+ }
+
+
+
+ static TokenType lexTokenImpl(Lexer* lexer)
+ {
+ switch(peek(lexer))
+ {
+ default:
+ break;
+
+ case kEOF:
+ if((lexer->lexerFlags & kLexerFlag_InDirective) != 0)
+ return TokenType::EndOfDirective;
+ return TokenType::EndOfFile;
+
+ case '\r': case '\n':
+ if((lexer->lexerFlags & kLexerFlag_InDirective) != 0)
+ return TokenType::EndOfDirective;
+ handleNewLine(lexer);
+ return TokenType::NewLine;
+
+ case ' ': case '\t':
+ lexHorizontalSpace(lexer);
+ return TokenType::WhiteSpace;
+
+ case '.':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '0': case '1': case '2': case '3': case '4':
+ case '5': case '6': case '7': case '8': case '9':
+ return lexNumberAfterDecimalPoint(lexer, 10);
+
+ // TODO(tfoley): handle ellipsis (`...`)
+
+ default:
+ return TokenType::Dot;
+ }
+
+ case '1': case '2': case '3': case '4':
+ case '5': case '6': case '7': case '8': case '9':
+ return lexNumber(lexer, 10);
+
+ case '0':
+ {
+ auto loc = lexer->loc;
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ default:
+ return TokenType::IntLiterial;
+
+ case '.':
+ advance(lexer);
+ return lexNumberAfterDecimalPoint(lexer, 10);
+
+ case 'x': case 'X':
+ advance(lexer);
+ return lexNumber(lexer, 16);
+
+ case 'b': case 'B':
+ advance(lexer);
+ return lexNumber(lexer, 2);
+
+ case '0': case '1': case '2': case '3': case '4':
+ case '5': case '6': case '7': case '8': case '9':
+ lexer->sink->diagnose(loc, Diagnostics::octalLiteral);
+ return lexNumber(lexer, 8);
+ }
+ }
+
+ case 'a': case 'b': case 'c': case 'd': case 'e':
+ case 'f': case 'g': case 'h': case 'i': case 'j':
+ case 'k': case 'l': case 'm': case 'n': case 'o':
+ case 'p': case 'q': case 'r': case 's': case 't':
+ case 'u': case 'v': case 'w': case 'x': case 'y':
+ case 'z':
+ case 'A': case 'B': case 'C': case 'D': case 'E':
+ case 'F': case 'G': case 'H': case 'I': case 'J':
+ case 'K': case 'L': case 'M': case 'N': case 'O':
+ case 'P': case 'Q': case 'R': case 'S': case 'T':
+ case 'U': case 'V': case 'W': case 'X': case 'Y':
+ case 'Z':
+ case '_':
+ lexIdentifier(lexer);
+ return TokenType::Identifier;
+
+ case '\"':
+ advance(lexer);
+ lexStringLiteralBody(lexer, '\"');
+ return TokenType::StringLiterial;
+
+ case '\'':
+ advance(lexer);
+ lexStringLiteralBody(lexer, '\'');
+ return TokenType::CharLiterial;
+
+ case '+':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '+': advance(lexer); return TokenType::OpInc;
+ case '=': advance(lexer); return TokenType::OpAddAssign;
+ default:
+ return TokenType::OpAdd;
+ }
+
+ case '-':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '-': advance(lexer); return TokenType::OpDec;
+ case '=': advance(lexer); return TokenType::OpSubAssign;
+ case '>': advance(lexer); return TokenType::RightArrow;
+ default:
+ return TokenType::OpSub;
+ }
+
+ case '*':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '=': advance(lexer); return TokenType::OpMulAssign;
+ default:
+ return TokenType::OpMul;
+ }
+
+ case '/':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '=': advance(lexer); return TokenType::OpDivAssign;
+ case '/': advance(lexer); lexLineComment(lexer); return TokenType::LineComment;
+ case '*': advance(lexer); lexBlockComment(lexer); return TokenType::BlockComment;
+ default:
+ return TokenType::OpDiv;
+ }
+
+ case '%':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '=': advance(lexer); return TokenType::OpModAssign;
+ default:
+ return TokenType::OpMod;
+ }
+
+ case '|':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '|': advance(lexer); return TokenType::OpOr;
+ case '=': advance(lexer); return TokenType::OpOrAssign;
+ default:
+ return TokenType::OpBitOr;
+ }
+
+ case '&':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '&': advance(lexer); return TokenType::OpAnd;
+ case '=': advance(lexer); return TokenType::OpAndAssign;
+ default:
+ return TokenType::OpBitAnd;
+ }
+
+ case '^':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '=': advance(lexer); return TokenType::OpXorAssign;
+ default:
+ return TokenType::OpBitXor;
+ }
+
+ case '>':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '>':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '=': advance(lexer); return TokenType::OpShrAssign;
+ default: return TokenType::OpRsh;
+ }
+ case '=': advance(lexer); return TokenType::OpGeq;
+ default:
+ return TokenType::OpGreater;
+ }
+
+ case '<':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '<':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '=': advance(lexer); return TokenType::OpShlAssign;
+ default: return TokenType::OpLsh;
+ }
+ case '=': advance(lexer); return TokenType::OpLeq;
+ default:
+ return TokenType::OpLess;
+ }
+
+ case '=':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '=': advance(lexer); return TokenType::OpEql;
+ default:
+ return TokenType::OpAssign;
+ }
+
+ case '!':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '=': advance(lexer); return TokenType::OpNeq;
+ default:
+ return TokenType::OpNot;
+ }
+
+ case '#':
+ advance(lexer);
+ switch(peek(lexer))
+ {
+ case '#': advance(lexer); return TokenType::PoundPound;
+ default:
+ return TokenType::Pound;
+ }
+
+ case '~': advance(lexer); return TokenType::OpBitNot;
+
+ case ':': advance(lexer); return TokenType::Colon;
+ case ';': advance(lexer); return TokenType::Semicolon;
+ case ',': advance(lexer); return TokenType::Comma;
+
+ case '{': advance(lexer); return TokenType::LBrace;
+ case '}': advance(lexer); return TokenType::RBrace;
+ case '[': advance(lexer); return TokenType::LBracket;
+ case ']': advance(lexer); return TokenType::RBracket;
+ case '(': advance(lexer); return TokenType::LParent;
+ case ')': advance(lexer); return TokenType::RParent;
+
+ case '?': advance(lexer); return TokenType::QuestionMark;
+ case '@': advance(lexer); return TokenType::At;
+ case '$': advance(lexer); return TokenType::Dollar;
+
+ }
+
+ // TODO(tfoley): If we ever wanted to support proper Unicode
+ // in identifiers, etc., then this would be the right place
+ // to perform a more expensive dispatch based on the actual
+ // code point (and not just the first byte).
+
+ {
+ // If none of the above cases matched, then we have an
+ // unexpected/invalid character.
+
+ auto loc = lexer->loc;
+ auto sink = lexer->sink;
+ int c = advance(lexer);
+ if(c >= 0x20 && c <= 0x7E)
+ {
+ char buffer[] = { (char) c, 0 };
+ sink->diagnose(loc, Diagnostics::illegalCharacterPrint, buffer);
+ }
+ else
+ {
+ // Fallback: print as hexadecimal
+ sink->diagnose(loc, Diagnostics::illegalCharacterHex, String((unsigned char)c, 16));
+ }
+
+ return TokenType::Invalid;
+ }
+ }
+
+ Token Lexer::lexToken()
+ {
+ auto flags = this->tokenFlags;
+ for(;;)
+ {
+ Token token;
+ token.Position = loc;
+
+ char const* textBegin = cursor;
+
+ auto tokenType = lexTokenImpl(this);
+
+ // The low-level lexer produces tokens for things we want
+ // to ignore, such as white space, so we skip them here.
+ switch(tokenType)
+ {
+ case TokenType::Invalid:
+ flags = 0;
+ continue;
+
+ case TokenType::NewLine:
+ flags = TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace;
+ continue;
+
+ case TokenType::WhiteSpace:
+ case TokenType::LineComment:
+ case TokenType::BlockComment:
+ flags |= TokenFlag::AfterWhitespace;
+ continue;
+
+ // We don't want to skip the end-of-file token, but we *do*
+ // want to make sure it has appropriate flags to make our life easier
+ case TokenType::EndOfFile:
+ flags = TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace;
+ break;
+
+ // We will also do some book-keeping around preprocessor directives here:
+ //
+ // If we see a `#` at the start of a line, then we are entering a
+ // preprocessor directive.
+ case TokenType::Pound:
+ if((flags & TokenFlag::AtStartOfLine) != 0)
+ lexerFlags |= kLexerFlag_InDirective;
+ break;
+ //
+ // And if we saw an end-of-line during a directive, then we are
+ // now leaving that directive.
+ //
+ case TokenType::EndOfDirective:
+ lexerFlags &= ~kLexerFlag_InDirective;
+ break;
+
+ default:
+ break;
+ }
+
+ token.Type = tokenType;
+
+ char const* textEnd = cursor;
+
+ // Note(tfoley): `StringBuilder::Append()` seems to crash when appending zero bytes
+ if(textEnd != textBegin)
+ {
+ StringBuilder valueBuilder;
+ valueBuilder.Append(textBegin, int(textEnd - textBegin));
+ token.Content = valueBuilder.ProduceString();
+ }
+
+ token.flags = flags;
+
+ this->tokenFlags = 0;
+
+ return token;
+ }
+ }
+
+ TokenList Lexer::lexAllTokens()
+ {
+ TokenList tokenList;
+ for(;;)
+ {
+ Token token = lexToken();
+ tokenList.mTokens.Add(token);
+
+ if(token.Type == TokenType::EndOfFile)
+ return tokenList;
+ }
+ }
+
+
+
+#if 0
+ TokenList Lexer::Parse(const String & fileName, const String & str, DiagnosticSink * sink)
+ {
+ TokenList tokenList;
+ tokenList.mTokens = TokenizeText(fileName, str, [&](TokenizeErrorType errType, CodePosition pos)
+ {
+ auto curChar = str[pos.Pos];
+ switch (errType)
+ {
+ case TokenizeErrorType::InvalidCharacter:
+ // Check if inside the ASCII "printable" range
+ if(curChar >= 0x20 && curChar <= 0x7E)
+ {
+ char buffer[] = { curChar, 0 };
+ sink->diagnose(pos, Diagnostics::illegalCharacterPrint, buffer);
+ }
+ else
+ {
+ // Fallback: print as hexadecimal
+ sink->diagnose(pos, Diagnostics::illegalCharacterHex, String((unsigned char)curChar, 16));
+ }
+ break;
+ case TokenizeErrorType::InvalidEscapeSequence:
+ sink->diagnose(pos, Diagnostics::illegalCharacterLiteral);
+ break;
+ default:
+ break;
+ }
+ });
+
+ // Add an end-of-file token so that we can reference it in diagnostic messages
+ tokenList.mTokens.Add(Token(TokenType::EndOfFile, "", 0, 0, 0, fileName, TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace));
+ return tokenList;
+ }
+#endif
+ }
+} \ No newline at end of file
diff --git a/source/slang/lexer.h b/source/slang/lexer.h
new file mode 100644
index 000000000..d11e92d84
--- /dev/null
+++ b/source/slang/lexer.h
@@ -0,0 +1,101 @@
+#ifndef RASTER_RENDERER_LEXER_H
+#define RASTER_RENDERER_LEXER_H
+
+#include "../core/basic.h"
+#include "diagnostics.h"
+
+namespace Slang
+{
+ namespace Compiler
+ {
+ using namespace CoreLib::Basic;
+
+ struct TokenList
+ {
+ Token* begin() const;
+ Token* end() const;
+
+ List<Token> mTokens;
+ };
+
+ struct TokenSpan
+ {
+ TokenSpan();
+ TokenSpan(
+ TokenList const& tokenList)
+ : mBegin(tokenList.begin())
+ , mEnd (tokenList.end ())
+ {}
+
+ Token* begin() const { return mBegin; }
+ Token* end () const { return mEnd ; }
+
+ int GetCount() { return (int)(mEnd - mBegin); }
+
+ Token* mBegin;
+ Token* mEnd;
+ };
+
+ struct TokenReader
+ {
+ TokenReader();
+ explicit TokenReader(TokenSpan const& tokens)
+ : mCursor(tokens.begin())
+ , mEnd (tokens.end ())
+ {}
+ explicit TokenReader(TokenList const& tokens)
+ : mCursor(tokens.begin())
+ , mEnd (tokens.end ())
+ {}
+
+ bool IsAtEnd() const { return mCursor == mEnd; }
+ Token PeekToken() const;
+ TokenType PeekTokenType() const;
+ CodePosition PeekLoc() const;
+
+ Token AdvanceToken();
+
+ int GetCount() { return (int)(mEnd - mCursor); }
+
+ Token* mCursor;
+ Token* mEnd;
+ };
+
+ typedef unsigned int LexerFlags;
+ enum
+ {
+ kLexerFlag_InDirective = 1 << 0,
+ kLexerFlag_ExpectFileName = 2 << 0,
+ };
+
+ struct Lexer
+ {
+ Lexer(
+ String const& path,
+ String const& content,
+ DiagnosticSink* sink);
+
+ ~Lexer();
+
+ Token lexToken();
+
+ TokenList lexAllTokens();
+
+ String path;
+ String content;
+ DiagnosticSink* sink;
+
+ char const* cursor;
+ char const* end;
+ CodePosition loc;
+ TokenFlags tokenFlags;
+ LexerFlags lexerFlags;
+ };
+
+ // Helper routines for extracting values from tokens
+ String getStringLiteralTokenValue(Token const& token);
+ String getFileNameTokenValue(Token const& token);
+ }
+}
+
+#endif \ No newline at end of file
diff --git a/source/slang/lookup.cpp b/source/slang/lookup.cpp
new file mode 100644
index 000000000..9731b1c8a
--- /dev/null
+++ b/source/slang/lookup.cpp
@@ -0,0 +1,311 @@
+// lookup.cpp
+#include "lookup.h"
+
+namespace Slang {
+namespace Compiler {
+
+//
+
+// Helper for constructing breadcrumb trails during lookup, without unnecessary heap allocaiton
+struct BreadcrumbInfo
+{
+ LookupResultItem::Breadcrumb::Kind kind;
+ DeclRef declRef;
+ BreadcrumbInfo* prev = nullptr;
+};
+
+void DoLocalLookupImpl(
+ String const& name,
+ ContainerDeclRef containerDeclRef,
+ LookupRequest const& request,
+ LookupResult& result,
+ BreadcrumbInfo* inBreadcrumbs);
+
+//
+
+void buildMemberDictionary(ContainerDecl* decl)
+{
+ // Don't rebuild if already built
+ if (decl->memberDictionaryIsValid)
+ return;
+
+ decl->memberDictionary.Clear();
+ decl->transparentMembers.Clear();
+
+ for (auto m : decl->Members)
+ {
+ auto name = m->Name.Content;
+
+ // Add any transparent members to a separate list for lookup
+ if (m->HasModifier<TransparentModifier>())
+ {
+ TransparentMemberInfo info;
+ info.decl = m.Ptr();
+ decl->transparentMembers.Add(info);
+ }
+
+ // Ignore members with an empty name
+ if (name.Length() == 0)
+ continue;
+
+ m->nextInContainerWithSameName = nullptr;
+
+ Decl* next = nullptr;
+ if (decl->memberDictionary.TryGetValue(name, next))
+ m->nextInContainerWithSameName = next;
+
+ decl->memberDictionary[name] = m.Ptr();
+
+ }
+ decl->memberDictionaryIsValid = true;
+}
+
+
+bool DeclPassesLookupMask(Decl* decl, LookupMask mask)
+{
+ // type declarations
+ if(auto aggTypeDecl = dynamic_cast<AggTypeDecl*>(decl))
+ {
+ return int(mask) & int(LookupMask::Type);
+ }
+ else if(auto simpleTypeDecl = dynamic_cast<SimpleTypeDecl*>(decl))
+ {
+ return int(mask) & int(LookupMask::Type);
+ }
+ // function declarations
+ else if(auto funcDecl = dynamic_cast<FunctionDeclBase*>(decl))
+ {
+ return (int(mask) & int(LookupMask::Function)) != 0;
+ }
+
+ // default behavior is to assume a value declaration
+ // (no overloading allowed)
+
+ return (int(mask) & int(LookupMask::Value)) != 0;
+}
+
+void AddToLookupResult(
+ LookupResult& result,
+ LookupResultItem item)
+{
+ if (!result.isValid())
+ {
+ // If we hadn't found a hit before, we have one now
+ result.item = item;
+ }
+ else if (!result.isOverloaded())
+ {
+ // We are about to make this overloaded
+ result.items.Add(result.item);
+ result.items.Add(item);
+ }
+ else
+ {
+ // The result was already overloaded, so we pile on
+ result.items.Add(item);
+ }
+}
+
+LookupResult refineLookup(LookupResult const& inResult, LookupMask mask)
+{
+ if (!inResult.isValid()) return inResult;
+ if (!inResult.isOverloaded()) return inResult;
+
+ LookupResult result;
+ for (auto item : inResult.items)
+ {
+ if (!DeclPassesLookupMask(item.declRef.GetDecl(), mask))
+ continue;
+
+ AddToLookupResult(result, item);
+ }
+ return result;
+}
+
+LookupResultItem CreateLookupResultItem(
+ DeclRef declRef,
+ BreadcrumbInfo* breadcrumbInfos)
+{
+ LookupResultItem item;
+ item.declRef = declRef;
+
+ // breadcrumbs were constructed "backwards" on the stack, so we
+ // reverse them here by building a linked list the other way
+ RefPtr<LookupResultItem::Breadcrumb> breadcrumbs;
+ for (auto bb = breadcrumbInfos; bb; bb = bb->prev)
+ {
+ breadcrumbs = new LookupResultItem::Breadcrumb(
+ bb->kind,
+ bb->declRef,
+ breadcrumbs);
+ }
+ item.breadcrumbs = breadcrumbs;
+ return item;
+}
+
+void DoMemberLookupImpl(
+ String const& name,
+ RefPtr<ExpressionType> baseType,
+ LookupRequest const& request,
+ LookupResult& ioResult,
+ BreadcrumbInfo* breadcrumbs)
+{
+ // If the type was pointer-like, then dereference it
+ // automatically here.
+ if (auto pointerLikeType = baseType->As<PointerLikeType>())
+ {
+ // Need to leave a breadcrumb to indicate that we
+ // did an implicit dereference here
+ BreadcrumbInfo derefBreacrumb;
+ derefBreacrumb.kind = LookupResultItem::Breadcrumb::Kind::Deref;
+ derefBreacrumb.prev = breadcrumbs;
+
+ // Recursively perform lookup on the result of deref
+ return DoMemberLookupImpl(name, pointerLikeType->elementType, request, ioResult, &derefBreacrumb);
+ }
+
+ // Default case: no dereference needed
+
+ if (auto baseDeclRefType = baseType->As<DeclRefType>())
+ {
+ if (auto baseAggTypeDeclRef = baseDeclRefType->declRef.As<AggTypeDeclRef>())
+ {
+ DoLocalLookupImpl(name, baseAggTypeDeclRef, request, ioResult, breadcrumbs);
+ }
+ }
+
+ // TODO(tfoley): any other cases to handle here?
+}
+
+void DoMemberLookupImpl(
+ String const& name,
+ DeclRef baseDeclRef,
+ LookupRequest const& request,
+ LookupResult& ioResult,
+ BreadcrumbInfo* breadcrumbs)
+{
+ auto baseType = getTypeForDeclRef(baseDeclRef);
+ return DoMemberLookupImpl(name, baseType, request, ioResult, breadcrumbs);
+}
+
+// Look for members of the given name in the given container for declarations
+void DoLocalLookupImpl(
+ String const& name,
+ ContainerDeclRef containerDeclRef,
+ LookupRequest const& request,
+ LookupResult& result,
+ BreadcrumbInfo* inBreadcrumbs)
+{
+ ContainerDecl* containerDecl = containerDeclRef.GetDecl();
+
+ // Ensure that the lookup dictionary in the container is up to date
+ if (!containerDecl->memberDictionaryIsValid)
+ {
+ buildMemberDictionary(containerDecl);
+ }
+
+ // Look up the declarations with the chosen name in the container.
+ Decl* firstDecl = nullptr;
+ containerDecl->memberDictionary.TryGetValue(name, firstDecl);
+
+ // Now iterate over those declarations (if any) and see if
+ // we find any that meet our filtering criteria.
+ // For example, we might be filtering so that we only consider
+ // type declarations.
+ for (auto m = firstDecl; m; m = m->nextInContainerWithSameName)
+ {
+ if (!DeclPassesLookupMask(m, request.mask))
+ continue;
+
+ // The declaration passed the test, so add it!
+ AddToLookupResult(result, CreateLookupResultItem(DeclRef(m, containerDeclRef.substitutions), inBreadcrumbs));
+ }
+
+
+ // TODO(tfoley): should we look up in the transparent decls
+ // if we already has a hit in the current container?
+
+ for(auto transparentInfo : containerDecl->transparentMembers)
+ {
+ // The reference to the transparent member should use whatever
+ // substitutions we used in referring to its outer container
+ DeclRef transparentMemberDeclRef(transparentInfo.decl, containerDeclRef.substitutions);
+
+ // We need to leave a breadcrumb so that we know that the result
+ // of lookup involves a member lookup step here
+
+ BreadcrumbInfo memberRefBreadcrumb;
+ memberRefBreadcrumb.kind = LookupResultItem::Breadcrumb::Kind::Member;
+ memberRefBreadcrumb.declRef = transparentMemberDeclRef;
+ memberRefBreadcrumb.prev = inBreadcrumbs;
+
+ DoMemberLookupImpl(name, transparentMemberDeclRef, request, result, &memberRefBreadcrumb);
+ }
+
+ // TODO(tfoley): need to consider lookup via extension here?
+}
+
+void DoLookupImpl(
+ String const& name,
+ LookupRequest const& request,
+ LookupResult& result)
+{
+ auto scope = request.scope;
+ auto endScope = request.endScope;
+ for (;scope != endScope; scope = scope->parent)
+ {
+ // Note that we consider all "peer" scopes together,
+ // so that a hit in one of them does not proclude
+ // also finding a hit in another
+ for(auto link = scope; link; link = link->nextSibling)
+ {
+ if(!link->containerDecl)
+ continue;
+
+ ContainerDeclRef containerRef = DeclRef(link->containerDecl, nullptr).As<ContainerDeclRef>();
+ DoLocalLookupImpl(name, containerRef, request, result, nullptr);
+ }
+
+ if (result.isValid())
+ {
+ // If we've found a result in this scope, then there
+ // is no reason to look further up (for now).
+ return;
+ }
+ }
+
+ // If we run out of scopes, then we are done.
+}
+
+LookupResult DoLookup(String const& name, LookupRequest const& request)
+{
+ LookupResult result;
+ DoLookupImpl(name, request, result);
+ return result;
+}
+
+LookupResult LookUp(String const& name, RefPtr<Scope> scope)
+{
+ LookupRequest request;
+ request.scope = scope;
+ return DoLookup(name, request);
+}
+
+// perform lookup within the context of a particular container declaration,
+// and do *not* look further up the chain
+LookupResult LookUpLocal(String const& name, ContainerDeclRef containerDeclRef)
+{
+ LookupRequest request;
+ LookupResult result;
+ DoLocalLookupImpl(name, containerDeclRef, request, result, nullptr);
+ return result;
+}
+
+LookupResult LookUpLocal(String const& name, ContainerDecl* containerDecl)
+{
+ ContainerDeclRef containerRef = DeclRef(containerDecl, nullptr).As<ContainerDeclRef>();
+ return LookUpLocal(name, containerRef);
+}
+
+
+}}
diff --git a/source/slang/lookup.h b/source/slang/lookup.h
new file mode 100644
index 000000000..25b62738f
--- /dev/null
+++ b/source/slang/lookup.h
@@ -0,0 +1,41 @@
+#ifndef SLANG_LOOKUP_H_INCLUDED
+#define SLANG_LOOKUP_H_INCLUDED
+
+#include "Syntax.h"
+
+namespace Slang {
+namespace Compiler {
+
+// Take an existing lookup result and refine it to only include
+// results that pass the given `LookupMask`.
+LookupResult refineLookup(LookupResult const& inResult, LookupMask mask);
+
+// Ensure that the dictionary for name-based member lookup has been
+// built for the given container declaration.
+void buildMemberDictionary(ContainerDecl* decl);
+
+// Look up a name in the given scope, proceeding up through
+// parent scopes as needed.
+LookupResult LookUp(String const& name, RefPtr<Scope> scope);
+
+// perform lookup within the context of a particular container declaration,
+// and do *not* look further up the chain
+LookupResult LookUpLocal(String const& name, ContainerDeclRef containerDeclRef);
+LookupResult LookUpLocal(String const& name, ContainerDecl* containerDecl);
+
+// TODO: this belongs somewhere else
+
+class SemanticsVisitor;
+QualType getTypeForDeclRef(
+ SemanticsVisitor* sema,
+ DiagnosticSink* sink,
+ DeclRef declRef,
+ RefPtr<ExpressionType>* outTypeResult);
+
+QualType getTypeForDeclRef(
+ DeclRef declRef);
+
+
+}}
+
+#endif \ No newline at end of file
diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp
new file mode 100644
index 000000000..8bbb566af
--- /dev/null
+++ b/source/slang/parameter-binding.cpp
@@ -0,0 +1,1252 @@
+// parameter-binding.cpp
+#include "parameter-binding.h"
+
+#include "lookup.h"
+#include "compiler.h"
+#include "type-layout.h"
+
+#include "../../slang.h"
+
+#define SLANG_EXHAUSTIVE_SWITCH() default: assert(!"unexpected"); break;
+
+namespace Slang {
+namespace Compiler {
+
+// Information on ranges of registers already claimed/used
+struct UsedRange
+{
+ int begin;
+ int end;
+};
+bool operator<(UsedRange left, UsedRange right)
+{
+ if (left.begin != right.begin)
+ return left.begin < right.begin;
+ if (left.end != right.end)
+ return left.end < right.end;
+ return false;
+}
+
+struct UsedRanges
+{
+ List<UsedRange> ranges;
+
+ // Add a range to the set, either by extending
+ // an existing range, or by adding a new one...
+ void Add(UsedRange const& range)
+ {
+ for (auto& rr : ranges)
+ {
+ if (rr.begin == range.end)
+ {
+ rr.begin = range.begin;
+ return;
+ }
+ else if (rr.end == range.begin)
+ {
+ rr.end = range.end;
+ return;
+ }
+ }
+ ranges.Add(range);
+ ranges.Sort();
+ }
+
+ void Add(int begin, int end)
+ {
+ UsedRange range;
+ range.begin = begin;
+ range.end = end;
+ Add(range);
+ }
+
+
+ // Try to find space for `count` entries
+ int Allocate(int count)
+ {
+ int begin = 0;
+
+ int rangeCount = ranges.Count();
+ for (int rr = 0; rr < rangeCount; ++rr)
+ {
+ // try to fit in before this range...
+
+ int end = ranges[rr].begin;
+
+ // If there is enough space...
+ if (end >= begin + count)
+ {
+ // ... then claim it and be done
+ Add(begin, begin + count);
+ return begin;
+ }
+
+ // ... otherwise, we need to look at the
+ // space between this range and the next
+ begin = ranges[rr].end;
+ }
+
+ // We've run out of ranges to check, so we
+ // can safely go after the last one!
+ Add(begin, begin + count);
+ return begin;
+ }
+};
+
+struct ParameterBindingInfo
+{
+ size_t space;
+ size_t index;
+ size_t count;
+};
+
+enum
+{
+ kLayoutResourceKindCount = SLANG_PARAMETER_CATEGORY_MIXED,
+};
+
+// Information on a single parameter
+struct ParameterInfo : RefObject
+{
+ // Layout info for the concrete variables that will make up this parameter
+ List<RefPtr<VarLayout>> varLayouts;
+
+ ParameterBindingInfo bindingInfo[kLayoutResourceKindCount];
+
+ // The next parameter that has the same name...
+ ParameterInfo* nextOfSameName;
+
+ ParameterInfo()
+ {
+ // Make sure we aren't claiming any resources yet
+ for( int ii = 0; ii < kLayoutResourceKindCount; ++ii )
+ {
+ bindingInfo[ii].count = 0;
+ }
+ }
+};
+
+// State that is shared during parameter binding,
+// across all translation units
+struct SharedParameterBindingContext
+{
+ LayoutRulesFamilyImpl* defaultLayoutRules;
+
+ // All shader parameters we've discovered so far, and started to lay out...
+ List<RefPtr<ParameterInfo>> parameters;
+
+ // A dictionary to accellerate looking up parameters by name
+ Dictionary<String, ParameterInfo*> mapNameToParameterInfo;
+
+ // The program layout we are trying to construct
+ RefPtr<ProgramLayout> programLayout;
+
+ // The source language we are trying to use
+ SourceLanguage sourceLanguage;
+
+ // Information on what ranges of "registers" have already
+ // been claimed, for each resource type
+ UsedRanges usedResourceRanges[kLayoutResourceKindCount];
+};
+
+// State that might be specific to a single translation unit
+// or event to an entry point.
+struct ParameterBindingContext
+{
+ // All the shared state needs to be available
+ SharedParameterBindingContext* shared;
+
+ // The layout rules to use while computing usage...
+ LayoutRulesFamilyImpl* layoutRules;
+
+ // What stage (if any) are we compiling for?
+ Stage stage;
+};
+
+struct LayoutSemanticInfo
+{
+ LayoutResourceKind kind; // the register kind
+ int space;
+ int index;
+
+ // TODO: need to deal with component-granularity binding...
+};
+
+LayoutSemanticInfo ExtractLayoutSemanticInfo(
+ ParameterBindingContext* /*context*/,
+ HLSLLayoutSemantic* semantic)
+{
+ LayoutSemanticInfo info;
+ info.space = 0;
+ info.index = 0;
+ info.kind = LayoutResourceKind::None;
+
+ auto registerName = semantic->registerName.Content;
+ if (registerName.Length() == 0)
+ return info;
+
+ LayoutResourceKind kind = LayoutResourceKind::None;
+ switch (registerName[0])
+ {
+ case 'b':
+ kind = LayoutResourceKind::ConstantBuffer;
+ break;
+
+ case 't':
+ kind = LayoutResourceKind::ShaderResource;
+ break;
+
+ case 'u':
+ kind = LayoutResourceKind::UnorderedAccess;
+ break;
+
+ case 's':
+ kind = LayoutResourceKind::SamplerState;
+ break;
+
+ default:
+ // TODO: issue an error here!
+ return info;
+ }
+
+ // TODO: need to parse and handle `space` binding
+ int space = 0;
+
+ int index = 0;
+ for (int ii = 1; ii < registerName.Length(); ++ii)
+ {
+ int c = registerName[ii];
+ if (c >= '0' && c <= '9')
+ {
+ index = index * 10 + (c - '0');
+ }
+ else
+ {
+ // TODO: issue an error here!
+ return info;
+ }
+ }
+
+ // TODO: handle component mask part of things...
+
+ info.kind = kind;
+ info.index = index;
+ info.space = space;
+ return info;
+}
+
+static bool doesParameterMatch(
+ ParameterBindingContext* context,
+ RefPtr<VarLayout> varLayout,
+ ParameterInfo* parameterInfo)
+{
+ // TODO: need to implement this eventually
+ return true;
+}
+
+//
+
+// Given a GLSL `layout` modifier, we need to be able to check for
+// a particular sub-argument and extract its value if present.
+template<typename T>
+static bool findLayoutArg(
+ RefPtr<ModifiableSyntaxNode> syntax,
+ int* outVal)
+{
+ for( auto modifier : syntax->GetModifiersOfType<T>() )
+ {
+ *outVal = (int) strtol(modifier->valToken.Content.Buffer(), nullptr, 10);
+ return true;
+ }
+ return false;
+}
+
+template<typename T>
+static bool findLayoutArg(
+ DeclRef declRef,
+ int* outVal)
+{
+ return findLayoutArg<T>(declRef.GetDecl(), outVal);
+}
+
+//
+
+RefPtr<TypeLayout>
+getTypeLayoutForGlobalShaderParameter_GLSL(
+ ParameterBindingContext* context,
+ VarDeclBase* varDecl)
+{
+ auto rules = context->layoutRules;
+ auto type = varDecl->getType();
+
+ // A GLSL shader parameter will be marked with
+ // a qualifier to match the boundary it uses
+ //
+ // In the case of a parameter block, we will have
+ // consumed this qualifier as part of parsing,
+ // so that it won't be present on the declaration
+ // any more. As such we also inspect the type
+ // of the variable.
+
+ // TODO(tfoley): We have multiple variations of
+ // the `uniform` modifier right now, and that
+ // needs to get fixed...
+ if(varDecl->HasModifier<HLSLUniformModifier>() || type->As<ConstantBufferType>())
+ return CreateTypeLayout(type, rules->getConstantBufferRules());
+
+ if(varDecl->HasModifier<GLSLBufferModifier>() || type->As<GLSLShaderStorageBufferType>())
+ return CreateTypeLayout(type, rules->getShaderStorageBufferRules());
+
+ if( varDecl->HasModifier<InModifier>() || type->As<GLSLInputParameterBlockType>())
+ {
+ // Special case to handle "arrayed" shader inputs, as used
+ // for Geometry and Hull input
+ switch( context->stage )
+ {
+ case Stage::Geometry:
+ case Stage::Hull:
+ case Stage::Domain:
+ // Tessellation `patch` variables should stay as written
+ if( !varDecl->HasModifier<GLSLPatchModifier>() )
+ {
+ // Unwrap array type, if prsent
+ if( auto arrayType = type->As<ArrayExpressionType>() )
+ {
+ type = arrayType->BaseType.Ptr();
+ }
+ }
+ break;
+
+ default:
+ break;
+ }
+
+ return CreateTypeLayout(type, rules->getVaryingInputRules());
+ }
+
+ if( varDecl->HasModifier<OutModifier>() || type->As<GLSLOutputParameterBlockType>())
+ {
+ // Special case to handle "arrayed" shader outputs, as used
+ // for Hull Shader output
+ //
+ // Note(tfoley): there is unfortunate code duplication
+ // with the `in` case above.
+ switch( context->stage )
+ {
+ case Stage::Hull:
+ // Tessellation `patch` variables should stay as written
+ if( !varDecl->HasModifier<GLSLPatchModifier>() )
+ {
+ // Unwrap array type, if prsent
+ if( auto arrayType = type->As<ArrayExpressionType>() )
+ {
+ type = arrayType->BaseType.Ptr();
+ }
+ }
+ break;
+
+ default:
+ break;
+ }
+
+ return CreateTypeLayout(type, rules->getVaryingOutputRules());
+ }
+
+ // A `const` global with a `layout(constant_id = ...)` modifier
+ // is a declaration of a specialization constant.
+ if(varDecl->HasModifier<GLSLConstantIDLayoutModifier>())
+ return CreateTypeLayout(type, rules->getSpecializationConstantRules());
+
+ // GLSL says that an "ordinary" global variable
+ // is just a (thread local) global and not a
+ // parameter
+ return nullptr;
+}
+
+RefPtr<TypeLayout>
+getTypeLayoutForGlobalShaderParameter_HLSL(
+ ParameterBindingContext* context,
+ VarDeclBase* varDecl)
+{
+ auto rules = context->layoutRules;
+ auto type = varDecl->getType();
+
+ // HLSL `static` modifier indicates "thread local"
+ if(varDecl->HasModifier<HLSLStaticModifier>())
+ return nullptr;
+
+ // HLSL `groupshared` modifier indicates "thread-group local"
+ if(varDecl->HasModifier<HLSLGroupSharedModifier>())
+ return nullptr;
+
+ // TODO(tfoley): there may be other cases that we need to handle here
+
+ // An "ordinary" global variable is implicitly a uniform
+ // shader parameter.
+ return CreateTypeLayout(type, rules->getConstantBufferRules());
+}
+
+// Determine how to lay out a global variable that might be
+// a shader parameter.
+// Returns `nullptr` if the declaration does not represent
+// a shader parameter.
+
+RefPtr<TypeLayout>
+getTypeLayoutForGlobalShaderParameter(
+ ParameterBindingContext* context,
+ VarDeclBase* varDecl)
+{
+ auto rules = context->layoutRules;
+ switch( context->shared->sourceLanguage )
+ {
+ case SourceLanguage::Slang:
+ case SourceLanguage::HLSL:
+ return getTypeLayoutForGlobalShaderParameter_HLSL(context, varDecl);
+
+ case SourceLanguage::GLSL:
+ return getTypeLayoutForGlobalShaderParameter_GLSL(context, varDecl);
+
+ default:
+ assert(false);
+ return nullptr;
+ }
+}
+
+
+//
+
+
+
+// Collect a single declaration into our set of parameters
+static void collectGlobalScopeParameter(
+ ParameterBindingContext* context,
+ RefPtr<VarDeclBase> varDecl)
+{
+ // We use a single operation to both check whether the
+ // variable represents a shader parameter, and to compute
+ // the layout for that parameter's type.
+ auto typeLayout = getTypeLayoutForGlobalShaderParameter(
+ context,
+ varDecl.Ptr());
+
+ // If we did not find appropriate layout rules, then it
+ // must mean that this global variable is *not* a shader
+ // parameter.
+ if(!typeLayout)
+ return;
+
+ // Now create a variable layout that we can use
+ RefPtr<VarLayout> varLayout = new VarLayout();
+ varLayout->typeLayout = typeLayout;
+ varLayout->varDecl = DeclRef(varDecl.Ptr(), nullptr).As<VarDeclBaseRef>();
+
+ // This declaration may represent the same logical parameter
+ // as a declaration that came from a different translation unit.
+ // If that is the case, we want to re-use the same `VarLayout`
+ // across both parameters.
+ //
+ // First we look for an existing entry matching the name
+ // of this parameter:
+ auto parameterName = varDecl->Name.Content;
+ ParameterInfo* parameterInfo = nullptr;
+ if( context->shared->mapNameToParameterInfo.TryGetValue(parameterName, parameterInfo) )
+ {
+ // If the parameters have the same name, but don't "match" according to some reasonable rules,
+ // then we need to bail out.
+ if( !doesParameterMatch(context, varLayout, parameterInfo) )
+ {
+ parameterInfo = nullptr;
+ }
+ }
+
+ // If we didn't find a matching parameter, then we need to create one here
+ if( !parameterInfo )
+ {
+ parameterInfo = new ParameterInfo();
+ context->shared->parameters.Add(parameterInfo);
+ context->shared->mapNameToParameterInfo.Add(parameterName, parameterInfo);
+ }
+ else
+ {
+ varLayout->flags |= VarLayoutFlag::IsRedeclaration;
+ }
+
+ // Add this variable declaration to the list of declarations for the parameter
+ parameterInfo->varLayouts.Add(varLayout);
+}
+
+static void addExplicitParameterBinding(
+ ParameterBindingContext* context,
+ RefPtr<ParameterInfo> parameterInfo,
+ LayoutSemanticInfo const& semanticInfo,
+ int count)
+{
+ auto kind = semanticInfo.kind;
+
+ auto& bindingInfo = parameterInfo->bindingInfo[(int)kind];
+ if( bindingInfo.count != 0 )
+ {
+ // We already have a binding here, so we want to
+ // confirm that it matches the new one that is
+ // incoming...
+ if( bindingInfo.count != count
+ || bindingInfo.index != semanticInfo.index
+ || bindingInfo.space != semanticInfo.space )
+ {
+ // TODO: diagnose!
+ }
+
+ // TODO(tfoley): `register` semantics can technically be
+ // profile-specific (not sure if anybody uses that)...
+ }
+ else
+ {
+ bindingInfo.count = count;
+ bindingInfo.index = semanticInfo.index;
+ bindingInfo.space = semanticInfo.space;
+
+ // If things are bound in `space0` (the default), then we need
+ // to lay claim to the register range used, so that automatic
+ // assignment doesn't go and use the same registers.
+ if (semanticInfo.space == 0)
+ {
+ context->shared->usedResourceRanges[(int)semanticInfo.kind].Add(
+ semanticInfo.index,
+ semanticInfo.index + count);
+ }
+ }
+}
+
+static void addExplicitParameterBindings_HLSL(
+ ParameterBindingContext* context,
+ RefPtr<ParameterInfo> parameterInfo,
+ RefPtr<VarLayout> varLayout)
+{
+ auto typeLayout = varLayout->typeLayout;
+ auto varDecl = varLayout->varDecl;
+
+ // If the declaration has explicit binding modifiers, then
+ // here is where we want to extract and apply them...
+
+ // Look for HLSL `register` or `packoffset` semantics.
+ for (auto semantic : varDecl.GetDecl()->GetModifiersOfType<HLSLLayoutSemantic>())
+ {
+ // Need to extract the information encoded in the semantic
+ LayoutSemanticInfo semanticInfo = ExtractLayoutSemanticInfo(context, semantic);
+ auto kind = semanticInfo.kind;
+ if (kind == LayoutResourceKind::None)
+ continue;
+
+ // TODO: need to special-case when this is a `c` register binding...
+
+ // Find the appropriate resource-binding information
+ // inside the type, to see if we even use any resources
+ // of the given kind.
+
+ auto typeRes = typeLayout->FindResourceInfo(kind);
+ int count = 0;
+ if (typeRes)
+ {
+ count = (int) typeRes->count;
+ }
+ else
+ {
+ // TODO: warning here!
+ }
+
+ addExplicitParameterBinding(context, parameterInfo, semanticInfo, count);
+ }
+}
+
+static void addExplicitParameterBindings_GLSL(
+ ParameterBindingContext* context,
+ RefPtr<ParameterInfo> parameterInfo,
+ RefPtr<VarLayout> varLayout)
+{
+ auto typeLayout = varLayout->typeLayout;
+ auto varDecl = varLayout->varDecl;
+
+ // The catch in GLSL is that the expected resource type
+ // is implied by the parameter declaration itself, and
+ // the `layout` modifier is only allowed to adjust
+ // the index/offset/etc.
+ //
+
+ TypeLayout::ResourceInfo* resInfo = nullptr;
+ LayoutSemanticInfo semanticInfo;
+ semanticInfo.index = 0;
+ semanticInfo.space = 0;
+ if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::DescriptorTableSlot)) )
+ {
+ // Try to find `binding` and `set`
+ if(!findLayoutArg<GLSLBindingLayoutModifier>(varDecl, &semanticInfo.index))
+ return;
+
+ findLayoutArg<GLSLSetLayoutModifier>(varDecl, &semanticInfo.space);
+ }
+ else if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::VertexInput)) )
+ {
+ // Try to find `location` binding
+ if(!findLayoutArg<GLSLLocationLayoutModifier>(varDecl, &semanticInfo.index))
+ return;
+ }
+ else if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::FragmentOutput)) )
+ {
+ // Try to find `location` binding
+ if(!findLayoutArg<GLSLLocationLayoutModifier>(varDecl, &semanticInfo.index))
+ return;
+ }
+ else if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::SpecializationConstant)) )
+ {
+ // Try to find `constant_id` binding
+ if(!findLayoutArg<GLSLConstantIDLayoutModifier>(varDecl, &semanticInfo.index))
+ return;
+ }
+
+ // If we didn't find any matches, then bail
+ if(!resInfo)
+ return;
+
+ auto kind = resInfo->kind;
+ auto count = resInfo->count;
+ semanticInfo.kind = kind;
+
+ addExplicitParameterBinding(context, parameterInfo, semanticInfo, int(count));
+}
+
+// Given a single parameter, collect whatever information we have on
+// how it has been explicitly bound, which may come from multiple declarations
+void generateParameterBindings(
+ ParameterBindingContext* context,
+ RefPtr<ParameterInfo> parameterInfo)
+{
+ // There must be at least one declaration for the parameter.
+ assert(parameterInfo->varLayouts.Count() != 0);
+
+ // Iterate over all declarations looking for explicit binding information.
+ for( auto& varLayout : parameterInfo->varLayouts )
+ {
+ // Handle HLSL `register` and `packoffset` modifiers
+ addExplicitParameterBindings_HLSL(context, parameterInfo, varLayout);
+
+
+ // Handle GLSL `layout` modifiers
+ addExplicitParameterBindings_GLSL(context, parameterInfo, varLayout);
+ }
+}
+
+// Generate the binding information for a shader parameter.
+static void completeBindingsForParameter(
+ ParameterBindingContext* context,
+ RefPtr<ParameterInfo> parameterInfo)
+{
+ // For any resource kind used by the parameter
+ // we need to update its layout information
+ // to include a binding for that resource kind.
+ //
+ // We will use the first declaration of the parameter as
+ // a stand-in for all the declarations, so it is important
+ // that earlier code has validated that the declarations
+ // "match".
+
+ assert(parameterInfo->varLayouts.Count() != 0);
+ auto firstVarLayout = parameterInfo->varLayouts.First();
+ auto firstTypeLayout = firstVarLayout->typeLayout;
+
+ for(auto typeRes : firstTypeLayout->resourceInfos)
+ {
+ // Did we already apply some explicit binding information
+ // for this resource kind?
+ auto kind = typeRes.kind;
+ auto& bindingInfo = parameterInfo->bindingInfo[(int)kind];
+ if( bindingInfo.count != 0 )
+ {
+ // If things have already been bound, our work is done.
+ continue;
+ }
+
+ auto count = typeRes.count;
+ bindingInfo.count = count;
+ bindingInfo.index = context->shared->usedResourceRanges[(int)kind].Allocate((int) count);
+
+ // For now we only auto-generate bindings in space zero
+ bindingInfo.space = 0;
+ }
+
+ // At this point we should have explicit binding locations chosen for
+ // all the relevant resource kinds, so we can apply these to the
+ // declarations:
+
+ for(auto& varLayout : parameterInfo->varLayouts)
+ {
+ for(auto k = 0; k < kLayoutResourceKindCount; ++k)
+ {
+ auto kind = LayoutResourceKind(k);
+ auto& bindingInfo = parameterInfo->bindingInfo[k];
+
+ // skip resources we aren't consuming
+ if(bindingInfo.count == 0)
+ continue;
+
+ // Add a record to the variable layout
+ auto varRes = varLayout->AddResourceInfo(kind);
+ varRes->space = (int) bindingInfo.space;
+ varRes->index = (int) bindingInfo.index;
+ }
+ }
+}
+
+static void collectGlobalScopeParameters(
+ ParameterBindingContext* context,
+ ProgramSyntaxNode* program)
+{
+ // First enumerate parameters at global scope
+ for( auto decl : program->Members )
+ {
+ // A shader parameter is always a variable,
+ // so skip declarations that aren't variables.
+ auto varDecl = decl.As<VarDeclBase>();
+ if (!varDecl)
+ continue;
+
+ collectGlobalScopeParameter(context, varDecl);
+ }
+
+ // Next, we need to enumerate the parameters of
+ // each entry point (which requires knowing what the
+ // entry points *are*)
+
+ // TODO(tfoley): Entry point functions should be identified
+ // by looking for a generated modifier that is attached
+ // to global-scope function declarations.
+}
+
+struct SimpleSemanticInfo
+{
+ String name;
+ int index;
+};
+
+SimpleSemanticInfo decomposeSimpleSemantic(
+ HLSLSimpleSemantic* semantic)
+{
+ auto composedName = semantic->name.Content;
+
+ // look for a trailing sequence of decimal digits
+ // at the end of the composed name
+ int length = composedName.Length();
+ int indexLoc = length;
+ while( indexLoc > 0 )
+ {
+ auto c = composedName[indexLoc-1];
+ if( c >= '0' && c <= '9' )
+ {
+ indexLoc--;
+ continue;
+ }
+ else
+ {
+ break;
+ }
+ }
+
+ SimpleSemanticInfo info;
+
+ //
+ if( indexLoc == length )
+ {
+ // No index suffix
+ info.name = composedName;
+ info.index = 0;
+ }
+ else
+ {
+ // The name is everything before the digits
+ info.name = composedName.SubString(0, indexLoc);
+ info.index = strtol(composedName.SubString(indexLoc, length - indexLoc).begin(), nullptr, 10);
+ }
+ return info;
+}
+
+enum class EntryPointParameterDirection
+{
+ Input,
+ Output,
+};
+
+struct EntryPointParameterState
+{
+ String* optSemanticName;
+ int* ioSemanticIndex;
+ EntryPointParameterDirection direction;
+ int semanticSlotCount;
+};
+
+static void processSimpleEntryPointInput(
+ ParameterBindingContext* context,
+ RefPtr<ExpressionType> type,
+ EntryPointParameterState const& state)
+{
+ auto optSemanticName = state.optSemanticName;
+ auto semanticIndex = *state.ioSemanticIndex;
+ auto semanticSlotCount = state.semanticSlotCount;
+}
+
+static void processSimpleEntryPointOutput(
+ ParameterBindingContext* context,
+ RefPtr<ExpressionType> type,
+ EntryPointParameterState const& state)
+{
+ auto optSemanticName = state.optSemanticName;
+ auto semanticIndex = *state.ioSemanticIndex;
+ auto semanticSlotCount = state.semanticSlotCount;
+
+ if(!optSemanticName)
+ return;
+
+ auto semanticName = *optSemanticName;
+
+ // Note: I'm just doing something expedient here and detecting `SV_Target`
+ // outputs and claiming the appropriate register range right away.
+ //
+ // TODO: we should really be building up some representation of all of this,
+ // once we've gone to the trouble of looking it all up...
+ if( semanticName.ToLower() == "sv_target" )
+ {
+ context->shared->usedResourceRanges[int(LayoutResourceKind::UnorderedAccess)].Add(semanticIndex, semanticIndex + semanticSlotCount);
+ }
+}
+
+static void processSimpleEntryPointParameter(
+ ParameterBindingContext* context,
+ RefPtr<ExpressionType> type,
+ EntryPointParameterState const& inState,
+ int semanticSlotCount = 1)
+{
+ EntryPointParameterState state = inState;
+ state.semanticSlotCount = semanticSlotCount;
+
+ switch( state.direction )
+ {
+ case EntryPointParameterDirection::Input:
+ processSimpleEntryPointInput(context, type, state);
+ break;
+
+ case EntryPointParameterDirection::Output:
+ processSimpleEntryPointOutput(context, type, state);
+ break;
+
+ SLANG_EXHAUSTIVE_SWITCH()
+ }
+
+ *state.ioSemanticIndex += state.semanticSlotCount;
+}
+
+static void processEntryPointParameter(
+ ParameterBindingContext* context,
+ RefPtr<ExpressionType> type,
+ EntryPointParameterState const& state);
+
+static void processEntryPointParameterWithPossibleSemantic(
+ ParameterBindingContext* context,
+ Decl* declForSemantic,
+ RefPtr<ExpressionType> type,
+ EntryPointParameterState const& state)
+{
+ // If there is no explicit semantic already in effect, *and* we find an explicit
+ // semantic on the associated declaration, then we'll use it.
+ if( !state.optSemanticName )
+ {
+ if( auto semantic = declForSemantic->FindModifier<HLSLSimpleSemantic>() )
+ {
+ auto semanticInfo = decomposeSimpleSemantic(semantic);
+ int semanticIndex = semanticInfo.index;
+
+ EntryPointParameterState subState = state;
+ subState.optSemanticName = &semanticInfo.name;
+ subState.ioSemanticIndex = &semanticIndex;
+
+ processEntryPointParameter(context, type, subState);
+ }
+ }
+
+ // Default case: either there was an explicit semantic in effect already,
+ // *or* we couldn't find an explicit semantic to apply on the given
+ // declaration, so we will just recursive with whatever we have at
+ // the moment.
+ processEntryPointParameter(context, type, state);
+}
+
+
+static void processEntryPointParameter(
+ ParameterBindingContext* context,
+ RefPtr<ExpressionType> type,
+ EntryPointParameterState const& state)
+{
+ // Scalar and vector types are treated as outputs directly
+ if(auto basicType = type->As<BasicExpressionType>())
+ {
+ processSimpleEntryPointParameter(context, basicType, state);
+ }
+ else if(auto basicType = type->As<VectorExpressionType>())
+ {
+ processSimpleEntryPointParameter(context, basicType, state);
+ }
+ // A matrix is processed as if it was an array of rows
+ else if( auto matrixType = type->As<MatrixExpressionType>() )
+ {
+ auto rowCount = GetIntVal(matrixType->getRowCount());
+ processSimpleEntryPointParameter(context, basicType, state, rowCount);
+ }
+ else if( auto arrayType = type->As<ArrayExpressionType>() )
+ {
+ auto elementCount = GetIntVal(arrayType->ArrayLength);
+
+ for( int ii = 0; ii < elementCount; ++ii )
+ {
+ processEntryPointParameter(context, arrayType->BaseType, state);
+ }
+ }
+ // Ignore a bunch of types that don't make sense here...
+ else if(auto textureType = type->As<TextureType>()) {}
+ else if(auto samplerStateType = type->As<SamplerStateType>()) {}
+ else if(auto constantBufferType = type->As<ConstantBufferType>()) {}
+ // Catch declaration-reference types late in the sequence, since
+ // otherwise they will include all of the above cases...
+ else if( auto declRefType = type->As<DeclRefType>() )
+ {
+ auto declRef = declRefType->declRef;
+
+ if (auto structDeclRef = declRef.As<StructDeclRef>())
+ {
+ // Need to recursively walk the fields of the structure now...
+ for( auto field : structDeclRef.GetFields() )
+ {
+ processEntryPointParameterWithPossibleSemantic(
+ context,
+ field.GetDecl(),
+ field.GetType(),
+ state);
+ }
+ }
+ else
+ {
+ assert(!"unimplemented");
+ }
+ }
+ else
+ {
+ assert(!"unimplemented");
+ }
+}
+
+static void collectEntryPointParameters(
+ ParameterBindingContext* context,
+ EntryPointOption const& entryPoint,
+ ProgramSyntaxNode* translationUnitSyntax)
+{
+ // First, look for the entry point with the specified name
+
+ // Make sure we've got a query-able member dictionary
+ buildMemberDictionary(translationUnitSyntax);
+
+ Decl* entryPointDecl;
+ if( !translationUnitSyntax->memberDictionary.TryGetValue(entryPoint.name, entryPointDecl) )
+ {
+ // No such entry point!
+ return;
+ }
+ if( entryPointDecl->nextInContainerWithSameName )
+ {
+ // Not the only decl of that name!
+ return;
+ }
+
+ FunctionSyntaxNode* entryPointFuncDecl = dynamic_cast<FunctionSyntaxNode*>(entryPointDecl);
+ if( !entryPointFuncDecl )
+ {
+ // Not a function!
+ return;
+ }
+
+ // Create the layout object here
+ auto entryPointLayout = new EntryPointLayout();
+ entryPointLayout->profile = entryPoint.profile;
+ entryPointLayout->entryPoint = entryPointFuncDecl;
+
+
+ context->shared->programLayout->entryPoints.Add(entryPointLayout);
+
+ // Okay, we seemingly have an entry-point function, and now we need to collect info on its parameters too
+ //
+ // TODO: Long-term we probably want complete information on all inputs/outputs of an entry point,
+ // but for now we are really just trying to scrape information on fragment outputs, so lets do that:
+ //
+ // TODO: check whether we should enumerate the parameters before the return type, or vice versa
+
+ int defaultSemanticIndex = 0;
+
+ EntryPointParameterState state;
+ state.ioSemanticIndex = &defaultSemanticIndex;
+ state.optSemanticName = nullptr;
+ state.semanticSlotCount = 0;
+
+ for( auto m : entryPointFuncDecl->Members )
+ {
+ auto paramDecl = m.As<VarDeclBase>();
+ if(!paramDecl)
+ continue;
+
+ // We have an entry-point parameter, and need to figure out what to do with it.
+
+ // If it appears to be an input, process it as such.
+ if( paramDecl->HasModifier<InModifier>() || paramDecl->HasModifier<InOutModifier>() || !paramDecl->HasModifier<OutModifier>() )
+ {
+ state.direction = EntryPointParameterDirection::Input;
+
+ processEntryPointParameterWithPossibleSemantic(
+ context,
+ paramDecl.Ptr(),
+ paramDecl->Type.type,
+ state);
+ }
+
+ // If it appears to be an output, process it as such.
+ if(paramDecl->HasModifier<OutModifier>() || paramDecl->HasModifier<InOutModifier>())
+ {
+ state.direction = EntryPointParameterDirection::Output;
+
+ processEntryPointParameterWithPossibleSemantic(
+ context,
+ paramDecl.Ptr(),
+ paramDecl->Type.type,
+ state);
+ }
+ }
+
+ // If we can find an output type for the entry point, then process it as
+ // an output parameter.
+ if( auto resultType = entryPointFuncDecl->ReturnType.type )
+ {
+ state.direction = EntryPointParameterDirection::Output;
+
+ processEntryPointParameterWithPossibleSemantic(
+ context,
+ entryPointFuncDecl,
+ resultType,
+ state);
+ }
+}
+
+// When doing parameter binding for global-scope stuff in GLSL,
+// we may need to know what stage we are compiling for, so that
+// we can handle special cases appropriately (e.g., "arrayed"
+// inputs and outputs).
+static Stage
+inferStageForTranslationUnit(
+ CompileUnit const& translationUnit)
+{
+ // In the specific case where we are compiling GLSL input,
+ // and have only a single entry point, use the stage
+ // of the entry point.
+ //
+ // TODO: can we generalize this at all?
+ if( translationUnit.options.sourceLanguage == SourceLanguage::GLSL )
+ {
+ if( translationUnit.options.entryPoints.Count() == 1 )
+ {
+ return translationUnit.options.entryPoints[0].profile.GetStage();
+ }
+ }
+
+ return Stage::Unknown;
+}
+
+static void collectParameters(
+ ParameterBindingContext* inContext,
+ CollectionOfTranslationUnits* program)
+{
+ ParameterBindingContext contextData = *inContext;
+ auto context = &contextData;
+
+ for( auto& translationUnit : program->translationUnits )
+ {
+ context->stage = inferStageForTranslationUnit(translationUnit);
+
+ // First look at global-scope parameters
+ collectGlobalScopeParameters(context, translationUnit.SyntaxNode.Ptr());
+
+ // Next consider parameters for entry points
+ for( auto& entryPoint : translationUnit.options.entryPoints )
+ {
+ context->stage = entryPoint.profile.GetStage();
+ collectEntryPointParameters(context, entryPoint, translationUnit.SyntaxNode.Ptr());
+ }
+ }
+}
+
+void GenerateParameterBindings(
+ CollectionOfTranslationUnits* program)
+{
+ // TODO: infer a language or set of language rules to use based on the
+ // source files and entry points given
+ auto language = SourceLanguage::Unknown;
+ for( auto& translationUnit : program->translationUnits )
+ {
+ auto translationUnitLanguage = translationUnit.options.sourceLanguage;
+ if( language == SourceLanguage::Unknown )
+ {
+ language = translationUnitLanguage;
+ }
+ else if( language == translationUnitLanguage )
+ {
+ // same language: nothing to do...
+ }
+ else
+ {
+ // mismatch!
+ // TODO(tfoley): emit a diagnostic
+ }
+ }
+
+ // TODO(tfoley): We should really be picking layout rules
+ // based on the *target* language, and not the source...
+ auto rules = GetLayoutRulesFamilyImpl(language);
+ assert(rules);
+
+ RefPtr<ProgramLayout> programLayout = new ProgramLayout;
+
+ // Create a context to hold shared state during the process
+ // of generating parameter bindings
+ SharedParameterBindingContext sharedContext;
+ sharedContext.defaultLayoutRules = rules;
+ sharedContext.programLayout = programLayout;
+ sharedContext.sourceLanguage = language;
+
+ // Create a sub-context to collect parameters that get
+ // declared into the global scope
+ ParameterBindingContext context;
+ context.shared = &sharedContext;
+ context.layoutRules = sharedContext.defaultLayoutRules;
+
+ // Walk through AST to discover all the parameters
+ collectParameters(&context, program);
+
+ // Now walk through the parameters to generate initial binding information
+ for( auto& parameter : sharedContext.parameters )
+ {
+ generateParameterBindings(&context, parameter);
+ }
+
+ bool anyGlobalUniforms = false;
+ for( auto& parameterInfo : sharedContext.parameters )
+ {
+ assert(parameterInfo->varLayouts.Count() != 0);
+ auto firstVarLayout = parameterInfo->varLayouts.First();
+
+ // Does the field have any uniform data?
+ if( firstVarLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform) )
+ {
+ anyGlobalUniforms = true;
+ break;
+ }
+ }
+
+ // If there are any global-scope uniforms, then we need to
+ // allocate a constant-buffer binding for them here.
+ ParameterBindingInfo globalConstantBufferBinding;
+ if( anyGlobalUniforms )
+ {
+ globalConstantBufferBinding.index =
+ context.shared->usedResourceRanges[
+ (int)LayoutResourceKind::ConstantBuffer].Allocate(1);
+
+ // For now we only auto-generate bindings in space zero
+ globalConstantBufferBinding.space = 0;
+ }
+
+
+ // Now walk through again to actually give everything
+ // ranges of registers...
+ for( auto& parameter : sharedContext.parameters )
+ {
+ completeBindingsForParameter(&context, parameter);
+ }
+
+ // TODO: need to deal with parameters declared inside entry-point
+ // parameter lists at some point...
+
+
+ // Next we need to create a type layout to reflect the information
+ // we have collected.
+
+ // We will lay out any bare uniforms at the global scope into
+ // a single constant buffer. This is appropriate for HLSL global-scope
+ // uniforms, and Vulkan GLSL doesn't allow uniforms at global scope,
+ // so it should work out.
+ //
+ // For legacy GLSL targets, we'd probably need a distinct resource
+ // kind and set of rules here, since legacy uniforms are not the
+ // same as the contents of a constant buffer.
+ auto globalScopeRules = context.layoutRules->getConstantBufferRules();
+
+ RefPtr<StructTypeLayout> globalScopeStructLayout = new StructTypeLayout();
+ globalScopeStructLayout->rules = globalScopeRules;
+
+ UniformLayoutInfo structLayoutInfo = globalScopeRules->BeginStructLayout();
+ for( auto& parameterInfo : sharedContext.parameters )
+ {
+ assert(parameterInfo->varLayouts.Count() != 0);
+ auto firstVarLayout = parameterInfo->varLayouts.First();
+
+ // Does the field have any uniform data?
+ auto layoutInfo = firstVarLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform);
+ size_t uniformSize = layoutInfo ? layoutInfo->count : 0;
+ if( uniformSize != 0 )
+ {
+ // Make sure uniform fields get laid out properly...
+
+ UniformLayoutInfo fieldInfo(
+ uniformSize,
+ firstVarLayout->typeLayout->uniformAlignment);
+
+ size_t uniformOffset = globalScopeRules->AddStructField(
+ &structLayoutInfo,
+ fieldInfo);
+
+ for( auto& varLayout : parameterInfo->varLayouts )
+ {
+ varLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset;
+ }
+ }
+
+ globalScopeStructLayout->fields.Add(firstVarLayout);
+
+ for( auto& varLayout : parameterInfo->varLayouts )
+ {
+ globalScopeStructLayout->mapVarToLayout.Add(varLayout->varDecl.GetDecl(), varLayout);
+ }
+ }
+ globalScopeRules->EndStructLayout(&structLayoutInfo);
+
+ RefPtr<TypeLayout> globalScopeLayout = globalScopeStructLayout;
+
+ // If there are global-scope uniforms, then we need to wrap
+ // up a global constant buffer type layout to hold them
+ if( anyGlobalUniforms )
+ {
+ auto globalConstantBufferLayout = createParameterBlockTypeLayout(
+ nullptr,
+ globalScopeStructLayout,
+ globalScopeRules);
+
+ globalScopeLayout = globalConstantBufferLayout;
+ }
+
+ // We now have a bunch of layout information, which we should
+ // record into a suitable object that represents the program
+ programLayout->globalScopeLayout = globalScopeLayout;
+ program->layout = programLayout;
+}
+
+}}
diff --git a/source/slang/parameter-binding.h b/source/slang/parameter-binding.h
new file mode 100644
index 000000000..8165f1b2e
--- /dev/null
+++ b/source/slang/parameter-binding.h
@@ -0,0 +1,32 @@
+#ifndef SLANG_PARAMETER_BINDING_H
+#define SLANG_PARAMETER_BINDING_H
+
+#include "../core/basic.h"
+#include "syntax.h"
+
+#include "../../Slang.h"
+
+namespace Slang {
+namespace Compiler {
+
+class CollectionOfTranslationUnits;
+
+// The parameter-binding interface is responsible for assigning
+// binding locations/registers to every parameter of a shader
+// program. This can include both parameters declared on a
+// particular entry point, as well as parameters declared at
+// global scope.
+//
+
+
+// Generate binding information for the given program,
+// represented as a collection of different translation units,
+// and attach that information to the syntax nodes
+// of the program.
+
+void GenerateParameterBindings(
+ CollectionOfTranslationUnits* program);
+
+}}
+
+#endif // SLANG_REFLECTION_H
diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp
new file mode 100644
index 000000000..773e3b74b
--- /dev/null
+++ b/source/slang/parser.cpp
@@ -0,0 +1,3106 @@
+#include "Parser.h"
+
+#include <assert.h>
+
+#include "lookup.h"
+
+namespace Slang
+{
+ namespace Compiler
+ {
+ enum Precedence : int
+ {
+ Invalid = -1,
+ Comma,
+ Assignment,
+ TernaryConditional,
+ LogicalOr,
+ LogicalAnd,
+ BitOr,
+ BitXor,
+ BitAnd,
+ EqualityComparison,
+ RelationalComparison,
+ BitShift,
+ Additive,
+ Multiplicative,
+ Prefix,
+ Postfix,
+ };
+
+ // TODO: implement two pass parsing for file reference and struct type recognition
+
+ class Parser
+ {
+ public:
+ CompileOptions& options;
+ int anonymousCounter = 0;
+
+ RefPtr<Scope> outerScope;
+ RefPtr<Scope> currentScope;
+
+ TokenReader tokenReader;
+ DiagnosticSink * sink;
+ String fileName;
+ int genericDepth = 0;
+
+ // Is the parser in a "recovering" state?
+ // During recovery we don't emit additional errors, until we find
+ // a token that we expected, when we exit recovery.
+ bool isRecovering = false;
+
+ void FillPosition(SyntaxNode * node)
+ {
+ node->Position = tokenReader.PeekLoc();
+ }
+ void PushScope(ContainerDecl* containerDecl)
+ {
+ RefPtr<Scope> newScope = new Scope();
+ newScope->containerDecl = containerDecl;
+ newScope->parent = currentScope;
+
+ currentScope = newScope;
+ }
+ void PopScope()
+ {
+ currentScope = currentScope->parent;
+ }
+ Parser(
+ CompileOptions& options,
+ TokenSpan const& _tokens,
+ DiagnosticSink * sink,
+ String _fileName,
+ RefPtr<Scope> const& outerScope)
+ : options(options)
+ , tokenReader(_tokens)
+ , sink(sink)
+ , fileName(_fileName)
+ , outerScope(outerScope)
+ {}
+ RefPtr<ProgramSyntaxNode> Parse();
+
+ Token ReadToken();
+ Token ReadToken(TokenType type);
+ Token ReadToken(const char * string);
+ bool LookAheadToken(TokenType type, int offset = 0);
+ bool LookAheadToken(const char * string, int offset = 0);
+ void parseSourceFile(ProgramSyntaxNode* program);
+ RefPtr<ProgramSyntaxNode> ParseProgram();
+ RefPtr<StructSyntaxNode> ParseStruct();
+ RefPtr<ClassSyntaxNode> ParseClass();
+ RefPtr<StatementSyntaxNode> ParseStatement();
+ RefPtr<StatementSyntaxNode> ParseBlockStatement();
+ RefPtr<VarDeclrStatementSyntaxNode> ParseVarDeclrStatement(Modifiers modifiers);
+ RefPtr<IfStatementSyntaxNode> ParseIfStatement();
+ RefPtr<ForStatementSyntaxNode> ParseForStatement();
+ RefPtr<WhileStatementSyntaxNode> ParseWhileStatement();
+ RefPtr<DoWhileStatementSyntaxNode> ParseDoWhileStatement();
+ RefPtr<BreakStatementSyntaxNode> ParseBreakStatement();
+ RefPtr<ContinueStatementSyntaxNode> ParseContinueStatement();
+ RefPtr<ReturnStatementSyntaxNode> ParseReturnStatement();
+ RefPtr<ExpressionStatementSyntaxNode> ParseExpressionStatement();
+ RefPtr<ExpressionSyntaxNode> ParseExpression(Precedence level = Precedence::Comma);
+
+ // Parse an expression that might be used in an initializer or argument context, so we should avoid operator-comma
+ inline RefPtr<ExpressionSyntaxNode> ParseInitExpr() { return ParseExpression(Precedence::Assignment); }
+ inline RefPtr<ExpressionSyntaxNode> ParseArgExpr() { return ParseExpression(Precedence::Assignment); }
+
+ RefPtr<ExpressionSyntaxNode> ParseLeafExpression();
+ RefPtr<ParameterSyntaxNode> ParseParameter();
+ RefPtr<ExpressionSyntaxNode> ParseType();
+ TypeExp ParseTypeExp();
+
+ Parser & operator = (const Parser &) = delete;
+ };
+
+ // Forward Declarations
+
+ static void ParseDeclBody(
+ Parser* parser,
+ ContainerDecl* containerDecl,
+ TokenType closingToken);
+
+ static RefPtr<Modifier> ParseOptSemantics(
+ Parser* parser);
+
+ static void ParseOptSemantics(
+ Parser* parser,
+ Decl* decl);
+
+ static RefPtr<DeclBase> ParseDecl(
+ Parser* parser,
+ ContainerDecl* containerDecl);
+
+ static RefPtr<Decl> ParseSingleDecl(
+ Parser* parser,
+ ContainerDecl* containerDecl);
+
+ //
+
+ static void Unexpected(
+ Parser* parser)
+ {
+ // Don't emit "unexpected token" errors if we are in recovering mode
+ if (!parser->isRecovering)
+ {
+ parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedToken,
+ parser->tokenReader.PeekTokenType());
+
+ // Switch into recovery mode, to suppress additional errors
+ parser->isRecovering = true;
+ }
+ }
+
+ static void Unexpected(
+ Parser* parser,
+ char const* expected)
+ {
+ // Don't emit "unexpected token" errors if we are in recovering mode
+ if (!parser->isRecovering)
+ {
+ parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedTokenExpectedTokenName,
+ parser->tokenReader.PeekTokenType(),
+ expected);
+
+ // Switch into recovery mode, to suppress additional errors
+ parser->isRecovering = true;
+ }
+ }
+
+ static void Unexpected(
+ Parser* parser,
+ TokenType expected)
+ {
+ // Don't emit "unexpected token" errors if we are in recovering mode
+ if (!parser->isRecovering)
+ {
+ parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedTokenExpectedTokenType,
+ parser->tokenReader.PeekTokenType(),
+ expected);
+
+ // Switch into recovery mode, to suppress additional errors
+ parser->isRecovering = true;
+ }
+ }
+
+ static TokenType SkipToMatchingToken(TokenReader* reader, TokenType tokenType);
+
+ // Skip a singel balanced token, which is either a single token in
+ // the common case, or a matched pair of tokens for `()`, `[]`, and `{}`
+ static TokenType SkipBalancedToken(
+ TokenReader* reader)
+ {
+ TokenType tokenType = reader->AdvanceToken().Type;
+ switch (tokenType)
+ {
+ default:
+ break;
+
+ case TokenType::LParent: tokenType = SkipToMatchingToken(reader, TokenType::RParent); break;
+ case TokenType::LBrace: tokenType = SkipToMatchingToken(reader, TokenType::RBrace); break;
+ case TokenType::LBracket: tokenType = SkipToMatchingToken(reader, TokenType::RBracket); break;
+ }
+ return tokenType;
+ }
+
+ // Skip balanced
+ static TokenType SkipToMatchingToken(
+ TokenReader* reader,
+ TokenType tokenType)
+ {
+ for (;;)
+ {
+ if (reader->IsAtEnd()) return TokenType::EndOfFile;
+ if (reader->PeekTokenType() == tokenType)
+ {
+ reader->AdvanceToken();
+ return tokenType;
+ }
+ SkipBalancedToken(reader);
+ }
+ }
+
+ // Is the given token type one that is used to "close" a
+ // balanced construct.
+ static bool IsClosingToken(TokenType tokenType)
+ {
+ switch (tokenType)
+ {
+ case TokenType::EndOfFile:
+ case TokenType::RBracket:
+ case TokenType::RParent:
+ case TokenType::RBrace:
+ return true;
+
+ default:
+ return false;
+ }
+ }
+
+
+ // Expect an identifier token with the given content, and consume it.
+ Token Parser::ReadToken(const char* expected)
+ {
+ if (tokenReader.PeekTokenType() == TokenType::Identifier
+ && tokenReader.PeekToken().Content == expected)
+ {
+ isRecovering = false;
+ return tokenReader.AdvanceToken();
+ }
+
+ if (!isRecovering)
+ {
+ Unexpected(this, expected);
+ return tokenReader.PeekToken();
+ }
+ else
+ {
+ // Try to find a place to recover
+ for (;;)
+ {
+ // The token we expected?
+ // Then exit recovery mode and pretend like all is well.
+ if (tokenReader.PeekTokenType() == TokenType::Identifier
+ && tokenReader.PeekToken().Content == expected)
+ {
+ isRecovering = false;
+ return tokenReader.AdvanceToken();
+ }
+
+
+ // Don't skip past any "closing" tokens.
+ if (IsClosingToken(tokenReader.PeekTokenType()))
+ {
+ return tokenReader.PeekToken();
+ }
+
+ // Skip balanced tokens and try again.
+ SkipBalancedToken(&tokenReader);
+ }
+ }
+ }
+
+ Token Parser::ReadToken()
+ {
+ return tokenReader.AdvanceToken();
+ }
+
+ static bool TryRecover(
+ Parser* parser,
+ TokenType const* recoverBefore,
+ int recoverBeforeCount,
+ TokenType const* recoverAfter,
+ int recoverAfterCount)
+ {
+ if (!parser->isRecovering)
+ return true;
+
+ // Determine if we are looking for a closing token at all...
+ bool lookingForClose = false;
+ for (int ii = 0; ii < recoverBeforeCount; ++ii)
+ {
+ if (IsClosingToken(recoverBefore[ii]))
+ lookingForClose = true;
+ }
+ for (int ii = 0; ii < recoverAfterCount; ++ii)
+ {
+ if (IsClosingToken(recoverAfter[ii]))
+ lookingForClose = true;
+ }
+
+ TokenReader* tokenReader = &parser->tokenReader;
+ for (;;)
+ {
+ TokenType peek = tokenReader->PeekTokenType();
+
+ // Is the next token in our recover-before set?
+ // If so, then we have recovered successfully!
+ for (int ii = 0; ii < recoverBeforeCount; ++ii)
+ {
+ if (peek == recoverBefore[ii])
+ {
+ parser->isRecovering = false;
+ return true;
+ }
+ }
+
+ // If we are looking at a token in our recover-after set,
+ // then consume it and recover
+ for (int ii = 0; ii < recoverAfterCount; ++ii)
+ {
+ if (peek == recoverAfter[ii])
+ {
+ tokenReader->AdvanceToken();
+ parser->isRecovering = false;
+ return true;
+ }
+ }
+
+ // Don't try to skip past end of file
+ if (peek == TokenType::EndOfFile)
+ return false;
+
+ switch (peek)
+ {
+ // Don't skip past simple "closing" tokens, *unless*
+ // we are looking for a closing token
+ case TokenType::RParent:
+ case TokenType::RBracket:
+ if (!lookingForClose)
+ return false;
+ break;
+
+ // never skip a `}`, to avoid spurious errors
+ case TokenType::RBrace:
+ return false;
+ }
+
+ // Skip balanced tokens and try again.
+ TokenType skipped = SkipBalancedToken(tokenReader);
+
+ // If we happened to find a matched pair of tokens, and
+ // the end of it was a token we were looking for,
+ // then recover here
+ for (int ii = 0; ii < recoverAfterCount; ++ii)
+ {
+ if (skipped == recoverAfter[ii])
+ {
+ parser->isRecovering = false;
+ return true;
+ }
+ }
+ }
+ }
+
+ static bool TryRecoverBefore(
+ Parser* parser,
+ TokenType before0)
+ {
+ TokenType recoverBefore[] = { before0 };
+ return TryRecover(parser, recoverBefore, 1, nullptr, 0);
+ }
+
+ // Default recovery strategy, to use inside `{}`-delimeted blocks.
+ static bool TryRecover(
+ Parser* parser)
+ {
+ TokenType recoverBefore[] = { TokenType::RBrace };
+ TokenType recoverAfter[] = { TokenType::Semicolon };
+ return TryRecover(parser, recoverBefore, 1, recoverAfter, 1);
+ }
+
+ Token Parser::ReadToken(TokenType expected)
+ {
+ if (tokenReader.PeekTokenType() == expected)
+ {
+ isRecovering = false;
+ return tokenReader.AdvanceToken();
+ }
+
+ if (!isRecovering)
+ {
+ Unexpected(this, expected);
+ return tokenReader.PeekToken();
+ }
+ else
+ {
+ // Try to find a place to recover
+ if (TryRecoverBefore(this, expected))
+ {
+ isRecovering = false;
+ return tokenReader.AdvanceToken();
+ }
+
+ return tokenReader.PeekToken();
+ }
+ }
+
+ bool Parser::LookAheadToken(const char * string, int offset)
+ {
+ TokenReader r = tokenReader;
+ for (int ii = 0; ii < offset; ++ii)
+ r.AdvanceToken();
+
+ return r.PeekTokenType() == TokenType::Identifier
+ && r.PeekToken().Content == string;
+ }
+
+ bool Parser::LookAheadToken(TokenType type, int offset)
+ {
+ TokenReader r = tokenReader;
+ for (int ii = 0; ii < offset; ++ii)
+ r.AdvanceToken();
+
+ return r.PeekTokenType() == type;
+ }
+
+ // Consume a token and return true it if matches, otherwise false
+ bool AdvanceIf(Parser* parser, TokenType tokenType)
+ {
+ if (parser->LookAheadToken(tokenType))
+ {
+ parser->ReadToken();
+ return true;
+ }
+ return false;
+ }
+
+ // Consume a token and return true it if matches, otherwise false
+ bool AdvanceIf(Parser* parser, char const* text)
+ {
+ if (parser->LookAheadToken(text))
+ {
+ parser->ReadToken();
+ return true;
+ }
+ return false;
+ }
+
+ // Consume a token and return true if it matches, otherwise check
+ // for end-of-file and expect that token (potentially producing
+ // an error) and return true to maintain forward progress.
+ // Otherwise return false.
+ bool AdvanceIfMatch(Parser* parser, TokenType tokenType)
+ {
+ // If we've run into a syntax error, but haven't recovered inside
+ // the block, then try to recover here.
+ if (parser->isRecovering)
+ {
+ TryRecoverBefore(parser, tokenType);
+ }
+ if (AdvanceIf(parser, tokenType))
+ return true;
+ if (parser->tokenReader.PeekTokenType() == TokenType::EndOfFile)
+ {
+ parser->ReadToken(tokenType);
+ return true;
+ }
+ return false;
+ }
+
+ RefPtr<ProgramSyntaxNode> Parser::Parse()
+ {
+ return ParseProgram();
+ }
+
+ RefPtr<TypeDefDecl> ParseTypeDef(Parser* parser)
+ {
+ // Consume the `typedef` keyword
+ parser->ReadToken("typedef");
+
+ // TODO(tfoley): parse an actual declarator
+ auto type = parser->ParseTypeExp();
+
+ auto nameToken = parser->ReadToken(TokenType::Identifier);
+
+ RefPtr<TypeDefDecl> typeDefDecl = new TypeDefDecl();
+ typeDefDecl->Name = nameToken;
+ typeDefDecl->Type = type;
+
+ return typeDefDecl;
+ }
+
+ // Add a modifier to a list of modifiers being built
+ static void AddModifier(RefPtr<Modifier>** ioModifierLink, RefPtr<Modifier> modifier)
+ {
+ RefPtr<Modifier>*& modifierLink = *ioModifierLink;
+
+ while(*modifierLink)
+ modifierLink = &(*modifierLink)->next;
+
+ *modifierLink = modifier;
+ modifierLink = &modifier->next;
+ }
+
+ void addModifier(
+ RefPtr<ModifiableSyntaxNode> syntax,
+ RefPtr<Modifier> modifier)
+ {
+ auto modifierLink = &syntax->modifiers.first;
+ AddModifier(&modifierLink, modifier);
+ }
+
+ // Parse HLSL-style `[name(arg, ...)]` style "attribute" modifiers
+ static void ParseSquareBracketAttributes(Parser* parser, RefPtr<Modifier>** ioModifierLink)
+ {
+ parser->ReadToken(TokenType::LBracket);
+ for(;;)
+ {
+ auto nameToken = parser->ReadToken(TokenType::Identifier);
+ RefPtr<HLSLUncheckedAttribute> modifier = new HLSLUncheckedAttribute();
+ modifier->nameToken = nameToken;
+
+ if (AdvanceIf(parser, TokenType::LParent))
+ {
+ // HLSL-style `[name(arg0, ...)]` attribute
+
+ while (!AdvanceIfMatch(parser, TokenType::RParent))
+ {
+ auto arg = parser->ParseArgExpr();
+ if (arg)
+ {
+ modifier->args.Add(arg);
+ }
+
+ if (AdvanceIfMatch(parser, TokenType::RParent))
+ break;
+
+ parser->ReadToken(TokenType::Comma);
+ }
+ }
+ AddModifier(ioModifierLink, modifier);
+
+
+ if (AdvanceIfMatch(parser, TokenType::RBracket))
+ break;
+
+ parser->ReadToken(TokenType::Comma);
+ }
+ }
+
+ static Modifiers ParseModifiers(Parser* parser)
+ {
+ Modifiers modifiers;
+ RefPtr<Modifier>* modifierLink = &modifiers.first;
+ for (;;)
+ {
+ CodePosition loc = parser->tokenReader.PeekLoc();
+
+ if (0) {}
+
+ #define CASE(KEYWORD, TYPE) \
+ else if(AdvanceIf(parser, #KEYWORD)) do { \
+ RefPtr<TYPE> modifier = new TYPE(); \
+ modifier->Position = loc; \
+ AddModifier(&modifierLink, modifier); \
+ } while(0)
+
+ CASE(in, InModifier);
+ CASE(input, InputModifier);
+ CASE(out, OutModifier);
+ CASE(inout, InOutModifier);
+ CASE(const, ConstModifier);
+ CASE(instance, InstanceModifier);
+ CASE(__builtin, BuiltinModifier);
+
+ CASE(inline, InlineModifier);
+ CASE(public, PublicModifier);
+ CASE(require, RequireModifier);
+ CASE(param, ParamModifier);
+ CASE(extern, ExternModifier);
+
+ CASE(row_major, HLSLRowMajorLayoutModifier);
+ CASE(column_major, HLSLColumnMajorLayoutModifier);
+
+ CASE(nointerpolation, HLSLNoInterpolationModifier);
+ CASE(linear, HLSLLinearModifier);
+ CASE(sample, HLSLSampleModifier);
+ CASE(centroid, HLSLCentroidModifier);
+ CASE(precise, HLSLPreciseModifier);
+ CASE(shared, HLSLEffectSharedModifier);
+ CASE(groupshared, HLSLGroupSharedModifier);
+ CASE(static, HLSLStaticModifier);
+ CASE(uniform, HLSLUniformModifier);
+ CASE(volatile, HLSLVolatileModifier);
+
+ // Modifiers for geometry shader input
+ CASE(point, HLSLPointModifier);
+ CASE(line, HLSLLineModifier);
+ CASE(triangle, HLSLTriangleModifier);
+ CASE(lineadj, HLSLLineAdjModifier);
+ CASE(triangleadj, HLSLTriangleAdjModifier);
+
+ // Modifiers for unary operator declarations
+ CASE(__prefix, PrefixModifier);
+ CASE(__postfix, PostfixModifier);
+
+ #undef CASE
+
+ else if (AdvanceIf(parser, "__intrinsic"))
+ {
+ auto modifier = new IntrinsicModifier();
+ modifier->Position = loc;
+
+ if (AdvanceIf(parser, TokenType::LParent))
+ {
+ if (parser->LookAheadToken(TokenType::IntLiterial))
+ {
+ modifier->op = (IntrinsicOp)StringToInt(parser->ReadToken().Content);
+ }
+ else
+ {
+ modifier->opToken = parser->ReadToken(TokenType::Identifier);
+
+ modifier->op = findIntrinsicOp(modifier->opToken.Content.Buffer());
+
+ if (modifier->op == IntrinsicOp::Unknown)
+ {
+ parser->sink->diagnose(loc, Diagnostics::unimplemented, "unknown intrinsic op");
+ }
+ }
+
+ parser->ReadToken(TokenType::RParent);
+ }
+
+ AddModifier(&modifierLink, modifier);
+ }
+
+
+ else if (AdvanceIf(parser, "layout"))
+ {
+ parser->ReadToken(TokenType::LParent);
+ while (!AdvanceIfMatch(parser, TokenType::RParent))
+ {
+ auto nameToken = parser->ReadToken(TokenType::Identifier);
+
+ RefPtr<GLSLLayoutModifier> modifier;
+
+ // TODO: better handling of this choise (e.g., lookup in scope)
+ if(0) {}
+ #define CASE(KEYWORD, CLASS) \
+ else if(nameToken.Content == #KEYWORD) modifier = new CLASS()
+
+ CASE(constant_id, GLSLConstantIDLayoutModifier);
+ CASE(binding, GLSLBindingLayoutModifier);
+ CASE(set, GLSLSetLayoutModifier);
+ CASE(location, GLSLLocationLayoutModifier);
+
+ #undef CASE
+ else
+ {
+ modifier = new GLSLUnparsedLayoutModifier();
+ }
+
+ modifier->nameToken = nameToken;
+
+ if(AdvanceIf(parser, TokenType::OpAssign))
+ {
+ modifier->valToken = parser->ReadToken(TokenType::IntLiterial);
+ }
+
+ AddModifier(&modifierLink, modifier);
+
+ if (AdvanceIf(parser, TokenType::RParent))
+ break;
+ parser->ReadToken(TokenType::Comma);
+ }
+ }
+ else if (parser->tokenReader.PeekTokenType() == TokenType::LBracket)
+ {
+ ParseSquareBracketAttributes(parser, &modifierLink);
+ }
+ else if (AdvanceIf(parser,"__builtin_type"))
+ {
+ RefPtr<BuiltinTypeModifier> modifier = new BuiltinTypeModifier();
+ parser->ReadToken(TokenType::LParent);
+ modifier->tag = BaseType(StringToInt(parser->ReadToken(TokenType::IntLiterial).Content));
+ parser->ReadToken(TokenType::RParent);
+
+ AddModifier(&modifierLink, modifier);
+ }
+ else if (AdvanceIf(parser,"__magic_type"))
+ {
+ RefPtr<MagicTypeModifier> modifier = new MagicTypeModifier();
+ parser->ReadToken(TokenType::LParent);
+ modifier->name = parser->ReadToken(TokenType::Identifier).Content;
+ if (AdvanceIf(parser, TokenType::Comma))
+ {
+ modifier->tag = uint32_t(StringToInt(parser->ReadToken(TokenType::IntLiterial).Content));
+ }
+ parser->ReadToken(TokenType::RParent);
+
+ AddModifier(&modifierLink, modifier);
+ }
+ else
+ {
+ // Fallback case if none of the above explicit cases matched.
+
+ // If we are looking at an identifier, then it may map to a
+ // modifier declaration visible in the current scope
+ if( parser->LookAheadToken(TokenType::Identifier) )
+ {
+ LookupResult lookupResult = LookUp(
+ parser->tokenReader.PeekToken().Content,
+ parser->currentScope);
+
+ if( lookupResult.isValid() && !lookupResult.isOverloaded() )
+ {
+ auto& item = lookupResult.item;
+ auto decl = item.declRef.GetDecl();
+
+ if( auto modifierDecl = dynamic_cast<ModifierDecl*>(decl) )
+ {
+ // We found a declaration for some modifier syntax,
+ // so lets create an instance of the type it names
+ // here.
+
+ auto syntax = createInstanceOfSyntaxClassByName(modifierDecl->classNameToken.Content);
+ auto modifier = dynamic_cast<Modifier*>(syntax);
+
+ if( modifier )
+ {
+ modifier->Position = parser->tokenReader.PeekLoc();
+ modifier->nameToken = parser->ReadToken(TokenType::Identifier);
+
+ AddModifier(&modifierLink, modifier);
+ continue;
+ }
+ else
+ {
+ parser->ReadToken(TokenType::Identifier);
+ assert(!"unexpected");
+ }
+ }
+ }
+ }
+
+ // Done with modifier list
+ return modifiers;
+ }
+ }
+ }
+
+ static RefPtr<Decl> ParseUsing(
+ Parser* parser)
+ {
+ parser->ReadToken("using");
+ if (parser->tokenReader.PeekTokenType() == TokenType::StringLiterial)
+ {
+ auto usingDecl = new UsingFileDecl();
+ usingDecl->fileName = parser->ReadToken(TokenType::StringLiterial);
+ parser->ReadToken(TokenType::Semicolon);
+ return usingDecl;
+ }
+ else
+ {
+ unexpected();
+ }
+ }
+
+ static Token ParseDeclName(
+ Parser* parser)
+ {
+ Token name;
+ if (AdvanceIf(parser, "operator"))
+ {
+ name = parser->ReadToken();
+ switch (name.Type)
+ {
+ case TokenType::OpAdd: case TokenType::OpSub: case TokenType::OpMul: case TokenType::OpDiv:
+ case TokenType::OpMod: case TokenType::OpNot: case TokenType::OpBitNot: case TokenType::OpLsh: case TokenType::OpRsh:
+ case TokenType::OpEql: case TokenType::OpNeq: case TokenType::OpGreater: case TokenType::OpLess: case TokenType::OpGeq:
+ case TokenType::OpLeq: case TokenType::OpAnd: case TokenType::OpOr: case TokenType::OpBitXor: case TokenType::OpBitAnd:
+ case TokenType::OpBitOr: case TokenType::OpInc: case TokenType::OpDec:
+ case TokenType::OpAddAssign:
+ case TokenType::OpSubAssign:
+ case TokenType::OpMulAssign:
+ case TokenType::OpDivAssign:
+ case TokenType::OpModAssign:
+ case TokenType::OpShlAssign:
+ case TokenType::OpShrAssign:
+ case TokenType::OpOrAssign:
+ case TokenType::OpAndAssign:
+ case TokenType::OpXorAssign:
+
+ // Note(tfoley): A bit of a hack:
+ case TokenType::Comma:
+ case TokenType::OpAssign:
+ break;
+
+ // Note(tfoley): Even more of a hack!
+ case TokenType::QuestionMark:
+ if (AdvanceIf(parser, TokenType::Colon))
+ {
+ name.Content = name.Content + ":";
+ break;
+ }
+
+ default:
+ parser->sink->diagnose(name.Position, Diagnostics::invalidOperator, name.Content);
+ break;
+ }
+ }
+ else
+ {
+ name = parser->ReadToken(TokenType::Identifier);
+ }
+ return name;
+ }
+
+ // A "declarator" as used in C-style languages
+ struct Declarator : RefObject
+ {
+ // Different cases of declarator appear as "flavors" here
+ enum class Flavor
+ {
+ Name,
+ Pointer,
+ Array,
+ };
+ Flavor flavor;
+ };
+
+ // The most common case of declarator uses a simple name
+ struct NameDeclarator : Declarator
+ {
+ Token nameToken;
+ };
+
+ // A declarator that declares a pointer type
+ struct PointerDeclarator : Declarator
+ {
+ // location of the `*` token
+ CodePosition starLoc;
+
+ RefPtr<Declarator> inner;
+ };
+
+ // A declarator that declares an array type
+ struct ArrayDeclarator : Declarator
+ {
+ RefPtr<Declarator> inner;
+
+ // location of the `[` token
+ CodePosition openBracketLoc;
+
+ // The expression that yields the element count, or NULL
+ RefPtr<ExpressionSyntaxNode> elementCountExpr;
+ };
+
+ // "Unwrapped" information about a declarator
+ struct DeclaratorInfo
+ {
+ RefPtr<ExpressionSyntaxNode> typeSpec;
+ Token nameToken;
+ RefPtr<Modifier> semantics;
+ RefPtr<ExpressionSyntaxNode> initializer;
+ };
+
+ // Add a member declaration to its container, and ensure that its
+ // parent link is set up correctly.
+ static void AddMember(RefPtr<ContainerDecl> container, RefPtr<Decl> member)
+ {
+ if (container)
+ {
+ member->ParentDecl = container.Ptr();
+ container->Members.Add(member);
+
+ container->memberDictionaryIsValid = false;
+ }
+ }
+
+ static void AddMember(RefPtr<Scope> scope, RefPtr<Decl> member)
+ {
+ if (scope)
+ {
+ AddMember(scope->containerDecl, member);
+ }
+ }
+
+ static void parseParameterList(
+ Parser* parser,
+ RefPtr<CallableDecl> decl)
+ {
+ parser->ReadToken(TokenType::LParent);
+
+ // Allow a declaration to use the keyword `void` for a parameter list,
+ // since that was required in ancient C, and continues to be supported
+ // in a bunc hof its derivatives even if it is a Bad Design Choice
+ //
+ // TODO: conditionalize this so we don't keep this around for "pure"
+ // Slang code
+ if( parser->LookAheadToken("void") && parser->LookAheadToken(TokenType::RParent, 1) )
+ {
+ parser->ReadToken("void");
+ parser->ReadToken(TokenType::RParent);
+ return;
+ }
+
+ while (!AdvanceIfMatch(parser, TokenType::RParent))
+ {
+ AddMember(decl, parser->ParseParameter());
+ if (AdvanceIf(parser, TokenType::RParent))
+ break;
+ parser->ReadToken(TokenType::Comma);
+ }
+ }
+
+ static void ParseFuncDeclHeader(
+ Parser* parser,
+ DeclaratorInfo const& declaratorInfo,
+ RefPtr<FunctionSyntaxNode> decl)
+ {
+ parser->PushScope(decl.Ptr());
+
+ parser->FillPosition(decl.Ptr());
+ decl->Position = declaratorInfo.nameToken.Position;
+
+ decl->Name = declaratorInfo.nameToken;
+ decl->ReturnType = TypeExp(declaratorInfo.typeSpec);
+ parseParameterList(parser, decl);
+ ParseOptSemantics(parser, decl.Ptr());
+ }
+
+ static RefPtr<Decl> ParseFuncDecl(
+ Parser* parser,
+ ContainerDecl* /*containerDecl*/,
+ DeclaratorInfo const& declaratorInfo)
+ {
+ RefPtr<FunctionSyntaxNode> decl = new FunctionSyntaxNode();
+ ParseFuncDeclHeader(parser, declaratorInfo, decl);
+
+ if (AdvanceIf(parser, TokenType::Semicolon))
+ {
+ // empty body
+ }
+ else
+ {
+ decl->Body = parser->ParseBlockStatement();
+ }
+
+ parser->PopScope();
+ return decl;
+ }
+
+ static RefPtr<VarDeclBase> CreateVarDeclForContext(
+ ContainerDecl* containerDecl )
+ {
+ if (dynamic_cast<StructSyntaxNode*>(containerDecl) || dynamic_cast<ClassSyntaxNode*>(containerDecl))
+ {
+ return new StructField();
+ }
+ else if (dynamic_cast<CallableDecl*>(containerDecl))
+ {
+ return new ParameterSyntaxNode();
+ }
+ else
+ {
+ return new Variable();
+ }
+ }
+
+ // Add modifiers to the end of the modifier list for a declaration
+ void AddModifiers(Decl* decl, RefPtr<Modifier> modifiers)
+ {
+ if (!modifiers)
+ return;
+
+ RefPtr<Modifier>* link = &decl->modifiers.first;
+ while (*link)
+ {
+ link = &(*link)->next;
+ }
+ *link = modifiers;
+ }
+
+
+ static String GenerateName(Parser* parser, String const& base)
+ {
+ // TODO: somehow mangle the name to avoid clashes
+ return base;
+ }
+
+ static String GenerateName(Parser* parser)
+ {
+ return GenerateName(parser, "_anonymous_" + String(parser->anonymousCounter++));
+ }
+
+
+ // Set up a variable declaration based on what we saw in its declarator...
+ static void CompleteVarDecl(
+ Parser* parser,
+ RefPtr<VarDeclBase> decl,
+ DeclaratorInfo const& declaratorInfo)
+ {
+ parser->FillPosition(decl.Ptr());
+
+ if( declaratorInfo.nameToken.Type == TokenType::Unknown )
+ {
+ // HACK(tfoley): we always give a name, even if the declarator didn't include one... :(
+ decl->Name.Content = GenerateName(parser);
+ }
+ else
+ {
+ decl->Position = declaratorInfo.nameToken.Position;
+ decl->Name = declaratorInfo.nameToken;
+ }
+ decl->Type = TypeExp(declaratorInfo.typeSpec);
+
+ AddModifiers(decl.Ptr(), declaratorInfo.semantics);
+
+ decl->Expr = declaratorInfo.initializer;
+ }
+
+ static RefPtr<Declarator> ParseDeclarator(Parser* parser);
+
+ static RefPtr<Declarator> ParseDirectAbstractDeclarator(
+ Parser* parser)
+ {
+ RefPtr<Declarator> declarator;
+ switch( parser->tokenReader.PeekTokenType() )
+ {
+ case TokenType::Identifier:
+ {
+ auto nameDeclarator = new NameDeclarator();
+ nameDeclarator->flavor = Declarator::Flavor::Name;
+ nameDeclarator->nameToken = ParseDeclName(parser);
+ declarator = nameDeclarator;
+ }
+ break;
+
+ case TokenType::LParent:
+ {
+ // Note(tfoley): This is a point where disambiguation is required.
+ // We could be looking at an abstract declarator for a function-type
+ // parameter:
+ //
+ // void F( int(int) );
+ //
+ // Or we could be looking at the use of parenthesese in an ordinary
+ // declarator:
+ //
+ // void (*f)(int);
+ //
+ // The difference really doesn't matter right now, but we err in
+ // the direction of assuming the second case.
+ parser->ReadToken(TokenType::LParent);
+ declarator = ParseDeclarator(parser);
+ parser->ReadToken(TokenType::RParent);
+ }
+ break;
+
+ default:
+ // an empty declarator is allowed
+ return nullptr;
+ }
+
+ // postifx additions
+ for( ;;)
+ {
+ switch( parser->tokenReader.PeekTokenType() )
+ {
+ case TokenType::LBracket:
+ {
+ auto arrayDeclarator = new ArrayDeclarator();
+ arrayDeclarator->openBracketLoc = parser->tokenReader.PeekLoc();
+ arrayDeclarator->flavor = Declarator::Flavor::Array;
+ arrayDeclarator->inner = declarator;
+
+ parser->ReadToken(TokenType::LBracket);
+ if( parser->tokenReader.PeekTokenType() != TokenType::RBracket )
+ {
+ arrayDeclarator->elementCountExpr = parser->ParseExpression();
+ }
+ parser->ReadToken(TokenType::RBracket);
+
+ declarator = arrayDeclarator;
+ continue;
+ }
+
+ case TokenType::LParent:
+ break;
+
+ default:
+ break;
+ }
+
+ break;
+ }
+
+ return declarator;
+ }
+
+ // Parse a declarator (or at least as much of one as we support)
+ static RefPtr<Declarator> ParseDeclarator(
+ Parser* parser)
+ {
+ if( parser->tokenReader.PeekTokenType() == TokenType::OpMul )
+ {
+ auto ptrDeclarator = new PointerDeclarator();
+ ptrDeclarator->starLoc = parser->tokenReader.PeekLoc();
+ ptrDeclarator->flavor = Declarator::Flavor::Pointer;
+
+ parser->ReadToken(TokenType::OpMul);
+
+ // TODO(tfoley): allow qualifiers like `const` here?
+
+ ptrDeclarator->inner = ParseDeclarator(parser);
+ return ptrDeclarator;
+ }
+ else
+ {
+ return ParseDirectAbstractDeclarator(parser);
+ }
+ }
+
+ // A declarator plus optional semantics and initializer
+ struct InitDeclarator
+ {
+ RefPtr<Declarator> declarator;
+ RefPtr<Modifier> semantics;
+ RefPtr<ExpressionSyntaxNode> initializer;
+ };
+
+ // Parse a declarator plus optional semantics
+ static InitDeclarator ParseSemanticDeclarator(
+ Parser* parser)
+ {
+ InitDeclarator result;
+ result.declarator = ParseDeclarator(parser);
+ result.semantics = ParseOptSemantics(parser);
+ return result;
+ }
+
+ // Parse a declarator plus optional semantics and initializer
+ static InitDeclarator ParseInitDeclarator(
+ Parser* parser)
+ {
+ InitDeclarator result = ParseSemanticDeclarator(parser);
+ if (AdvanceIf(parser, TokenType::OpAssign))
+ {
+ result.initializer = parser->ParseInitExpr();
+ }
+ return result;
+ }
+
+ static void UnwrapDeclarator(
+ RefPtr<Declarator> declarator,
+ DeclaratorInfo* ioInfo)
+ {
+ while( declarator )
+ {
+ switch(declarator->flavor)
+ {
+ case Declarator::Flavor::Name:
+ {
+ auto nameDeclarator = (NameDeclarator*) declarator.Ptr();
+ ioInfo->nameToken = nameDeclarator->nameToken;
+ return;
+ }
+ break;
+
+ case Declarator::Flavor::Pointer:
+ {
+ auto ptrDeclarator = (PointerDeclarator*) declarator.Ptr();
+
+ // TODO(tfoley): we don't support pointers for now
+ // ioInfo->typeSpec = new PointerTypeExpr(ioInfo->typeSpec);
+
+ declarator = ptrDeclarator->inner;
+ }
+ break;
+
+ case Declarator::Flavor::Array:
+ {
+ // TODO(tfoley): we don't support pointers for now
+ auto arrayDeclarator = (ArrayDeclarator*) declarator.Ptr();
+
+ auto arrayTypeExpr = new IndexExpressionSyntaxNode();
+ arrayTypeExpr->Position = arrayDeclarator->openBracketLoc;
+ arrayTypeExpr->BaseExpression = ioInfo->typeSpec;
+ arrayTypeExpr->IndexExpression = arrayDeclarator->elementCountExpr;
+ ioInfo->typeSpec = arrayTypeExpr;
+
+ declarator = arrayDeclarator->inner;
+ }
+ break;
+
+ default:
+ SLANG_UNREACHABLE("all cases handled");
+ break;
+ }
+ }
+ }
+
+ static void UnwrapDeclarator(
+ InitDeclarator const& initDeclarator,
+ DeclaratorInfo* ioInfo)
+ {
+ UnwrapDeclarator(initDeclarator.declarator, ioInfo);
+ ioInfo->semantics = initDeclarator.semantics;
+ ioInfo->initializer = initDeclarator.initializer;
+ }
+
+ // Either a single declaration, or a group of them
+ struct DeclGroupBuilder
+ {
+ CodePosition startPosition;
+ RefPtr<Decl> decl;
+ RefPtr<DeclGroup> group;
+
+ // Add a new declaration to the potential group
+ void addDecl(
+ RefPtr<Decl> newDecl)
+ {
+ assert(newDecl);
+
+ if( decl )
+ {
+ group = new DeclGroup();
+ group->Position = startPosition;
+ group->decls.Add(decl);
+ decl = nullptr;
+ }
+
+ if( group )
+ {
+ group->decls.Add(newDecl);
+ }
+ else
+ {
+ decl = newDecl;
+ }
+ }
+
+ RefPtr<DeclBase> getResult()
+ {
+ if(group) return group;
+ return decl;
+ }
+ };
+
+ // Pares an argument to an application of a generic
+ RefPtr<ExpressionSyntaxNode> ParseGenericArg(Parser* parser)
+ {
+ return parser->ParseArgExpr();
+ }
+
+ // Create a type expression that will refer to the given declaration
+ static RefPtr<ExpressionSyntaxNode>
+ createDeclRefType(Parser* parser, RefPtr<Decl> decl)
+ {
+ // For now we just construct an expression that
+ // will look up the given declaration by name.
+ //
+ // TODO: do this better, e.g. by filling in the `declRef` field directly
+
+ auto expr = new VarExpressionSyntaxNode();
+ expr->scope = parser->currentScope.Ptr();
+ expr->Position = decl->getNameToken().Position;
+ expr->Variable = decl->getName();
+ return expr;
+ }
+
+ // Representation for a parsed type specifier, which might
+ // include a declaration (e.g., of a `struct` type)
+ struct TypeSpec
+ {
+ // If the type-spec declared something, then put it here
+ RefPtr<Decl> decl;
+
+ // Put the resulting expression (which should evaluate to a type) here
+ RefPtr<ExpressionSyntaxNode> expr;
+ };
+
+ static TypeSpec
+ parseTypeSpec(Parser* parser)
+ {
+ TypeSpec typeSpec;
+
+ // We may see a `struct` type specified here, and need to act accordingly
+ //
+ // TODO(tfoley): Handle the case where the user is just using `struct`
+ // as a way to name an existing struct "tag" (e.g., `struct Foo foo;`)
+ //
+ if( parser->LookAheadToken("struct") )
+ {
+ auto decl = parser->ParseStruct();
+ typeSpec.decl = decl;
+ typeSpec.expr = createDeclRefType(parser, decl);
+ return typeSpec;
+ }
+ else if( parser->LookAheadToken("class") )
+ {
+ auto decl = parser->ParseClass();
+ typeSpec.decl = decl;
+ typeSpec.expr = createDeclRefType(parser, decl);
+ return typeSpec;
+ }
+
+ Token typeName = parser->ReadToken(TokenType::Identifier);
+
+ auto basicType = new VarExpressionSyntaxNode();
+ basicType->scope = parser->currentScope.Ptr();
+ basicType->Position = typeName.Position;
+ basicType->Variable = typeName.Content;
+
+ RefPtr<ExpressionSyntaxNode> typeExpr = basicType;
+
+ if (parser->LookAheadToken(TokenType::OpLess))
+ {
+ RefPtr<GenericAppExpr> gtype = new GenericAppExpr();
+ parser->FillPosition(gtype.Ptr()); // set up scope for lookup
+ gtype->Position = typeName.Position;
+ gtype->FunctionExpr = typeExpr;
+ parser->ReadToken(TokenType::OpLess);
+ parser->genericDepth++;
+ // For now assume all generics have at least one argument
+ gtype->Arguments.Add(ParseGenericArg(parser));
+ while (AdvanceIf(parser, TokenType::Comma))
+ {
+ gtype->Arguments.Add(ParseGenericArg(parser));
+ }
+ parser->genericDepth--;
+ parser->ReadToken(TokenType::OpGreater);
+ typeExpr = gtype;
+ }
+
+ typeSpec.expr = typeExpr;
+ return typeSpec;
+ }
+
+
+ static RefPtr<DeclBase> ParseDeclaratorDecl(
+ Parser* parser,
+ ContainerDecl* containerDecl)
+ {
+ CodePosition startPosition = parser->tokenReader.PeekLoc();
+
+ auto typeSpec = parseTypeSpec(parser);
+
+ // We may need to build up multiple declarations in a group,
+ // but the common case will be when we have just a single
+ // declaration
+ DeclGroupBuilder declGroupBuilder;
+ declGroupBuilder.startPosition = startPosition;
+
+ // The type specifier may include a declaration. E.g.,
+ // it might declare a `struct` type.
+ if(typeSpec.decl)
+ declGroupBuilder.addDecl(typeSpec.decl);
+
+ if( AdvanceIf(parser, TokenType::Semicolon) )
+ {
+ // No actual variable is being declared here, but
+ // that might not be an error.
+
+ auto result = declGroupBuilder.getResult();
+ if( !result )
+ {
+ parser->sink->diagnose(startPosition, Diagnostics::declarationDidntDeclareAnything);
+ }
+ return result;
+ }
+
+
+ InitDeclarator initDeclarator = ParseInitDeclarator(parser);
+
+ DeclaratorInfo declaratorInfo;
+ declaratorInfo.typeSpec = typeSpec.expr;
+
+
+ // Rather than parse function declarators properly for now,
+ // we'll just do a quick disambiguation here. This won't
+ // matter unless we actually decide to support function-type parameters,
+ // using C syntax.
+ //
+ if( parser->tokenReader.PeekTokenType() == TokenType::LParent
+
+ // Only parse as a function if we didn't already see mutually-exclusive
+ // constructs when parsing the declarator.
+ && !initDeclarator.initializer
+ && !initDeclarator.semantics)
+ {
+ // Looks like a function, so parse it like one.
+ UnwrapDeclarator(initDeclarator, &declaratorInfo);
+ return ParseFuncDecl(parser, containerDecl, declaratorInfo);
+ }
+
+ // Otherwise we are looking at a variable declaration, which could be one in a sequence...
+
+ if( AdvanceIf(parser, TokenType::Semicolon) )
+ {
+ // easy case: we only had a single declaration!
+ UnwrapDeclarator(initDeclarator, &declaratorInfo);
+ RefPtr<VarDeclBase> firstDecl = CreateVarDeclForContext(containerDecl);
+ CompleteVarDecl(parser, firstDecl, declaratorInfo);
+
+ declGroupBuilder.addDecl(firstDecl);
+ return declGroupBuilder.getResult();
+
+ return firstDecl;
+ }
+
+ // Otherwise we have multiple declarations in a sequence, and these
+ // declarations need to somehow share both the type spec and modifiers.
+ //
+ // If there are any errors in the type specifier, we only want to hear
+ // about it once, so we need to share structure rather than just
+ // clone syntax.
+
+ auto sharedTypeSpec = new SharedTypeExpr();
+ sharedTypeSpec->Position = typeSpec.expr->Position;
+ sharedTypeSpec->base = TypeExp(typeSpec.expr);
+
+ for(;;)
+ {
+ declaratorInfo.typeSpec = sharedTypeSpec;
+ UnwrapDeclarator(initDeclarator, &declaratorInfo);
+
+ RefPtr<VarDeclBase> varDecl = CreateVarDeclForContext(containerDecl);
+ CompleteVarDecl(parser, varDecl, declaratorInfo);
+
+ declGroupBuilder.addDecl(varDecl);
+
+ // end of the sequence?
+ if(AdvanceIf(parser, TokenType::Semicolon))
+ return declGroupBuilder.getResult();
+
+ // ad-hoc recovery, to avoid infinite loops
+ if( parser->isRecovering )
+ {
+ parser->ReadToken(TokenType::Semicolon);
+ return declGroupBuilder.getResult();
+ }
+
+ // Let's default to assuming that a missing `,`
+ // indicates the end of a declaration,
+ // where a `;` would be expected, and not
+ // a continuation of this declaration, where
+ // a `,` would be expected (this is tailoring
+ // the diagnostic message a bit).
+ //
+ // TODO: a more advanced heuristic here might
+ // look at whether the next token is on the
+ // same line, to predict whether `,` or `;`
+ // would be more likely...
+
+ if (!AdvanceIf(parser, TokenType::Comma))
+ {
+ parser->ReadToken(TokenType::Semicolon);
+ return declGroupBuilder.getResult();
+ }
+
+ // expect another variable declaration...
+ initDeclarator = ParseInitDeclarator(parser);
+ }
+ }
+
+ //
+ // layout-semantic ::= (register | packoffset) '(' register-name component-mask? ')'
+ // register-name ::= identifier
+ // component-mask ::= '.' identifier
+ //
+ static void ParseHLSLLayoutSemantic(
+ Parser* parser,
+ HLSLLayoutSemantic* semantic)
+ {
+ semantic->name = parser->ReadToken(TokenType::Identifier);
+
+ parser->ReadToken(TokenType::LParent);
+ semantic->registerName = parser->ReadToken(TokenType::Identifier);
+ if (AdvanceIf(parser, TokenType::Dot))
+ {
+ semantic->componentMask = parser->ReadToken(TokenType::Identifier);
+ }
+ parser->ReadToken(TokenType::RParent);
+ }
+
+ //
+ // semantic ::= identifier ( '(' args ')' )?
+ //
+ static RefPtr<Modifier> ParseSemantic(
+ Parser* parser)
+ {
+ if (parser->LookAheadToken("register"))
+ {
+ RefPtr<HLSLRegisterSemantic> semantic = new HLSLRegisterSemantic();
+ ParseHLSLLayoutSemantic(parser, semantic.Ptr());
+ return semantic;
+ }
+ else if (parser->LookAheadToken("packoffset"))
+ {
+ RefPtr<HLSLPackOffsetSemantic> semantic = new HLSLPackOffsetSemantic();
+ ParseHLSLLayoutSemantic(parser, semantic.Ptr());
+ return semantic;
+ }
+ else
+ {
+ RefPtr<HLSLSimpleSemantic> semantic = new HLSLSimpleSemantic();
+ semantic->name = parser->ReadToken(TokenType::Identifier);
+ return semantic;
+ }
+ }
+
+ //
+ // opt-semantics ::= (':' semantic)*
+ //
+ static RefPtr<Modifier> ParseOptSemantics(
+ Parser* parser)
+ {
+ if (!AdvanceIf(parser, TokenType::Colon))
+ return nullptr;
+
+ RefPtr<Modifier> result;
+ RefPtr<Modifier>* link = &result;
+ assert(!*link);
+
+ for (;;)
+ {
+ RefPtr<Modifier> semantic = ParseSemantic(parser);
+ if (semantic)
+ {
+ *link = semantic;
+ link = &semantic->next;
+ }
+
+ switch (parser->tokenReader.PeekTokenType())
+ {
+ case TokenType::LBrace:
+ case TokenType::Semicolon:
+ case TokenType::Comma:
+ case TokenType::RParent:
+ case TokenType::EndOfFile:
+ return result;
+
+ default:
+ break;
+ }
+
+ parser->ReadToken(TokenType::Colon);
+ }
+
+ }
+
+
+ static void ParseOptSemantics(
+ Parser* parser,
+ Decl* decl)
+ {
+ AddModifiers(decl, ParseOptSemantics(parser));
+ }
+
+ static RefPtr<Decl> ParseHLSLBufferDecl(
+ Parser* parser)
+ {
+ // An HLSL declaration of a constant buffer like this:
+ //
+ // cbuffer Foo : register(b0) { int a; float b; };
+ //
+ // is treated as syntax sugar for a type declaration
+ // and then a global variable declaration using that type:
+ //
+ // struct $anonymous { int a; float b; };
+ // ConstantBuffer<$anonymous> Foo;
+ //
+ // where `$anonymous` is a fresh name, and the variable
+ // declaration is made to be "transparent" so that lookup
+ // will see through it to the members inside.
+
+ // We first look at the declaration keywrod to determine
+ // the type of buffer to declare:
+ String bufferWrapperTypeName;
+ CodePosition bufferWrapperTypeNamePos = parser->tokenReader.PeekLoc();
+ if (AdvanceIf(parser, "cbuffer"))
+ {
+ bufferWrapperTypeName = "ConstantBuffer";
+ }
+ else if (AdvanceIf(parser, "tbuffer"))
+ {
+ bufferWrapperTypeName = "TextureBuffer";
+ }
+ else
+ {
+ Unexpected(parser);
+ }
+
+ // We are going to represent each buffer as a pair of declarations.
+ // The first is a type declaration that holds all the members, while
+ // the second is a variable declaration that uses the buffer type.
+ RefPtr<StructSyntaxNode> bufferDataTypeDecl = new StructSyntaxNode();
+ RefPtr<Variable> bufferVarDecl = new Variable();
+
+ // Both declarations will have a location that points to the name
+ parser->FillPosition(bufferDataTypeDecl.Ptr());
+ parser->FillPosition(bufferVarDecl.Ptr());
+
+ auto reflectionNameToken = parser->ReadToken(TokenType::Identifier);
+
+ // Attach the reflection name to the block so we can use it
+ auto reflectionNameModifier = new ParameterBlockReflectionName();
+ reflectionNameModifier->nameToken = reflectionNameToken;
+ addModifier(bufferVarDecl, reflectionNameModifier);
+
+ // Both the buffer variable and its type need to have names generated
+ bufferVarDecl->Name.Content = GenerateName(parser, "SLANG_constantBuffer_" + reflectionNameToken.Content);
+ bufferDataTypeDecl->Name.Content = GenerateName(parser, "SLANG_ConstantBuffer_" + reflectionNameToken.Content);
+
+ addModifier(bufferDataTypeDecl, new ImplicitParameterBlockElementTypeModifier());
+ addModifier(bufferVarDecl, new ImplicitParameterBlockVariableModifier());
+
+ // TODO(tfoley): We end up constructing unchecked syntax here that
+ // is expected to type check into the right form, but it might be
+ // cleaner to have a more explicit desugaring pass where we parse
+ // these constructs directly into the AST and *then* desugar them.
+
+ // Construct a type expression to reference the buffer data type
+ auto bufferDataTypeExpr = new VarExpressionSyntaxNode();
+ bufferDataTypeExpr->Position = bufferDataTypeDecl->Position;
+ bufferDataTypeExpr->Variable = bufferDataTypeDecl->Name.Content;
+ bufferDataTypeExpr->scope = parser->currentScope.Ptr();
+
+ // Construct a type exrpession to reference the type constructor
+ auto bufferWrapperTypeExpr = new VarExpressionSyntaxNode();
+ bufferWrapperTypeExpr->Position = bufferWrapperTypeNamePos;
+ bufferWrapperTypeExpr->Variable = bufferWrapperTypeName;
+
+ // Always need to look this up in the outer scope,
+ // so that it won't collide with, e.g., a local variable called `ConstantBuffer`
+ bufferWrapperTypeExpr->scope = parser->outerScope;
+
+ // Construct a type expression that represents the type for the variable,
+ // which is the wrapper type applied to the data type
+ auto bufferVarTypeExpr = new GenericAppExpr();
+ bufferVarTypeExpr->Position = bufferVarDecl->Position;
+ bufferVarTypeExpr->FunctionExpr = bufferWrapperTypeExpr;
+ bufferVarTypeExpr->Arguments.Add(bufferDataTypeExpr);
+
+ bufferVarDecl->Type.exp = bufferVarTypeExpr;
+
+ // Any semantics applied to the bufer declaration are taken as applying
+ // to the variable instead.
+ ParseOptSemantics(parser, bufferVarDecl.Ptr());
+
+ // The declarations in the body belong to the data type.
+ parser->ReadToken(TokenType::LBrace);
+ ParseDeclBody(parser, bufferDataTypeDecl.Ptr(), TokenType::RBrace);
+
+ // All HLSL buffer declarations are "transparent" in that their
+ // members are implicitly made visible in the parent scope.
+ // We achieve this by applying the transparent modifier to the variable.
+ auto transparentModifier = new TransparentModifier();
+ transparentModifier->next = bufferVarDecl->modifiers.first;
+ bufferVarDecl->modifiers.first = transparentModifier;
+
+ // Because we are constructing two declarations, we have a thorny
+ // issue that were are only supposed to return one.
+ // For now we handle this by adding the type declaration to
+ // the current scope manually, and then returning the variable
+ // declaration.
+ //
+ // Note: this means that any modifiers that have already been parsed
+ // will get attached to the variable declaration, not the type.
+ // There might be cases where we need to shuffle things around.
+
+ AddMember(parser->currentScope, bufferDataTypeDecl);
+
+ return bufferVarDecl;
+ }
+
+ static void removeModifier(
+ Modifiers& modifiers,
+ RefPtr<Modifier> modifier)
+ {
+ RefPtr<Modifier>* link = &modifiers.first;
+ while (*link)
+ {
+ if (*link == modifier)
+ {
+ *link = (*link)->next;
+ return;
+ }
+
+ link = &(*link)->next;
+ }
+ }
+
+ static RefPtr<Decl> parseGLSLBlockDecl(
+ Parser* parser,
+ Modifiers& modifiers)
+ {
+ // An GLSL block like this:
+ //
+ // uniform Foo { int a; float b; } foo;
+ //
+ // is treated as syntax sugar for a type declaration
+ // and then a global variable declaration using that type:
+ //
+ // struct $anonymous { int a; float b; };
+ // Block<$anonymous> foo;
+ //
+ // where `$anonymous` is a fresh name.
+ //
+ // If a "local name" like `foo` is not given, then
+ // we make the declaration "transparent" so that lookup
+ // will see through it to the members inside.
+
+
+ CodePosition pos = parser->tokenReader.PeekLoc();
+
+ // The initial name before the `{` is only supposed
+ // to be made visible to reflection
+ auto reflectionNameToken = parser->ReadToken(TokenType::Identifier);
+
+ // Look at the qualifiers present on the block to decide what kind
+ // of block we are looking at. Also *remove* those qualifiers so
+ // that they don't interfere with downstream work.
+ String blockWrapperTypeName;
+ if( auto uniformMod = modifiers.findModifier<HLSLUniformModifier>() )
+ {
+ removeModifier(modifiers, uniformMod);
+ blockWrapperTypeName = "ConstantBuffer";
+ }
+ else if( auto inMod = modifiers.findModifier<InModifier>() )
+ {
+ removeModifier(modifiers, inMod);
+ blockWrapperTypeName = "__GLSLInputParameterBlock";
+ }
+ else if( auto outMod = modifiers.findModifier<OutModifier>() )
+ {
+ removeModifier(modifiers, outMod);
+ blockWrapperTypeName = "__GLSLOutputParameterBlock";
+ }
+ else if( auto bufferMod = modifiers.findModifier<GLSLBufferModifier>() )
+ {
+ removeModifier(modifiers, bufferMod);
+ blockWrapperTypeName = "__GLSLShaderStorageBuffer";
+ }
+ else
+ {
+ // Unknown case: just map to a constant buffer and hope for the best
+ blockWrapperTypeName = "ConstantBuffer";
+ }
+
+ // We are going to represent each buffer as a pair of declarations.
+ // The first is a type declaration that holds all the members, while
+ // the second is a variable declaration that uses the buffer type.
+ RefPtr<StructSyntaxNode> blockDataTypeDecl = new StructSyntaxNode();
+ RefPtr<Variable> blockVarDecl = new Variable();
+
+ addModifier(blockDataTypeDecl, new ImplicitParameterBlockElementTypeModifier());
+ addModifier(blockVarDecl, new ImplicitParameterBlockVariableModifier());
+
+ // Attach the reflection name to the block so we can use it
+ auto reflectionNameModifier = new ParameterBlockReflectionName();
+ reflectionNameModifier->nameToken = reflectionNameToken;
+ addModifier(blockVarDecl, reflectionNameModifier);
+
+ // Both declarations will have a location that points to the name
+ parser->FillPosition(blockDataTypeDecl.Ptr());
+ parser->FillPosition(blockVarDecl.Ptr());
+
+ // Generate a unique name for the data type
+ blockDataTypeDecl->Name.Content = GenerateName(parser, "SLANG_ParameterBlock_" + reflectionNameToken.Content);
+
+ // TODO(tfoley): We end up constructing unchecked syntax here that
+ // is expected to type check into the right form, but it might be
+ // cleaner to have a more explicit desugaring pass where we parse
+ // these constructs directly into the AST and *then* desugar them.
+
+ // Construct a type expression to reference the buffer data type
+ auto blockDataTypeExpr = new VarExpressionSyntaxNode();
+ blockDataTypeExpr->Position = blockDataTypeDecl->Position;
+ blockDataTypeExpr->Variable = blockDataTypeDecl->Name.Content;
+ blockDataTypeExpr->scope = parser->currentScope.Ptr();
+
+ // Construct a type exrpession to reference the type constructor
+ auto blockWrapperTypeExpr = new VarExpressionSyntaxNode();
+ blockWrapperTypeExpr->Position = pos;
+ blockWrapperTypeExpr->Variable = blockWrapperTypeName;
+ // Always need to look this up in the outer scope,
+ // so that it won't collide with, e.g., a local variable called `ConstantBuffer`
+ blockWrapperTypeExpr->scope = parser->outerScope;
+
+ // Construct a type expression that represents the type for the variable,
+ // which is the wrapper type applied to the data type
+ auto blockVarTypeExpr = new GenericAppExpr();
+ blockVarTypeExpr->Position = blockVarDecl->Position;
+ blockVarTypeExpr->FunctionExpr = blockWrapperTypeExpr;
+ blockVarTypeExpr->Arguments.Add(blockDataTypeExpr);
+
+ blockVarDecl->Type.exp = blockVarTypeExpr;
+
+ // The declarations in the body belong to the data type.
+ parser->ReadToken(TokenType::LBrace);
+ ParseDeclBody(parser, blockDataTypeDecl.Ptr(), TokenType::RBrace);
+
+ if( parser->LookAheadToken(TokenType::Identifier) )
+ {
+ // The user gave an explicit name to the block,
+ // so we need to use that as our variable name
+ blockVarDecl->Name = parser->ReadToken(TokenType::Identifier);
+
+ // TODO: in this case we make actually have a more complex
+ // declarator, including `[]` brackets.
+ }
+ else
+ {
+ // synthesize a dummy name
+ blockVarDecl->Name.Content = GenerateName(parser, "SLANG_parameterBlock_" + reflectionNameToken.Content);
+
+ // Otherwise we have a transparent declaration, similar
+ // to an HLSL `cbuffer`
+ auto transparentModifier = new TransparentModifier();
+ transparentModifier->Position = pos;
+ addModifier(blockVarDecl, transparentModifier);
+ }
+
+ // Expect a trailing `;`
+ parser->ReadToken(TokenType::Semicolon);
+
+ // Because we are constructing two declarations, we have a thorny
+ // issue that were are only supposed to return one.
+ // For now we handle this by adding the type declaration to
+ // the current scope manually, and then returning the variable
+ // declaration.
+ //
+ // Note: this means that any modifiers that have already been parsed
+ // will get attached to the variable declaration, not the type.
+ // There might be cases where we need to shuffle things around.
+
+ AddMember(parser->currentScope, blockDataTypeDecl);
+
+ return blockVarDecl;
+ }
+
+
+
+
+ static RefPtr<Decl> ParseGenericParamDecl(
+ Parser* parser,
+ RefPtr<GenericDecl> genericDecl)
+ {
+ // simple syntax to introduce a value parameter
+ if (AdvanceIf(parser, "let"))
+ {
+ // default case is a type parameter
+ auto paramDecl = new GenericValueParamDecl();
+ paramDecl->Name = parser->ReadToken(TokenType::Identifier);
+ if (AdvanceIf(parser, TokenType::Colon))
+ {
+ paramDecl->Type = parser->ParseTypeExp();
+ }
+ if (AdvanceIf(parser, TokenType::OpAssign))
+ {
+ paramDecl->Expr = parser->ParseInitExpr();
+ }
+ return paramDecl;
+ }
+ else
+ {
+ // default case is a type parameter
+ auto paramDecl = new GenericTypeParamDecl();
+ parser->FillPosition(paramDecl);
+ paramDecl->Name = parser->ReadToken(TokenType::Identifier);
+ if (AdvanceIf(parser, TokenType::Colon))
+ {
+ // The user is apply a constraint to this type parameter...
+
+ auto paramConstraint = new GenericTypeConstraintDecl();
+ parser->FillPosition(paramConstraint);
+
+ auto paramType = DeclRefType::Create(DeclRef(paramDecl, nullptr));
+
+ auto paramTypeExpr = new SharedTypeExpr();
+ paramTypeExpr->Position = paramDecl->Position;
+ paramTypeExpr->base.type = paramType;
+ paramTypeExpr->Type = new TypeType(paramType);
+
+ paramConstraint->sub = TypeExp(paramTypeExpr);
+ paramConstraint->sup = parser->ParseTypeExp();
+
+ AddMember(genericDecl, paramConstraint);
+
+
+ }
+ if (AdvanceIf(parser, TokenType::OpAssign))
+ {
+ paramDecl->initType = parser->ParseTypeExp();
+ }
+ return paramDecl;
+ }
+ }
+
+ static RefPtr<Decl> ParseGenericDecl(
+ Parser* parser)
+ {
+ RefPtr<GenericDecl> decl = new GenericDecl();
+ parser->FillPosition(decl.Ptr());
+ parser->PushScope(decl.Ptr());
+ parser->ReadToken("__generic");
+ parser->ReadToken(TokenType::OpLess);
+ parser->genericDepth++;
+ while (!parser->LookAheadToken(TokenType::OpGreater))
+ {
+ AddMember(decl, ParseGenericParamDecl(parser, decl));
+
+if( parser->LookAheadToken(TokenType::OpGreater) )
+break;
+
+parser->ReadToken(TokenType::Comma);
+ }
+ parser->genericDepth--;
+ parser->ReadToken(TokenType::OpGreater);
+
+ decl->inner = ParseSingleDecl(parser, decl.Ptr());
+
+ // A generic decl hijacks the name of the declaration
+ // it wraps, so that lookup can find it.
+ decl->Name = decl->inner->Name;
+
+ parser->PopScope();
+ return decl;
+ }
+
+ static RefPtr<Decl> ParseTraitConformanceDecl(
+ Parser* parser)
+ {
+ RefPtr<TraitConformanceDecl> decl = new TraitConformanceDecl();
+ parser->FillPosition(decl.Ptr());
+ parser->ReadToken("__conforms");
+
+ decl->base = parser->ParseTypeExp();
+
+ return decl;
+ }
+
+
+ static RefPtr<ExtensionDecl> ParseExtensionDecl(Parser* parser)
+ {
+ RefPtr<ExtensionDecl> decl = new ExtensionDecl();
+ parser->FillPosition(decl.Ptr());
+ parser->ReadToken("__extension");
+ decl->targetType = parser->ParseTypeExp();
+ parser->ReadToken(TokenType::LBrace);
+ ParseDeclBody(parser, decl.Ptr(), TokenType::RBrace);
+ return decl;
+ }
+
+ static RefPtr<TraitDecl> ParseTraitDecl(Parser* parser)
+ {
+ RefPtr<TraitDecl> decl = new TraitDecl();
+ parser->FillPosition(decl.Ptr());
+ parser->ReadToken("__trait");
+ decl->Name = parser->ReadToken(TokenType::Identifier);
+
+ if( AdvanceIf(parser, TokenType::Colon) )
+ {
+ do
+ {
+ auto base = parser->ParseTypeExp();
+ decl->bases.Add(base);
+ } while( AdvanceIf(parser, TokenType::Comma) );
+ }
+
+ parser->ReadToken(TokenType::LBrace);
+ ParseDeclBody(parser, decl.Ptr(), TokenType::RBrace);
+ return decl;
+ }
+
+ static RefPtr<ConstructorDecl> ParseConstructorDecl(Parser* parser)
+ {
+ RefPtr<ConstructorDecl> decl = new ConstructorDecl();
+ parser->FillPosition(decl.Ptr());
+ parser->ReadToken("__init");
+
+ parseParameterList(parser, decl);
+
+ if( AdvanceIf(parser, TokenType::Semicolon) )
+ {
+ // empty body
+ }
+ else
+ {
+ decl->Body = parser->ParseBlockStatement();
+ }
+ return decl;
+ }
+
+ static RefPtr<AccessorDecl> parseAccessorDecl(Parser* parser)
+ {
+ RefPtr<AccessorDecl> decl;
+ if( AdvanceIf(parser, "get") )
+ {
+ decl = new GetterDecl();
+ }
+ else if( AdvanceIf(parser, "set") )
+ {
+ decl = new SetterDecl();
+ }
+ else
+ {
+ Unexpected(parser);
+ return nullptr;
+ }
+
+ if( parser->tokenReader.PeekTokenType() == TokenType::LBrace )
+ {
+ decl->Body = parser->ParseBlockStatement();
+ }
+ else
+ {
+ parser->ReadToken(TokenType::Semicolon);
+ }
+
+ return decl;
+ }
+
+ static RefPtr<SubscriptDecl> ParseSubscriptDecl(Parser* parser)
+ {
+ RefPtr<SubscriptDecl> decl = new SubscriptDecl();
+ parser->FillPosition(decl.Ptr());
+ parser->ReadToken("__subscript");
+
+ // TODO: the use of this name here is a bit magical...
+ decl->Name.Content = "operator[]";
+
+ parseParameterList(parser, decl);
+
+ if( AdvanceIf(parser, TokenType::RightArrow) )
+ {
+ decl->ReturnType = parser->ParseTypeExp();
+ }
+
+ if( AdvanceIf(parser, TokenType::LBrace) )
+ {
+ // We want to parse nested "accessor" declarations
+ while( !AdvanceIfMatch(parser, TokenType::RBrace) )
+ {
+ auto accessor = parseAccessorDecl(parser);
+ AddMember(decl, accessor);
+ }
+ }
+ else
+ {
+ parser->ReadToken(TokenType::Semicolon);
+
+ // empty body should be treated like `{ get; }`
+ }
+
+ return decl;
+ }
+
+ // Parse a declaration of a new modifier keyword
+ static RefPtr<ModifierDecl> parseModifierDecl(Parser* parser)
+ {
+ RefPtr<ModifierDecl> decl = new ModifierDecl();
+
+ // read the `__modifier` keyword
+ parser->ReadToken(TokenType::Identifier);
+
+ parser->ReadToken(TokenType::LParent);
+ decl->classNameToken = parser->ReadToken(TokenType::Identifier);
+ parser->ReadToken(TokenType::RParent);
+
+ parser->FillPosition(decl.Ptr());
+ decl->Name = parser->ReadToken(TokenType::Identifier);
+
+ parser->ReadToken(TokenType::Semicolon);
+ return decl;
+ }
+
+ // Finish up work on a declaration that was parsed
+ static void CompleteDecl(
+ Parser* /*parser*/,
+ RefPtr<Decl> decl,
+ ContainerDecl* containerDecl,
+ Modifiers modifiers)
+ {
+ // Add any modifiers we parsed before the declaration to the list
+ // of modifiers on the declaration itself.
+ AddModifiers(decl.Ptr(), modifiers.first);
+
+ // Make sure the decl is properly nested inside its lexical parent
+ if (containerDecl)
+ {
+ AddMember(containerDecl, decl);
+ }
+ }
+
+ static RefPtr<DeclBase> ParseDeclWithModifiers(
+ Parser* parser,
+ ContainerDecl* containerDecl,
+ Modifiers modifiers )
+ {
+ RefPtr<DeclBase> decl;
+
+ auto loc = parser->tokenReader.PeekLoc();
+
+ // TODO: actual dispatch!
+ if (parser->LookAheadToken("struct"))
+ decl = ParseDeclaratorDecl(parser, containerDecl);
+ else if (parser->LookAheadToken("class"))
+ decl = ParseDeclaratorDecl(parser, containerDecl);
+ else if (parser->LookAheadToken("typedef"))
+ decl = ParseTypeDef(parser);
+ else if (parser->LookAheadToken("using"))
+ decl = ParseUsing(parser);
+ else if (parser->LookAheadToken("cbuffer") || parser->LookAheadToken("tbuffer"))
+ decl = ParseHLSLBufferDecl(parser);
+ else if (parser->LookAheadToken("__generic"))
+ decl = ParseGenericDecl(parser);
+ else if (parser->LookAheadToken("__conforms"))
+ decl = ParseTraitConformanceDecl(parser);
+ else if (parser->LookAheadToken("__extension"))
+ decl = ParseExtensionDecl(parser);
+ else if (parser->LookAheadToken("__init"))
+ decl = ParseConstructorDecl(parser);
+ else if (parser->LookAheadToken("__subscript"))
+ decl = ParseSubscriptDecl(parser);
+ else if (parser->LookAheadToken("__trait"))
+ decl = ParseTraitDecl(parser);
+ else if(parser->LookAheadToken("__modifier"))
+ decl = parseModifierDecl(parser);
+ else if (AdvanceIf(parser, TokenType::Semicolon))
+ {
+ decl = new EmptyDecl();
+ decl->Position = loc;
+ }
+ // GLSL requires that we be able to parse "block" declarations,
+ // which look superficially similar to declarator declarations
+ else if( parser->LookAheadToken(TokenType::Identifier)
+ && parser->LookAheadToken(TokenType::LBrace, 1) )
+ {
+ decl = parseGLSLBlockDecl(parser, modifiers);
+ }
+ else
+ {
+ // Default case: just parse a declarator-based declaration
+ decl = ParseDeclaratorDecl(parser, containerDecl);
+ }
+
+ if (decl)
+ {
+ if( auto dd = decl.As<Decl>() )
+ {
+ CompleteDecl(parser, dd, containerDecl, modifiers);
+ }
+ else if(auto declGroup = decl.As<DeclGroup>())
+ {
+ // We are going to add the same modifiers to *all* of these declarations,
+ // so we want to give later passes a way to detect which modifiers
+ // were shared, vs. which ones are specific to a single declaration.
+
+ auto sharedModifiers = new SharedModifiers();
+ sharedModifiers->next = modifiers.first;
+ modifiers.first = sharedModifiers;
+
+ for( auto subDecl : declGroup->decls )
+ {
+ CompleteDecl(parser, subDecl, containerDecl, modifiers);
+ }
+ }
+ }
+ return decl;
+ }
+
+ static RefPtr<DeclBase> ParseDecl(
+ Parser* parser,
+ ContainerDecl* containerDecl)
+ {
+ Modifiers modifiers = ParseModifiers(parser);
+ return ParseDeclWithModifiers(parser, containerDecl, modifiers);
+ }
+
+ static RefPtr<Decl> ParseSingleDecl(
+ Parser* parser,
+ ContainerDecl* containerDecl)
+ {
+ auto declBase = ParseDecl(parser, containerDecl);
+ if(!declBase)
+ return nullptr;
+ if( auto decl = declBase.As<Decl>() )
+ {
+ return decl;
+ }
+ else if( auto declGroup = declBase.As<DeclGroup>() )
+ {
+ if( declGroup->decls.Count() == 1 )
+ {
+ return declGroup->decls[0];
+ }
+ }
+
+ parser->sink->diagnose(declBase->Position, Diagnostics::unimplemented, "didn't expect multiple declarations here");
+ return nullptr;
+ }
+
+
+ // Parse a body consisting of declarations
+ static void ParseDeclBody(
+ Parser* parser,
+ ContainerDecl* containerDecl,
+ TokenType closingToken)
+ {
+ while(!AdvanceIfMatch(parser, closingToken))
+ {
+ ParseDecl(parser, containerDecl);
+ TryRecover(parser);
+ }
+ }
+
+ void Parser::parseSourceFile(ProgramSyntaxNode* program)
+ {
+ if (outerScope)
+ {
+ currentScope = outerScope;
+ }
+
+ PushScope(program);
+ program->Position = CodePosition(0, 0, 0, fileName);
+ ParseDeclBody(this, program, TokenType::EndOfFile);
+ PopScope();
+
+ assert(currentScope == outerScope);
+ currentScope = nullptr;
+ }
+
+ RefPtr<ProgramSyntaxNode> Parser::ParseProgram()
+ {
+ RefPtr<ProgramSyntaxNode> program = new ProgramSyntaxNode();
+
+ parseSourceFile(program.Ptr());
+
+ return program;
+ }
+
+ RefPtr<StructSyntaxNode> Parser::ParseStruct()
+ {
+ RefPtr<StructSyntaxNode> rs = new StructSyntaxNode();
+ FillPosition(rs.Ptr());
+ ReadToken("struct");
+ rs->Name = ReadToken(TokenType::Identifier);
+ ReadToken(TokenType::LBrace);
+ ParseDeclBody(this, rs.Ptr(), TokenType::RBrace);
+
+ return rs;
+ }
+
+ RefPtr<ClassSyntaxNode> Parser::ParseClass()
+ {
+ RefPtr<ClassSyntaxNode> rs = new ClassSyntaxNode();
+ FillPosition(rs.Ptr());
+ ReadToken("class");
+ rs->Name = ReadToken(TokenType::Identifier);
+ ReadToken(TokenType::LBrace);
+ ParseDeclBody(this, rs.Ptr(), TokenType::RBrace);
+ return rs;
+ }
+
+ static RefPtr<StatementSyntaxNode> ParseSwitchStmt(Parser* parser)
+ {
+ RefPtr<SwitchStmt> stmt = new SwitchStmt();
+ parser->FillPosition(stmt.Ptr());
+ parser->ReadToken("switch");
+ parser->ReadToken(TokenType::LParent);
+ stmt->condition = parser->ParseExpression();
+ parser->ReadToken(TokenType::RParent);
+ stmt->body = parser->ParseBlockStatement();
+ return stmt;
+ }
+
+ static RefPtr<StatementSyntaxNode> ParseCaseStmt(Parser* parser)
+ {
+ RefPtr<CaseStmt> stmt = new CaseStmt();
+ parser->FillPosition(stmt.Ptr());
+ parser->ReadToken("case");
+ stmt->expr = parser->ParseExpression();
+ parser->ReadToken(TokenType::Colon);
+ return stmt;
+ }
+
+ static RefPtr<StatementSyntaxNode> ParseDefaultStmt(Parser* parser)
+ {
+ RefPtr<DefaultStmt> stmt = new DefaultStmt();
+ parser->FillPosition(stmt.Ptr());
+ parser->ReadToken("default");
+ parser->ReadToken(TokenType::Colon);
+ return stmt;
+ }
+
+ static bool peekTypeName(Parser* parser)
+ {
+ if(!parser->LookAheadToken(TokenType::Identifier))
+ return false;
+
+ auto name = parser->tokenReader.PeekToken().Content;
+
+ auto lookupResult = LookUp(name, parser->currentScope);
+ if(!lookupResult.isValid() || lookupResult.isOverloaded())
+ return false;
+
+ auto decl = lookupResult.item.declRef.GetDecl();
+ if( auto typeDecl = dynamic_cast<AggTypeDecl*>(decl) )
+ {
+ return true;
+ }
+ else if( auto typeVarDecl = dynamic_cast<SimpleTypeDecl*>(decl) )
+ {
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+
+ RefPtr<StatementSyntaxNode> Parser::ParseStatement()
+ {
+ auto modifiers = ParseModifiers(this);
+
+ RefPtr<StatementSyntaxNode> statement;
+ if (LookAheadToken(TokenType::LBrace))
+ statement = ParseBlockStatement();
+ else if (peekTypeName(this))
+ statement = ParseVarDeclrStatement(modifiers);
+ else if (LookAheadToken("if"))
+ statement = ParseIfStatement();
+ else if (LookAheadToken("for"))
+ statement = ParseForStatement();
+ else if (LookAheadToken("while"))
+ statement = ParseWhileStatement();
+ else if (LookAheadToken("do"))
+ statement = ParseDoWhileStatement();
+ else if (LookAheadToken("break"))
+ statement = ParseBreakStatement();
+ else if (LookAheadToken("continue"))
+ statement = ParseContinueStatement();
+ else if (LookAheadToken("return"))
+ statement = ParseReturnStatement();
+ else if (LookAheadToken("discard"))
+ {
+ statement = new DiscardStatementSyntaxNode();
+ FillPosition(statement.Ptr());
+ ReadToken("discard");
+ ReadToken(TokenType::Semicolon);
+ }
+ else if (LookAheadToken("switch"))
+ statement = ParseSwitchStmt(this);
+ else if (LookAheadToken("case"))
+ statement = ParseCaseStmt(this);
+ else if (LookAheadToken("default"))
+ statement = ParseDefaultStmt(this);
+ else if (LookAheadToken(TokenType::Identifier))
+ {
+ // We might be looking at a local declaration, or an
+ // expression statement, and we need to figure out which.
+ //
+ // We'll solve this with backtracking for now.
+
+ Token* startPos = tokenReader.mCursor;
+
+ // Try to parse a type (knowing that the type grammar is
+ // a subset of the expression grammar, and so this should
+ // always succeed).
+ RefPtr<ExpressionSyntaxNode> type = ParseType();
+ // We don't actually care about the type, though, so
+ // don't retain it
+ type = nullptr;
+
+ // If the next token after we parsed a type looks like
+ // we are going to declare a variable, then lets guess
+ // that this is a declaration.
+ //
+ // TODO(tfoley): this wouldn't be robust for more
+ // general kinds of declarators (notably pointer declarators),
+ // so we'll need to be careful about this.
+ if (LookAheadToken(TokenType::Identifier))
+ {
+ // Reset the cursor and try to parse a declaration now.
+ // Note: the declaration will consume any modifiers
+ // that had been in place on the statement.
+ tokenReader.mCursor = startPos;
+ statement = ParseVarDeclrStatement(modifiers);
+ return statement;
+ }
+
+ // Fallback: reset and parse an expression
+ tokenReader.mCursor = startPos;
+ statement = ParseExpressionStatement();
+ }
+ else if (LookAheadToken(TokenType::Semicolon))
+ {
+ statement = new EmptyStatementSyntaxNode();
+ FillPosition(statement.Ptr());
+ ReadToken(TokenType::Semicolon);
+ }
+ else
+ {
+ // Default case should always fall back to parsing an expression,
+ // and then let that detect any errors
+ statement = ParseExpressionStatement();
+ }
+
+ if (statement)
+ {
+ // Install any modifiers onto the statement.
+ // Note: this path is bypassed in the case of a
+ // declaration statement, so we don't end up
+ // doubling up the modifiers.
+ statement->modifiers = modifiers;
+ }
+
+ return statement;
+ }
+
+ RefPtr<StatementSyntaxNode> Parser::ParseBlockStatement()
+ {
+ if( options.flags & SLANG_COMPILE_FLAG_NO_CHECKING )
+ {
+ // We have been asked to parse the input, but not attempt to understand it.
+
+ // TODO: record start/end locations...
+
+ List<Token> tokens;
+
+ ReadToken(TokenType::LBrace);
+
+ int depth = 1;
+ for( ;;)
+ {
+ switch( tokenReader.PeekTokenType() )
+ {
+ case TokenType::EndOfFile:
+ goto done;
+
+ case TokenType::RBrace:
+ depth--;
+ if(depth == 0)
+ goto done;
+ break;
+
+ case TokenType::LBrace:
+ depth++;
+ break;
+
+ default:
+ break;
+ }
+
+ auto token = tokenReader.AdvanceToken();
+ tokens.Add(token);
+ }
+ done:
+ ReadToken(TokenType::RBrace);
+
+ RefPtr<UnparsedStmt> unparsedStmt = new UnparsedStmt();
+ unparsedStmt->tokens = tokens;
+ return unparsedStmt;
+ }
+
+
+ RefPtr<ScopeDecl> scopeDecl = new ScopeDecl();
+ RefPtr<BlockStatementSyntaxNode> blockStatement = new BlockStatementSyntaxNode();
+ blockStatement->scopeDecl = scopeDecl;
+ PushScope(scopeDecl.Ptr());
+ ReadToken(TokenType::LBrace);
+ if(!tokenReader.IsAtEnd())
+ {
+ FillPosition(blockStatement.Ptr());
+ }
+ while (!AdvanceIfMatch(this, TokenType::RBrace))
+ {
+ auto stmt = ParseStatement();
+ if(stmt)
+ {
+ blockStatement->Statements.Add(stmt);
+ }
+ TryRecover(this);
+ }
+ PopScope();
+ return blockStatement;
+ }
+
+ RefPtr<VarDeclrStatementSyntaxNode> Parser::ParseVarDeclrStatement(
+ Modifiers modifiers)
+ {
+ RefPtr<VarDeclrStatementSyntaxNode>varDeclrStatement = new VarDeclrStatementSyntaxNode();
+
+ FillPosition(varDeclrStatement.Ptr());
+ auto decl = ParseDeclWithModifiers(this, currentScope->containerDecl, modifiers);
+ varDeclrStatement->decl = decl;
+ return varDeclrStatement;
+ }
+
+ RefPtr<IfStatementSyntaxNode> Parser::ParseIfStatement()
+ {
+ RefPtr<IfStatementSyntaxNode> ifStatement = new IfStatementSyntaxNode();
+ FillPosition(ifStatement.Ptr());
+ ReadToken("if");
+ ReadToken(TokenType::LParent);
+ ifStatement->Predicate = ParseExpression();
+ ReadToken(TokenType::RParent);
+ ifStatement->PositiveStatement = ParseStatement();
+ if (LookAheadToken("else"))
+ {
+ ReadToken("else");
+ ifStatement->NegativeStatement = ParseStatement();
+ }
+ return ifStatement;
+ }
+
+ RefPtr<ForStatementSyntaxNode> Parser::ParseForStatement()
+ {
+ RefPtr<ScopeDecl> scopeDecl = new ScopeDecl();
+ RefPtr<ForStatementSyntaxNode> stmt = new ForStatementSyntaxNode();
+ stmt->scopeDecl = scopeDecl;
+
+ // Note(tfoley): HLSL implements `for` with incorrect scoping.
+ // We need an option to turn on this behavior in a kind of "legacy" mode
+// PushScope(scopeDecl.Ptr());
+ FillPosition(stmt.Ptr());
+ ReadToken("for");
+ ReadToken(TokenType::LParent);
+ if (peekTypeName(this))
+ {
+ stmt->InitialStatement = ParseVarDeclrStatement(Modifiers());
+ }
+ else
+ {
+ if (!LookAheadToken(TokenType::Semicolon))
+ {
+ stmt->InitialStatement = ParseExpressionStatement();
+ }
+ else
+ {
+ ReadToken(TokenType::Semicolon);
+ }
+ }
+ if (!LookAheadToken(TokenType::Semicolon))
+ stmt->PredicateExpression = ParseExpression();
+ ReadToken(TokenType::Semicolon);
+ if (!LookAheadToken(TokenType::RParent))
+ stmt->SideEffectExpression = ParseExpression();
+ ReadToken(TokenType::RParent);
+ stmt->Statement = ParseStatement();
+// PopScope();
+ return stmt;
+ }
+
+ RefPtr<WhileStatementSyntaxNode> Parser::ParseWhileStatement()
+ {
+ RefPtr<WhileStatementSyntaxNode> whileStatement = new WhileStatementSyntaxNode();
+ FillPosition(whileStatement.Ptr());
+ ReadToken("while");
+ ReadToken(TokenType::LParent);
+ whileStatement->Predicate = ParseExpression();
+ ReadToken(TokenType::RParent);
+ whileStatement->Statement = ParseStatement();
+ return whileStatement;
+ }
+
+ RefPtr<DoWhileStatementSyntaxNode> Parser::ParseDoWhileStatement()
+ {
+ RefPtr<DoWhileStatementSyntaxNode> doWhileStatement = new DoWhileStatementSyntaxNode();
+ FillPosition(doWhileStatement.Ptr());
+ ReadToken("do");
+ doWhileStatement->Statement = ParseStatement();
+ ReadToken("while");
+ ReadToken(TokenType::LParent);
+ doWhileStatement->Predicate = ParseExpression();
+ ReadToken(TokenType::RParent);
+ ReadToken(TokenType::Semicolon);
+ return doWhileStatement;
+ }
+
+ RefPtr<BreakStatementSyntaxNode> Parser::ParseBreakStatement()
+ {
+ RefPtr<BreakStatementSyntaxNode> breakStatement = new BreakStatementSyntaxNode();
+ FillPosition(breakStatement.Ptr());
+ ReadToken("break");
+ ReadToken(TokenType::Semicolon);
+ return breakStatement;
+ }
+
+ RefPtr<ContinueStatementSyntaxNode> Parser::ParseContinueStatement()
+ {
+ RefPtr<ContinueStatementSyntaxNode> continueStatement = new ContinueStatementSyntaxNode();
+ FillPosition(continueStatement.Ptr());
+ ReadToken("continue");
+ ReadToken(TokenType::Semicolon);
+ return continueStatement;
+ }
+
+ RefPtr<ReturnStatementSyntaxNode> Parser::ParseReturnStatement()
+ {
+ RefPtr<ReturnStatementSyntaxNode> returnStatement = new ReturnStatementSyntaxNode();
+ FillPosition(returnStatement.Ptr());
+ ReadToken("return");
+ if (!LookAheadToken(TokenType::Semicolon))
+ returnStatement->Expression = ParseExpression();
+ ReadToken(TokenType::Semicolon);
+ return returnStatement;
+ }
+
+ RefPtr<ExpressionStatementSyntaxNode> Parser::ParseExpressionStatement()
+ {
+ RefPtr<ExpressionStatementSyntaxNode> statement = new ExpressionStatementSyntaxNode();
+
+ FillPosition(statement.Ptr());
+ statement->Expression = ParseExpression();
+
+ ReadToken(TokenType::Semicolon);
+ return statement;
+ }
+
+ RefPtr<ParameterSyntaxNode> Parser::ParseParameter()
+ {
+ RefPtr<ParameterSyntaxNode> parameter = new ParameterSyntaxNode();
+ parameter->modifiers = ParseModifiers(this);
+
+ DeclaratorInfo declaratorInfo;
+ declaratorInfo.typeSpec = ParseType();
+
+ InitDeclarator initDeclarator = ParseInitDeclarator(this);
+ UnwrapDeclarator(initDeclarator, &declaratorInfo);
+
+ // Assume it is a variable-like declarator
+ CompleteVarDecl(this, parameter, declaratorInfo);
+ return parameter;
+ }
+
+ RefPtr<ExpressionSyntaxNode> Parser::ParseType()
+ {
+ auto typeSpec = parseTypeSpec(this);
+ if( typeSpec.decl )
+ {
+ AddMember(currentScope, typeSpec.decl);
+ }
+ auto typeExpr = typeSpec.expr;
+
+ while (LookAheadToken(TokenType::LBracket))
+ {
+ RefPtr<IndexExpressionSyntaxNode> arrType = new IndexExpressionSyntaxNode();
+ arrType->Position = typeExpr->Position;
+ arrType->BaseExpression = typeExpr;
+ ReadToken(TokenType::LBracket);
+ if (!LookAheadToken(TokenType::RBracket))
+ {
+ arrType->IndexExpression = ParseExpression();
+ }
+ ReadToken(TokenType::RBracket);
+ typeExpr = arrType;
+ }
+
+ return typeExpr;
+ }
+
+
+
+ TypeExp Parser::ParseTypeExp()
+ {
+ return TypeExp(ParseType());
+ }
+
+ enum class Associativity
+ {
+ Left, Right
+ };
+
+
+
+ Associativity GetAssociativityFromLevel(Precedence level)
+ {
+ if (level == Precedence::Assignment)
+ return Associativity::Right;
+ else
+ return Associativity::Left;
+ }
+
+
+
+
+ Precedence GetOpLevel(Parser* parser, TokenType type)
+ {
+ switch(type)
+ {
+ case TokenType::Comma:
+ return Precedence::Comma;
+ case TokenType::OpAssign:
+ case TokenType::OpMulAssign:
+ case TokenType::OpDivAssign:
+ case TokenType::OpAddAssign:
+ case TokenType::OpSubAssign:
+ case TokenType::OpModAssign:
+ case TokenType::OpShlAssign:
+ case TokenType::OpShrAssign:
+ case TokenType::OpOrAssign:
+ case TokenType::OpAndAssign:
+ case TokenType::OpXorAssign:
+ return Precedence::Assignment;
+ case TokenType::OpOr:
+ return Precedence::LogicalOr;
+ case TokenType::OpAnd:
+ return Precedence::LogicalAnd;
+ case TokenType::OpBitOr:
+ return Precedence::BitOr;
+ case TokenType::OpBitXor:
+ return Precedence::BitXor;
+ case TokenType::OpBitAnd:
+ return Precedence::BitAnd;
+ case TokenType::OpEql:
+ case TokenType::OpNeq:
+ return Precedence::EqualityComparison;
+ case TokenType::OpGreater:
+ case TokenType::OpGeq:
+ // Don't allow these ops inside a generic argument
+ if (parser->genericDepth > 0) return Precedence::Invalid;
+ case TokenType::OpLeq:
+ case TokenType::OpLess:
+ return Precedence::RelationalComparison;
+ case TokenType::OpRsh:
+ // Don't allow this op inside a generic argument
+ if (parser->genericDepth > 0) return Precedence::Invalid;
+ case TokenType::OpLsh:
+ return Precedence::BitShift;
+ case TokenType::OpAdd:
+ case TokenType::OpSub:
+ return Precedence::Additive;
+ case TokenType::OpMul:
+ case TokenType::OpDiv:
+ case TokenType::OpMod:
+ return Precedence::Multiplicative;
+ default:
+ return Precedence::Invalid;
+ }
+ }
+
+ Operator GetOpFromToken(Token & token)
+ {
+ switch(token.Type)
+ {
+ case TokenType::Comma:
+ return Operator::Sequence;
+ case TokenType::OpAssign:
+ return Operator::Assign;
+ case TokenType::OpAddAssign:
+ return Operator::AddAssign;
+ case TokenType::OpSubAssign:
+ return Operator::SubAssign;
+ case TokenType::OpMulAssign:
+ return Operator::MulAssign;
+ case TokenType::OpDivAssign:
+ return Operator::DivAssign;
+ case TokenType::OpModAssign:
+ return Operator::ModAssign;
+ case TokenType::OpShlAssign:
+ return Operator::LshAssign;
+ case TokenType::OpShrAssign:
+ return Operator::RshAssign;
+ case TokenType::OpOrAssign:
+ return Operator::OrAssign;
+ case TokenType::OpAndAssign:
+ return Operator::AddAssign;
+ case TokenType::OpXorAssign:
+ return Operator::XorAssign;
+ case TokenType::OpOr:
+ return Operator::Or;
+ case TokenType::OpAnd:
+ return Operator::And;
+ case TokenType::OpBitOr:
+ return Operator::BitOr;
+ case TokenType::OpBitXor:
+ return Operator::BitXor;
+ case TokenType::OpBitAnd:
+ return Operator::BitAnd;
+ case TokenType::OpEql:
+ return Operator::Eql;
+ case TokenType::OpNeq:
+ return Operator::Neq;
+ case TokenType::OpGeq:
+ return Operator::Geq;
+ case TokenType::OpLeq:
+ return Operator::Leq;
+ case TokenType::OpGreater:
+ return Operator::Greater;
+ case TokenType::OpLess:
+ return Operator::Less;
+ case TokenType::OpLsh:
+ return Operator::Lsh;
+ case TokenType::OpRsh:
+ return Operator::Rsh;
+ case TokenType::OpAdd:
+ return Operator::Add;
+ case TokenType::OpSub:
+ return Operator::Sub;
+ case TokenType::OpMul:
+ return Operator::Mul;
+ case TokenType::OpDiv:
+ return Operator::Div;
+ case TokenType::OpMod:
+ return Operator::Mod;
+ case TokenType::OpInc:
+ return Operator::PostInc;
+ case TokenType::OpDec:
+ return Operator::PostDec;
+ case TokenType::OpNot:
+ return Operator::Not;
+ case TokenType::OpBitNot:
+ return Operator::BitNot;
+ default:
+ throw "Illegal TokenType.";
+ }
+ }
+
+ static RefPtr<ExpressionSyntaxNode> parseOperator(Parser* parser)
+ {
+ Token opToken;
+ switch(parser->tokenReader.PeekTokenType())
+ {
+ case TokenType::QuestionMark:
+ opToken = parser->ReadToken();
+ opToken.Content = "?:";
+ break;
+
+ default:
+ opToken = parser->ReadToken();
+ break;
+ }
+
+ auto opExpr = new VarExpressionSyntaxNode();
+ opExpr->Variable = opToken.Content;
+ opExpr->scope = parser->currentScope;
+ opExpr->Position = opToken.Position;
+
+ return opExpr;
+
+ }
+
+ RefPtr<ExpressionSyntaxNode> Parser::ParseExpression(Precedence level)
+ {
+ if (level == Precedence::Prefix)
+ return ParseLeafExpression();
+ if (level == Precedence::TernaryConditional)
+ {
+ // parse select clause
+ auto condition = ParseExpression(Precedence(level + 1));
+ if (LookAheadToken(TokenType::QuestionMark))
+ {
+ RefPtr<SelectExpressionSyntaxNode> select = new SelectExpressionSyntaxNode();
+ FillPosition(select.Ptr());
+
+ select->Arguments.Add(condition);
+
+ select->FunctionExpr = parseOperator(this);
+
+ select->Arguments.Add(ParseExpression(level));
+ ReadToken(TokenType::Colon);
+ select->Arguments.Add(ParseExpression(level));
+ return select;
+ }
+ else
+ return condition;
+ }
+ else
+ {
+ if (GetAssociativityFromLevel(level) == Associativity::Left)
+ {
+ auto left = ParseExpression(Precedence(level + 1));
+ while (GetOpLevel(this, tokenReader.PeekTokenType()) == level)
+ {
+ RefPtr<OperatorExpressionSyntaxNode> tmp = new InfixExpr();
+ tmp->FunctionExpr = parseOperator(this);
+
+ tmp->Arguments.Add(left);
+ FillPosition(tmp.Ptr());
+ tmp->Arguments.Add(ParseExpression(Precedence(level + 1)));
+ left = tmp;
+ }
+ return left;
+ }
+ else
+ {
+ auto left = ParseExpression(Precedence(level + 1));
+ if (GetOpLevel(this, tokenReader.PeekTokenType()) == level)
+ {
+ RefPtr<OperatorExpressionSyntaxNode> tmp = new InfixExpr();
+ tmp->Arguments.Add(left);
+ FillPosition(tmp.Ptr());
+ tmp->FunctionExpr = parseOperator(this);
+ tmp->Arguments.Add(ParseExpression(level));
+ left = tmp;
+ }
+ return left;
+ }
+ }
+ }
+
+ RefPtr<ExpressionSyntaxNode> Parser::ParseLeafExpression()
+ {
+ RefPtr<ExpressionSyntaxNode> rs;
+ if (LookAheadToken(TokenType::OpInc) ||
+ LookAheadToken(TokenType::OpDec) ||
+ LookAheadToken(TokenType::OpNot) ||
+ LookAheadToken(TokenType::OpBitNot) ||
+ LookAheadToken(TokenType::OpSub))
+ {
+ RefPtr<OperatorExpressionSyntaxNode> unaryExpr = new PrefixExpr();
+ FillPosition(unaryExpr.Ptr());
+ unaryExpr->FunctionExpr = parseOperator(this);
+ unaryExpr->Arguments.Add(ParseLeafExpression());
+ rs = unaryExpr;
+ return rs;
+ }
+
+ if (LookAheadToken(TokenType::LParent))
+ {
+ ReadToken(TokenType::LParent);
+ RefPtr<ExpressionSyntaxNode> expr;
+ if (peekTypeName(this) && LookAheadToken(TokenType::RParent, 1))
+ {
+ RefPtr<TypeCastExpressionSyntaxNode> tcexpr = new TypeCastExpressionSyntaxNode();
+ FillPosition(tcexpr.Ptr());
+ tcexpr->TargetType = ParseTypeExp();
+ ReadToken(TokenType::RParent);
+ tcexpr->Expression = ParseExpression(Precedence::Multiplicative); // Note(tfoley): need to double-check this
+ expr = tcexpr;
+ }
+ else
+ {
+ expr = ParseExpression();
+ ReadToken(TokenType::RParent);
+ }
+ rs = expr;
+ }
+ else if( LookAheadToken(TokenType::LBrace) )
+ {
+ RefPtr<InitializerListExpr> initExpr = new InitializerListExpr();
+ FillPosition(initExpr.Ptr());
+
+ // Initializer list
+ ReadToken(TokenType::LBrace);
+
+ List<RefPtr<ExpressionSyntaxNode>> exprs;
+
+ for(;;)
+ {
+ if(AdvanceIfMatch(this, TokenType::RBrace))
+ break;
+
+ auto expr = ParseArgExpr();
+ if( expr )
+ {
+ initExpr->args.Add(expr);
+ }
+
+ if(AdvanceIfMatch(this, TokenType::RBrace))
+ break;
+
+ ReadToken(TokenType::Comma);
+ }
+ rs = initExpr;
+ }
+
+ else if (LookAheadToken(TokenType::IntLiterial) ||
+ LookAheadToken(TokenType::DoubleLiterial))
+ {
+ RefPtr<ConstantExpressionSyntaxNode> constExpr = new ConstantExpressionSyntaxNode();
+ auto token = tokenReader.AdvanceToken();
+ FillPosition(constExpr.Ptr());
+ if (token.Type == TokenType::IntLiterial)
+ {
+ constExpr->ConstType = ConstantExpressionSyntaxNode::ConstantType::Int;
+ constExpr->IntValue = StringToInt(token.Content);
+ }
+ else if (token.Type == TokenType::DoubleLiterial)
+ {
+ constExpr->ConstType = ConstantExpressionSyntaxNode::ConstantType::Float;
+ constExpr->FloatValue = (FloatingPointLiteralValue) StringToDouble(token.Content);
+ }
+ rs = constExpr;
+ }
+ else if (LookAheadToken("true") || LookAheadToken("false"))
+ {
+ RefPtr<ConstantExpressionSyntaxNode> constExpr = new ConstantExpressionSyntaxNode();
+ auto token = tokenReader.AdvanceToken();
+ FillPosition(constExpr.Ptr());
+ constExpr->ConstType = ConstantExpressionSyntaxNode::ConstantType::Bool;
+ constExpr->IntValue = token.Content == "true" ? 1 : 0;
+ rs = constExpr;
+ }
+ else if (LookAheadToken(TokenType::Identifier))
+ {
+ RefPtr<VarExpressionSyntaxNode> varExpr = new VarExpressionSyntaxNode();
+ varExpr->scope = currentScope.Ptr();
+ FillPosition(varExpr.Ptr());
+ auto token = ReadToken(TokenType::Identifier);
+ varExpr->Variable = token.Content;
+ rs = varExpr;
+ }
+
+ while (!tokenReader.IsAtEnd() &&
+ (LookAheadToken(TokenType::OpInc) ||
+ LookAheadToken(TokenType::OpDec) ||
+ LookAheadToken(TokenType::Dot) ||
+ LookAheadToken(TokenType::LBracket) ||
+ LookAheadToken(TokenType::LParent)))
+ {
+ if (LookAheadToken(TokenType::OpInc))
+ {
+ RefPtr<OperatorExpressionSyntaxNode> unaryExpr = new PostfixExpr();
+ FillPosition(unaryExpr.Ptr());
+ unaryExpr->FunctionExpr = parseOperator(this);
+ unaryExpr->Arguments.Add(rs);
+ rs = unaryExpr;
+ }
+ else if (LookAheadToken(TokenType::OpDec))
+ {
+ RefPtr<OperatorExpressionSyntaxNode> unaryExpr = new PostfixExpr();
+ FillPosition(unaryExpr.Ptr());
+ unaryExpr->FunctionExpr = parseOperator(this);
+ unaryExpr->Arguments.Add(rs);
+ rs = unaryExpr;
+ }
+ else if (LookAheadToken(TokenType::LBracket))
+ {
+ RefPtr<IndexExpressionSyntaxNode> indexExpr = new IndexExpressionSyntaxNode();
+ indexExpr->BaseExpression = rs;
+ FillPosition(indexExpr.Ptr());
+ ReadToken(TokenType::LBracket);
+ indexExpr->IndexExpression = ParseExpression();
+ ReadToken(TokenType::RBracket);
+ rs = indexExpr;
+ }
+ else if (LookAheadToken(TokenType::LParent))
+ {
+ RefPtr<InvokeExpressionSyntaxNode> invokeExpr = new InvokeExpressionSyntaxNode();
+ invokeExpr->FunctionExpr = rs;
+ FillPosition(invokeExpr.Ptr());
+ ReadToken(TokenType::LParent);
+ while (!tokenReader.IsAtEnd())
+ {
+ if (!LookAheadToken(TokenType::RParent))
+ invokeExpr->Arguments.Add(ParseArgExpr());
+ else
+ {
+ break;
+ }
+ if (!LookAheadToken(TokenType::Comma))
+ break;
+ ReadToken(TokenType::Comma);
+ }
+ ReadToken(TokenType::RParent);
+ rs = invokeExpr;
+ }
+ else if (LookAheadToken(TokenType::Dot))
+ {
+ RefPtr<MemberExpressionSyntaxNode> memberExpr = new MemberExpressionSyntaxNode();
+ memberExpr->scope = currentScope.Ptr();
+ FillPosition(memberExpr.Ptr());
+ memberExpr->BaseExpression = rs;
+ ReadToken(TokenType::Dot);
+ memberExpr->MemberName = ReadToken(TokenType::Identifier).Content;
+ rs = memberExpr;
+ }
+ }
+ if (!rs)
+ {
+ sink->diagnose(tokenReader.PeekLoc(), Diagnostics::syntaxError);
+ }
+ return rs;
+ }
+
+ // Parse a source file into an existing translation unit
+ void parseSourceFile(
+ ProgramSyntaxNode* translationUnitSyntax,
+ CompileOptions& options,
+ TokenSpan const& tokens,
+ DiagnosticSink* sink,
+ String const& fileName,
+ RefPtr<Scope> const&outerScope)
+ {
+ Parser parser(options, tokens, sink, fileName, outerScope);
+ return parser.parseSourceFile(translationUnitSyntax);
+ }
+ }
+} \ No newline at end of file
diff --git a/source/slang/parser.h b/source/slang/parser.h
new file mode 100644
index 000000000..90af69158
--- /dev/null
+++ b/source/slang/parser.h
@@ -0,0 +1,23 @@
+#ifndef RASTER_RENDERER_PARSER_H
+#define RASTER_RENDERER_PARSER_H
+
+#include "lexer.h"
+#include "compiler.h"
+#include "syntax.h"
+
+namespace Slang
+{
+ namespace Compiler
+ {
+ // Parse a source file into an existing translation unit
+ void parseSourceFile(
+ ProgramSyntaxNode* translationUnitSyntax,
+ CompileOptions& options,
+ TokenSpan const& tokens,
+ DiagnosticSink* sink,
+ String const& fileName,
+ RefPtr<Scope> const&outerScope);
+ }
+}
+
+#endif \ No newline at end of file
diff --git a/source/slang/preprocessor.cpp b/source/slang/preprocessor.cpp
new file mode 100644
index 000000000..cdde2591d
--- /dev/null
+++ b/source/slang/preprocessor.cpp
@@ -0,0 +1,2032 @@
+// Preprocessor.cpp
+#include "Preprocessor.h"
+
+#include "Diagnostics.h"
+#include "Lexer.h"
+
+// Needed so that we can construct modifier syntax
+// to represent GLSL directives
+#include "Syntax.h"
+
+#include <assert.h>
+
+using namespace CoreLib;
+
+// This file provides an implementation of a simple C-style preprocessor.
+// It does not aim for 100% compatibility with any particular preprocessor
+// specification, but the goal is to have it accept the most common
+// idioms for using the preprocessor, found in shader code in the wild.
+
+
+namespace Slang{ namespace Compiler {
+
+// State of a preprocessor conditional, which can change when
+// we encounter directives like `#elif` or `#endif`
+enum class PreprocessorConditionalState
+{
+ Before, // We have not yet seen a branch with a `true` condition.
+ During, // We are inside the branch with a `true` condition.
+ After, // We have already seen the branch with a `true` condition.
+};
+
+// Represents a preprocessor conditional that we are currently
+// nested inside.
+struct PreprocessorConditional
+{
+ // The next outer conditional in the current file/stream, or NULL.
+ PreprocessorConditional* parent;
+
+ // The directive token that started the conditional (an `#if` or `#ifdef`)
+ Token ifToken;
+
+ // The `#else` directive token, if one has been seen (otherwise `TokenType::Unknown`)
+ Token elseToken;
+
+ // The state of the conditional
+ PreprocessorConditionalState state;
+};
+
+struct PreprocessorMacro;
+
+struct PreprocessorEnvironment
+{
+ // The "outer" environment, to be used if lookup in this env fails
+ PreprocessorEnvironment* parent = NULL;
+
+ // Macros defined in this environment
+ Dictionary<String, PreprocessorMacro*> macros;
+
+ ~PreprocessorEnvironment();
+};
+
+// Input tokens can either come from source text, or from macro expansion.
+// In general, input streams can be nested, so we have to keep a conceptual
+// stack of input.
+
+// A stream of input tokens to be consumed
+struct PreprocessorInputStream
+{
+ // The next input stream up the stack, if any.
+ PreprocessorInputStream* parent;
+
+ // The deepest preprocessor conditional active for this stream.
+ PreprocessorConditional* conditional;
+
+ // Environment to use when looking up macros
+ PreprocessorEnvironment* environment;
+
+ // Reader for pre-tokenized input
+ TokenReader tokenReader;
+
+ // If we are clobbering source locations with `#line`, then
+ // the state is tracked here:
+
+ // Are we overriding source locations?
+ bool isOverridingSourceLoc;
+
+ // What is the file name we are overriding to?
+ String overrideFileName;
+
+ // What is the relative offset to apply to any line numbers?
+ int overrideLineOffset;
+
+ // Destructor is virtual so that we can clean up
+ // after concrete subtypes.
+ virtual ~PreprocessorInputStream() = default;
+};
+
+struct SourceTextInputStream : PreprocessorInputStream
+{
+ // The pre-tokenized input
+ TokenList lexedTokens;
+};
+
+struct MacroExpansion : PreprocessorInputStream
+{
+ // The macro we will expand
+ PreprocessorMacro* macro;
+};
+
+struct ObjectLikeMacroExpansion : MacroExpansion
+{
+};
+
+struct FunctionLikeMacroExpansion : MacroExpansion
+{
+ // Environment for macro arguments
+ PreprocessorEnvironment argumentEnvironment;
+};
+
+// An enumeration for the diferent types of macros
+enum class PreprocessorMacroFlavor
+{
+ ObjectLike,
+ FunctionArg,
+ FunctionLike,
+};
+
+// In the current design (which we may want to re-consider),
+// a macro is a specialized flavor of input stream, that
+// captures the token list in its expansion, and then
+// can be "played back."
+struct PreprocessorMacro
+{
+ // The name under which the macro was `#define`d
+ Token nameToken;
+
+ // Parameters of the macro, in case of a function-like macro
+ List<Token> params;
+
+ // The tokens that make up the macro body
+ TokenList tokens;
+
+ // The flavor of macro
+ PreprocessorMacroFlavor flavor;
+
+ // The environment in which this macro needs to be expanded.
+ // For ordinary macros this will be the global environment,
+ // while for function-like macro arguments, it will be
+ // the environment of the macro invocation.
+ PreprocessorEnvironment* environment;
+};
+
+// State of the preprocessor
+struct Preprocessor
+{
+ // diagnostics sink to use when writing messages
+ DiagnosticSink* sink;
+
+ // An external callback interface to use when looking
+ // for files in a `#include` directive
+ IncludeHandler* includeHandler;
+
+ // Current input stream (top of the stack of input)
+ PreprocessorInputStream* inputStream;
+
+ // Currently-defined macros
+ PreprocessorEnvironment globalEnv;
+
+ // A pre-allocated token that can be returned to
+ // represent end-of-input situations.
+ Token endOfFileToken;
+
+ // Syntax for the program we are trying to parse
+ ProgramSyntaxNode* syntax;
+};
+
+// Convenience routine to access the diagnostic sink
+static DiagnosticSink* GetSink(Preprocessor* preprocessor)
+{
+ return preprocessor->sink;
+}
+
+//
+// Forward declarations
+//
+
+static void DestroyConditional(PreprocessorConditional* conditional);
+static void DestroyMacro(Preprocessor* preprocessor, PreprocessorMacro* macro);
+
+//
+// Basic Input Handling
+//
+
+// Create a fresh input stream
+static void InitializeInputStream(Preprocessor* preprocessor, PreprocessorInputStream* inputStream)
+{
+ inputStream->parent = NULL;
+ inputStream->conditional = NULL;
+ inputStream->environment = &preprocessor->globalEnv;
+}
+
+// Destroy an input stream
+static void DestroyInputStream(Preprocessor* /*preprocessor*/, PreprocessorInputStream* inputStream)
+{
+ delete inputStream;
+}
+
+// Create an input stream to represent a pre-tokenized input file.
+// TODO(tfoley): pre-tokenizing files isn't going to work in the long run.
+static PreprocessorInputStream* CreateInputStreamForSource(Preprocessor* preprocessor, CoreLib::String const& source, CoreLib::String const& fileName)
+{
+ SourceTextInputStream* inputStream = new SourceTextInputStream();
+ InitializeInputStream(preprocessor, inputStream);
+
+ // Use existing `Lexer` to generate a token stream.
+ Lexer lexer(fileName, source, GetSink(preprocessor));
+ inputStream->lexedTokens = lexer.lexAllTokens();
+ inputStream->tokenReader = TokenReader(inputStream->lexedTokens);
+
+ return inputStream;
+}
+
+
+
+static void PushInputStream(Preprocessor* preprocessor, PreprocessorInputStream* inputStream)
+{
+ inputStream->parent = preprocessor->inputStream;
+ preprocessor->inputStream = inputStream;
+}
+
+// Called when we reach the end of an input stream.
+// Performs some validation and then destroys the input stream if required.
+static void EndInputStream(Preprocessor* preprocessor, PreprocessorInputStream* inputStream)
+{
+ // If there are any conditionals that weren't completed, then it is an error
+ if (inputStream->conditional)
+ {
+ PreprocessorConditional* conditional = inputStream->conditional;
+
+ GetSink(preprocessor)->diagnose(conditional->ifToken.Position, Diagnostics::endOfFileInPreprocessorConditional);
+
+ while (conditional)
+ {
+ PreprocessorConditional* parent = conditional->parent;
+ DestroyConditional(conditional);
+ conditional = parent;
+ }
+ }
+
+ DestroyInputStream(preprocessor, inputStream);
+}
+
+// Potentially clobber source location information based on `#line`
+static Token PossiblyOverrideSourceLoc(PreprocessorInputStream* inputStream, Token const& token)
+{
+ Token result = token;
+ if( inputStream->isOverridingSourceLoc )
+ {
+ result.Position.FileName = inputStream->overrideFileName;
+ result.Position.Line += inputStream->overrideLineOffset;
+ }
+ return result;
+}
+
+// Consume one token from an input stream
+static Token AdvanceRawToken(PreprocessorInputStream* inputStream)
+{
+ return PossiblyOverrideSourceLoc(inputStream, inputStream->tokenReader.AdvanceToken());
+}
+
+// Peek one token from an input stream
+static Token PeekRawToken(PreprocessorInputStream* inputStream)
+{
+ return PossiblyOverrideSourceLoc(inputStream, inputStream->tokenReader.PeekToken());
+}
+
+// Peek one token type from an input stream
+static TokenType PeekRawTokenType(PreprocessorInputStream* inputStream)
+{
+ return inputStream->tokenReader.PeekTokenType();
+}
+
+
+// Read one token in "raw" mode (meaning don't expand macros)
+static Token AdvanceRawToken(Preprocessor* preprocessor)
+{
+ for (;;)
+ {
+ // Look at the input stream on top of the stack
+ PreprocessorInputStream* inputStream = preprocessor->inputStream;
+
+ // If there isn't one, then there is no more input left to read.
+ if (!inputStream)
+ {
+ return preprocessor->endOfFileToken;
+ }
+
+ // The top-most input stream may be at its end
+ if (PeekRawTokenType(inputStream) == TokenType::EndOfFile)
+ {
+ // If there is another stream remaining, switch to it
+ if (inputStream->parent)
+ {
+ preprocessor->inputStream = inputStream->parent;
+ EndInputStream(preprocessor, inputStream);
+ continue;
+ }
+ else
+ {
+ // HACK(tfoley): A place to fall into debugger...
+ int f = 0;
+ }
+ }
+
+ // Everything worked, so read a token from the top-most stream
+ return AdvanceRawToken(inputStream);
+ }
+}
+
+// Return the next token in "raw" mode, but don't advance the
+// current token state.
+static Token PeekRawToken(Preprocessor* preprocessor)
+{
+ // We need to find the strema that `advanceRawToken` would read from.
+ PreprocessorInputStream* inputStream = preprocessor->inputStream;
+ for (;;)
+ {
+ if (!inputStream)
+ {
+ // No more input streams left to read
+ return preprocessor->endOfFileToken;
+ }
+
+ // The top-most input stream may be at its end, so
+ // look one entry up the stack (don't actually pop
+ // here, since we are just peeking)
+ if (PeekRawTokenType(inputStream) == TokenType::EndOfFile)
+ {
+ if (inputStream->parent)
+ {
+ inputStream = inputStream->parent;
+ continue;
+ }
+ else
+ {
+ // HACK(tfoley): A place to fall into debugger...
+ int f = 0;
+ }
+ }
+
+ // Everything worked, so the token we just peeked is fine.
+ return PeekRawToken(inputStream);
+ }
+}
+
+// Without advancing preprocessor state, look *two* raw tokens ahead
+// (This is only needed in order to determine when we are possibly
+// expanding a function-style macro)
+TokenType PeekSecondRawTokenType(Preprocessor* preprocessor)
+{
+ // We need to find the strema that `advanceRawToken` would read from.
+ PreprocessorInputStream* inputStream = preprocessor->inputStream;
+ int count = 1;
+ for (;;)
+ {
+ if (!inputStream)
+ {
+ // No more input streams left to read
+ return TokenType::EndOfFile;
+ }
+
+ // The top-most input stream may be at its end, so
+ // look one entry up the stack (don't actually pop
+ // here, since we are just peeking)
+
+ TokenReader reader = inputStream->tokenReader;
+ if (reader.PeekTokenType() == TokenType::EndOfFile)
+ {
+ inputStream = inputStream->parent;
+ continue;
+ }
+
+ if (count)
+ {
+ count--;
+
+ // Note: we are advancing our temporary
+ // copy of the token reader
+ reader.AdvanceToken();
+ if (reader.PeekTokenType() == TokenType::EndOfFile)
+ {
+ inputStream = inputStream->parent;
+ continue;
+ }
+ }
+
+ // Everything worked, so peek a token from the top-most stream
+ return reader.PeekTokenType();
+ }
+}
+
+
+// Get the location of the current (raw) token
+static CodePosition PeekLoc(Preprocessor* preprocessor)
+{
+ return PeekRawToken(preprocessor).Position;
+}
+
+// Get the `TokenType` of the current (raw) token
+static TokenType PeekRawTokenType(Preprocessor* preprocessor)
+{
+ return PeekRawToken(preprocessor).Type;
+}
+
+//
+// Macros
+//
+
+// Create a macro
+static PreprocessorMacro* CreateMacro(Preprocessor* preprocessor)
+{
+ // TODO(tfoley): Allocate these more intelligently.
+ // For example, consider pooling them on the preprocessor.
+
+ PreprocessorMacro* macro = new PreprocessorMacro();
+ macro->flavor = PreprocessorMacroFlavor::ObjectLike;
+ macro->environment = &preprocessor->globalEnv;
+ return macro;
+}
+
+// Destroy a macro
+static void DestroyMacro(Preprocessor* /*preprocessor*/, PreprocessorMacro* macro)
+{
+ delete macro;
+}
+
+
+// Find the currently-defined macro of the given name, or return NULL
+static PreprocessorMacro* LookupMacro(PreprocessorEnvironment* environment, String const& name)
+{
+ for(PreprocessorEnvironment* e = environment; e; e = e->parent)
+ {
+ PreprocessorMacro* macro = NULL;
+ if (e->macros.TryGetValue(name, macro))
+ return macro;
+ }
+
+ return NULL;
+}
+
+static PreprocessorEnvironment* GetCurrentEnvironment(Preprocessor* preprocessor)
+{
+ PreprocessorInputStream* inputStream = preprocessor->inputStream;
+ return inputStream ? inputStream->environment : &preprocessor->globalEnv;
+}
+
+static PreprocessorMacro* LookupMacro(Preprocessor* preprocessor, String const& name)
+{
+ return LookupMacro(GetCurrentEnvironment(preprocessor), name);
+}
+
+// A macro is "busy" if it is currently being used for expansion.
+// A macro cannot be expanded again while busy, to avoid infinite recursion.
+static bool IsMacroBusy(PreprocessorMacro* /*macro*/)
+{
+ // TODO: need to implement this correctly
+ //
+ // The challenge here is that we are implementing expansion
+ // for argumenst to function-like macros in a "lazy" fashion.
+ //
+ // The letter of the spec is that we should macro expand
+ // each argument *before* substitution, and then go and
+ // macro-expand the substituted body. This means that we
+ // can invoke a macro as part of an argument to an
+ // invocation of the same macro:
+ //
+ // FOO( 1, FOO(22), 333 );
+ //
+ // In our implementation, the "inner" invocation of `FOO`
+ // gets expanded at the point where it gets referenced
+ // in the body of the "outer" invocation of `FOO`.
+ // Doing things this way leads to greatly simplified
+ // code for handling expansion.
+ //
+ // A proper implementation of `IsMacroBusy` needs to
+ // take context into account, so that it bans recursive
+ // use of a macro when it occurs (indirectly) through
+ // the *body* of the expansion, but not when it occcurs
+ // only through an *argument*.
+ return false;
+}
+
+//
+// Reading Tokens With Expansion
+//
+
+static void InitializeMacroExpansion(
+ Preprocessor* preprocessor,
+ MacroExpansion* expansion,
+ PreprocessorMacro* macro)
+{
+ InitializeInputStream(preprocessor, expansion);
+ expansion->environment = macro->environment;
+ expansion->macro = macro;
+ expansion->tokenReader = TokenReader(macro->tokens);
+}
+
+static void PushMacroExpansion(
+ Preprocessor* preprocessor,
+ MacroExpansion* expansion)
+{
+ PushInputStream(preprocessor, expansion);
+}
+
+static void AddEndOfStreamToken(
+ Preprocessor* preprocessor,
+ PreprocessorMacro* macro)
+{
+ Token token = PeekRawToken(preprocessor);
+ token.Type = TokenType::EndOfFile;
+ macro->tokens.mTokens.Add(token);
+}
+
+// Check whether the current token on the given input stream should be
+// treated as a macro invocation, and if so set up state for expanding
+// that macro.
+static void MaybeBeginMacroExpansion(
+ Preprocessor* preprocessor )
+{
+ // We iterate because the first token in the expansion of one
+ // macro may be another macro invocation.
+ for (;;)
+ {
+ // Look at the next token ahead of us
+ Token const& token = PeekRawToken(preprocessor);
+
+ // Not an identifier? Can't be a macro.
+ if (token.Type != TokenType::Identifier)
+ return;
+
+ // Look for a macro with the given name.
+ String name = token.Content;
+ PreprocessorMacro* macro = LookupMacro(preprocessor, name);
+
+ // Not a macro? Can't be an invocation.
+ if (!macro)
+ return;
+
+ // If the macro is busy (already being expanded),
+ // don't try to trigger recursive expansion
+ if (IsMacroBusy(macro))
+ return;
+
+ // A function-style macro invocation should only match
+ // if the token *after* the identifier is `(`. This
+ // requires more lookahead than we usually have/need
+ if (macro->flavor == PreprocessorMacroFlavor::FunctionLike)
+ {
+ if(PeekSecondRawTokenType(preprocessor) != TokenType::LParent)
+ return;
+
+ // Consume the token that triggered macro expansion
+ AdvanceRawToken(preprocessor);
+
+ // Consume the opening `(`
+ Token leftParen = AdvanceRawToken(preprocessor);
+
+ FunctionLikeMacroExpansion* expansion = new FunctionLikeMacroExpansion();
+ InitializeMacroExpansion(preprocessor, expansion, macro);
+ expansion->argumentEnvironment.parent = &preprocessor->globalEnv;
+ expansion->environment = &expansion->argumentEnvironment;
+
+ // Try to read any arguments present.
+ int paramCount = macro->params.Count();
+ int argIndex = 0;
+
+ switch (PeekRawTokenType(preprocessor))
+ {
+ case TokenType::EndOfFile:
+ case TokenType::RParent:
+ // No arguments.
+ break;
+
+ default:
+ // At least one argument
+ while(argIndex < paramCount)
+ {
+ // Read an argument
+
+ // Create the argument, represented as a special flavor of macro
+ PreprocessorMacro* arg = CreateMacro(preprocessor);
+ arg->flavor = PreprocessorMacroFlavor::FunctionArg;
+ arg->environment = GetCurrentEnvironment(preprocessor);
+
+ // Associate the new macro with its parameter name
+ Token paramToken = macro->params[argIndex];
+ String const& paramName = paramToken.Content;
+ arg->nameToken = paramToken;
+ expansion->argumentEnvironment.macros[paramName] = arg;
+ argIndex++;
+
+ // Read tokens for the argument
+
+ // We track the nesting depth, since we don't break
+ // arguments on a `,` nested in balanced parentheses
+ //
+ int nesting = 0;
+ for (;;)
+ {
+ switch (PeekRawTokenType(preprocessor))
+ {
+ case TokenType::EndOfFile:
+ // if we reach the end of the file,
+ // then we have an error, and need to
+ // bail out
+ AddEndOfStreamToken(preprocessor, arg);
+ goto doneWithAllArguments;
+
+ case TokenType::RParent:
+ // If we see a right paren when we aren't nested
+ // then we are at the end of an argument
+ if (nesting == 0)
+ {
+ AddEndOfStreamToken(preprocessor, arg);
+ goto doneWithAllArguments;
+ }
+ // Otherwise we decrease our nesting depth, add
+ // the token, and keep going
+ nesting--;
+ break;
+
+ case TokenType::Comma:
+ // If we see a comma when we aren't nested
+ // then we are at the end of an argument
+ if (nesting == 0)
+ {
+ AddEndOfStreamToken(preprocessor, arg);
+ AdvanceRawToken(preprocessor);
+ goto doneWithArgument;
+ }
+ // Otherwise we add it as a normal token
+ break;
+
+ case TokenType::LParent:
+ // If we see a left paren then we need to
+ // increase our tracking of nesting
+ nesting++;
+ break;
+
+ default:
+ break;
+ }
+
+ // Add the token and continue parsing.
+ arg->tokens.mTokens.Add(AdvanceRawToken(preprocessor));
+ }
+ doneWithArgument: {}
+ // We've parsed an argument and should move onto
+ // the next one.
+ }
+ break;
+ }
+ doneWithAllArguments:
+ // TODO: handle possible varargs
+
+ // Expect closing right paren
+ if (PeekRawTokenType(preprocessor) == TokenType::RParent)
+ {
+ AdvanceRawToken(preprocessor);
+ }
+ else
+ {
+ GetSink(preprocessor)->diagnose(PeekLoc(preprocessor), Diagnostics::expectedTokenInMacroArguments, TokenType::RParent, PeekRawTokenType(preprocessor));
+ }
+
+ int argCount = argIndex;
+ if (argCount != paramCount)
+ {
+ // TODO: diagnose
+ throw 99;
+ }
+
+ // We are ready to expand.
+ PushMacroExpansion(preprocessor, expansion);
+ }
+ else
+ {
+ // Consume the token that triggered macro expansion
+ AdvanceRawToken(preprocessor);
+
+ // Object-like macros are the easy case.
+ ObjectLikeMacroExpansion* expansion = new ObjectLikeMacroExpansion();
+ InitializeMacroExpansion(preprocessor, expansion, macro);
+ PushMacroExpansion(preprocessor, expansion);
+ }
+ }
+}
+
+// Read one token with macro-expansion enabled.
+static Token AdvanceToken(Preprocessor* preprocessor)
+{
+top:
+ // Check whether we need to macro expand at the cursor.
+ MaybeBeginMacroExpansion(preprocessor);
+
+ // Read a raw token (now that expansion has been triggered)
+ Token token = AdvanceRawToken(preprocessor);
+
+ // Check if we need to perform token pasting
+ if (PeekRawTokenType(preprocessor) != TokenType::PoundPound)
+ {
+ // If we aren't token pasting, then we are done
+ return token;
+ }
+ else
+ {
+ // We are pasting tokens, which could get messy
+
+ StringBuilder sb;
+ sb << token.Content;
+
+ while (PeekRawTokenType(preprocessor) == TokenType::PoundPound)
+ {
+ // Consume the `##`
+ AdvanceRawToken(preprocessor);
+
+ // Possibly macro-expand the next token
+ MaybeBeginMacroExpansion(preprocessor);
+
+ // Read the next raw token (now that expansion has been triggered)
+ Token nextToken = AdvanceRawToken(preprocessor);
+
+ sb << nextToken.Content;
+ }
+
+ // Now re-lex the input
+ PreprocessorInputStream* inputStream = CreateInputStreamForSource(preprocessor, sb.ProduceString(), "token paste");
+ if (inputStream->tokenReader.GetCount() != 1)
+ {
+ // We expect a token paste to produce a single token
+ // TODO(tfoley): emit a diagnostic here
+ }
+
+ PushInputStream(preprocessor, inputStream);
+ goto top;
+ }
+}
+
+// Read one token with macro-expansion enabled.
+//
+// Note that because triggering macro expansion may
+// involve changing the input-stream state, this
+// operation *can* have side effects.
+static Token PeekToken(Preprocessor* preprocessor)
+{
+ // Check whether we need to macro expand at the cursor.
+ MaybeBeginMacroExpansion(preprocessor);
+
+ // Peek a raw token (now that expansion has been triggered)
+ return PeekRawToken(preprocessor);
+
+ // TODO: need a plan for how to handle token pasting
+ // here without it being onerous. Would be nice if we
+ // didn't have to re-do pasting on a "peek"...
+}
+
+// Peek the type of the next token, including macro expansion.
+static TokenType PeekTokenType(Preprocessor* preprocessor)
+{
+ return PeekToken(preprocessor).Type;
+}
+
+//
+// Preprocessor Directives
+//
+
+// When reading a preprocessor directive, we use a context
+// to wrap the direct preprocessor routines defines so far.
+//
+// One of the most important things the directive context
+// does is give us a convenient way to read tokens with
+// a guarantee that we won't read past the end of a line.
+struct PreprocessorDirectiveContext
+{
+ // The preprocessor that is parsing the directive.
+ Preprocessor* preprocessor;
+
+ // The directive token (e.g., the `if` in `#if`).
+ // Useful for reference in diagnostic messages.
+ Token directiveToken;
+
+ // Has any kind of parse error been encountered in
+ // the directive so far?
+ bool parseError;
+
+ // Have we done the necessary checks at the end
+ // of the directive already?
+ bool haveDoneEndOfDirectiveChecks;
+};
+
+// Get the token for the preprocessor directive being parsed.
+inline Token const& GetDirective(PreprocessorDirectiveContext* context)
+{
+ return context->directiveToken;
+}
+
+// Get the name of the directive being parsed.
+inline String const& GetDirectiveName(PreprocessorDirectiveContext* context)
+{
+ return context->directiveToken.Content;
+}
+
+// Get the location of the directive being parsed.
+inline CodePosition const& GetDirectiveLoc(PreprocessorDirectiveContext* context)
+{
+ return context->directiveToken.Position;
+}
+
+// Wrapper to get the diagnostic sink in the context of a directive.
+static inline DiagnosticSink* GetSink(PreprocessorDirectiveContext* context)
+{
+ return GetSink(context->preprocessor);
+}
+
+// Wrapper to get a "current" location when parsing a directive
+static CodePosition PeekLoc(PreprocessorDirectiveContext* context)
+{
+ return PeekLoc(context->preprocessor);
+}
+
+// Wrapper to look up a macro in the context of a directive.
+static PreprocessorMacro* LookupMacro(PreprocessorDirectiveContext* context, String const& name)
+{
+ return LookupMacro(context->preprocessor, name);
+}
+
+// Determine if we have read everthing on the directive's line.
+static bool IsEndOfLine(PreprocessorDirectiveContext* context)
+{
+ return PeekRawToken(context->preprocessor).Type == TokenType::EndOfDirective;
+}
+
+// Peek one raw token in a directive, without going past the end of the line.
+static Token PeekRawToken(PreprocessorDirectiveContext* context)
+{
+ return PeekRawToken(context->preprocessor);
+}
+
+// Read one raw token in a directive, without going past the end of the line.
+static Token AdvanceRawToken(PreprocessorDirectiveContext* context)
+{
+ if (IsEndOfLine(context))
+ return PeekRawToken(context);
+ return AdvanceRawToken(context->preprocessor);
+}
+
+// Peek next raw token type, without going past the end of the line.
+static TokenType PeekRawTokenType(PreprocessorDirectiveContext* context)
+{
+ return PeekRawTokenType(context->preprocessor);
+}
+
+// Read one token, with macro-expansion, without going past the end of the line.
+static Token AdvanceToken(PreprocessorDirectiveContext* context)
+{
+ if (IsEndOfLine(context))
+ return PeekRawToken(context);
+ return AdvanceToken(context->preprocessor);
+}
+
+// Peek one token, with macro-expansion, without going past the end of the line.
+static Token PeekToken(PreprocessorDirectiveContext* context)
+{
+ if (IsEndOfLine(context))
+ context->preprocessor->endOfFileToken;
+ return PeekToken(context->preprocessor);
+}
+
+// Peek next token type, with macro-expansion, without going past the end of the line.
+static TokenType PeekTokenType(PreprocessorDirectiveContext* context)
+{
+ if (IsEndOfLine(context))
+ return TokenType::EndOfDirective;
+ return PeekTokenType(context->preprocessor);
+}
+
+// Skip to the end of the line (useful for recovering from errors in a directive)
+static void SkipToEndOfLine(PreprocessorDirectiveContext* context)
+{
+ while(!IsEndOfLine(context))
+ {
+ AdvanceRawToken(context);
+ }
+}
+
+static bool ExpectRaw(PreprocessorDirectiveContext* context, TokenType tokenType, DiagnosticInfo const& diagnostic, Token* outToken = NULL)
+{
+ if (PeekRawTokenType(context) != tokenType)
+ {
+ // Only report the first parse error within a directive
+ if (!context->parseError)
+ {
+ GetSink(context)->diagnose(PeekLoc(context), diagnostic, tokenType, GetDirectiveName(context));
+ }
+ context->parseError = true;
+ return false;
+ }
+ Token const& token = AdvanceRawToken(context);
+ if (outToken)
+ *outToken = token;
+ return true;
+}
+
+static bool Expect(PreprocessorDirectiveContext* context, TokenType tokenType, DiagnosticInfo const& diagnostic, Token* outToken = NULL)
+{
+ if (PeekTokenType(context) != tokenType)
+ {
+ // Only report the first parse error within a directive
+ if (!context->parseError)
+ {
+ GetSink(context)->diagnose(PeekLoc(context), diagnostic, tokenType, GetDirectiveName(context));
+ context->parseError = true;
+ }
+ return false;
+ }
+ Token const& token = AdvanceToken(context);
+ if (outToken)
+ *outToken = token;
+ return true;
+}
+
+
+
+//
+// Preprocessor Conditionals
+//
+
+// Determine whether the current preprocessor state means we
+// should be skipping tokens.
+static bool IsSkipping(Preprocessor* preprocessor)
+{
+ PreprocessorInputStream* inputStream = preprocessor->inputStream;
+ if (!inputStream) return false;
+
+ // If we are not inside a preprocessor conditional, then don't skip
+ PreprocessorConditional* conditional = inputStream->conditional;
+ if (!conditional) return false;
+
+ // skip tokens unless the conditional is inside its `true` case
+ return conditional->state != PreprocessorConditionalState::During;
+}
+
+// Wrapper for use inside directives
+static inline bool IsSkipping(PreprocessorDirectiveContext* context)
+{
+ return IsSkipping(context->preprocessor);
+}
+
+// Create a preprocessor conditional
+static PreprocessorConditional* CreateConditional(Preprocessor* /*preprocessor*/)
+{
+ // TODO(tfoley): allocate these more intelligently (for example,
+ // pool them on the `Preprocessor`.
+ return new PreprocessorConditional();
+}
+
+// Destroy a preprocessor conditional.
+static void DestroyConditional(PreprocessorConditional* conditional)
+{
+ delete conditional;
+}
+
+// Start a preprocessor conditional, with an initial enable/disable state.
+static void BeginConditional(PreprocessorDirectiveContext* context, bool enable)
+{
+ Preprocessor* preprocessor = context->preprocessor;
+ PreprocessorInputStream* inputStream = preprocessor->inputStream;
+ assert(inputStream);
+
+ PreprocessorConditional* conditional = CreateConditional(preprocessor);
+
+ conditional->ifToken = context->directiveToken;
+
+ // Set state of this condition appropriately.
+ //
+ // Default to the "haven't yet seen a `true` branch" state.
+ PreprocessorConditionalState state = PreprocessorConditionalState::Before;
+ //
+ // If we are nested inside a `false` branch of another condition, then
+ // we never want to enable, so we act as if we already *saw* the `true` branch.
+ //
+ if (IsSkipping(preprocessor)) state = PreprocessorConditionalState::After;
+ //
+ // Similarly, if we ran into any parse errors when dealing with the
+ // opening directive, then things are probably screwy and we should just
+ // skip all the branches.
+ if (IsSkipping(preprocessor)) state = PreprocessorConditionalState::After;
+ //
+ // Otherwise, if our condition was true, then set us to be inside the `true` branch
+ else if (enable) state = PreprocessorConditionalState::During;
+
+ conditional->state = state;
+
+ // Push conditional onto the stack
+ conditional->parent = inputStream->conditional;
+ inputStream->conditional = conditional;
+}
+
+//
+// Preprocessor Conditional Expressions
+//
+
+// Conditional expressions are always of type `int`
+typedef int PreprocessorExpressionValue;
+
+// Forward-declaretion
+static PreprocessorExpressionValue ParseAndEvaluateExpression(PreprocessorDirectiveContext* context);
+
+// Parse a unary (prefix) expression inside of a preprocessor directive.
+static PreprocessorExpressionValue ParseAndEvaluateUnaryExpression(PreprocessorDirectiveContext* context)
+{
+ switch (PeekTokenType(context))
+ {
+ // handle prefix unary ops
+ case TokenType::OpSub:
+ AdvanceToken(context);
+ return -ParseAndEvaluateUnaryExpression(context);
+ case TokenType::OpNot:
+ AdvanceToken(context);
+ return !ParseAndEvaluateUnaryExpression(context);
+ case TokenType::OpBitNot:
+ AdvanceToken(context);
+ return ~ParseAndEvaluateUnaryExpression(context);
+
+ // handle parenthized sub-expression
+ case TokenType::LParent:
+ {
+ Token leftParen = AdvanceToken(context);
+ PreprocessorExpressionValue value = ParseAndEvaluateExpression(context);
+ if (!Expect(context, TokenType::RParent, Diagnostics::expectedTokenInPreprocessorExpression))
+ {
+ GetSink(context)->diagnose(leftParen.Position, Diagnostics::seeOpeningToken, leftParen);
+ }
+ return value;
+ }
+
+ case TokenType::IntLiterial:
+ return StringToInt(AdvanceToken(context).Content);
+
+ case TokenType::Identifier:
+ {
+ Token token = AdvanceToken(context);
+ if (token.Content == "defined")
+ {
+ // handle `defined(someName)`
+
+ // Possibly parse a `(`
+ Token leftParen;
+ if (PeekRawTokenType(context) == TokenType::LParent)
+ {
+ leftParen = AdvanceRawToken(context);
+ }
+
+ // Expect an identifier
+ Token nameToken;
+ if (!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInDefinedExpression, &nameToken))
+ {
+ return 0;
+ }
+ String name = nameToken.Content;
+
+ // If we saw an opening `(`, then expect one to close
+ if (leftParen.Type != TokenType::Unknown)
+ {
+ if(!ExpectRaw(context, TokenType::RParent, Diagnostics::expectedTokenInDefinedExpression))
+ {
+ GetSink(context)->diagnose(leftParen.Position, Diagnostics::seeOpeningToken, leftParen);
+ return 0;
+ }
+ }
+
+ return LookupMacro(context, name) != NULL;
+ }
+
+ // An identifier here means it was not defined as a macro (or
+ // it is defined, but as a function-like macro. These should
+ // just evaluate to zero (possibly with a warning)
+ return 0;
+ }
+
+ default:
+ GetSink(context)->diagnose(PeekLoc(context), Diagnostics::syntaxErrorInPreprocessorExpression);
+ return 0;
+ }
+}
+
+// Determine the precedence level of an infix operator
+// for use in parsing preprocessor conditionals.
+static int GetInfixOpPrecedence(Token const& opToken)
+{
+ // If token is on another line, it is not part of the
+ // expression
+ if (opToken.flags & TokenFlag::AtStartOfLine)
+ return -1;
+
+ // otherwise we look at the token type to figure
+ // out what precednece it should be parse with
+ switch (opToken.Type)
+ {
+ default:
+ // tokens that aren't infix operators should
+ // cause us to stop parsing an expression
+ return -1;
+
+ case TokenType::OpMul: return 10;
+ case TokenType::OpDiv: return 10;
+ case TokenType::OpMod: return 10;
+
+ case TokenType::OpAdd: return 9;
+ case TokenType::OpSub: return 9;
+
+ case TokenType::OpLsh: return 8;
+ case TokenType::OpRsh: return 8;
+
+ case TokenType::OpLess: return 7;
+ case TokenType::OpGreater: return 7;
+ case TokenType::OpLeq: return 7;
+ case TokenType::OpGeq: return 7;
+
+ case TokenType::OpEql: return 6;
+ case TokenType::OpNeq: return 6;
+
+ case TokenType::OpBitAnd: return 5;
+ case TokenType::OpBitOr: return 4;
+ case TokenType::OpBitXor: return 3;
+ case TokenType::OpAnd: return 2;
+ case TokenType::OpOr: return 1;
+ }
+};
+
+// Evaluate one infix operation in a preprocessor
+// conditional expression
+static PreprocessorExpressionValue EvaluateInfixOp(
+ PreprocessorDirectiveContext* context,
+ Token const& opToken,
+ PreprocessorExpressionValue left,
+ PreprocessorExpressionValue right)
+{
+ switch (opToken.Type)
+ {
+ default:
+// SLANG_INTERNAL_ERROR(getSink(preprocessor), opToken);
+ return 0;
+ break;
+
+ case TokenType::OpMul: return left * right;
+ case TokenType::OpDiv:
+ {
+ if (right == 0)
+ {
+ if (!context->parseError)
+ {
+ GetSink(context)->diagnose(opToken.Position, Diagnostics::divideByZeroInPreprocessorExpression);
+ }
+ return 0;
+ }
+ return left / right;
+ }
+ case TokenType::OpMod:
+ {
+ if (right == 0)
+ {
+ if (!context->parseError)
+ {
+ GetSink(context)->diagnose(opToken.Position, Diagnostics::divideByZeroInPreprocessorExpression);
+ }
+ return 0;
+ }
+ return left % right;
+ }
+ case TokenType::OpAdd: return left + right;
+ case TokenType::OpSub: return left - right;
+ case TokenType::OpLsh: return left << right;
+ case TokenType::OpRsh: return left >> right;
+ case TokenType::OpLess: return left < right ? 1 : 0;
+ case TokenType::OpGreater: return left > right ? 1 : 0;
+ case TokenType::OpLeq: return left <= right ? 1 : 0;
+ case TokenType::OpGeq: return left <= right ? 1 : 0;
+ case TokenType::OpEql: return left == right ? 1 : 0;
+ case TokenType::OpNeq: return left != right ? 1 : 0;
+ case TokenType::OpBitAnd: return left & right;
+ case TokenType::OpBitOr: return left | right;
+ case TokenType::OpBitXor: return left ^ right;
+ case TokenType::OpAnd: return left && right;
+ case TokenType::OpOr: return left || right;
+ }
+}
+
+// Parse the rest of an infix preprocessor expression with
+// precedence greater than or equal to the given `precedence` argument.
+// The value of the left-hand-side expression is provided as
+// an argument.
+// This is used to form a simple recursive-descent expression parser.
+static PreprocessorExpressionValue ParseAndEvaluateInfixExpressionWithPrecedence(
+ PreprocessorDirectiveContext* context,
+ PreprocessorExpressionValue left,
+ int precedence)
+{
+ for (;;)
+ {
+ // Look at the next token, and see if it is an operator of
+ // high enough precedence to be included in our expression
+ Token opToken = PeekToken(context);
+ int opPrecedence = GetInfixOpPrecedence(opToken);
+
+ // If it isn't an operator of high enough precendece, we are done.
+ if(opPrecedence < precedence)
+ break;
+
+ // Otherwise we need to consume the operator token.
+ AdvanceToken(context);
+
+ // Next we parse a right-hand-side expression by starting with
+ // a unary expression and absorbing and many infix operators
+ // as possible with strictly higher precedence than the operator
+ // we found above.
+ PreprocessorExpressionValue right = ParseAndEvaluateUnaryExpression(context);
+ for (;;)
+ {
+ // Look for an operator token
+ Token rightOpToken = PeekToken(context);
+ int rightOpPrecedence = GetInfixOpPrecedence(rightOpToken);
+
+ // If no operator was found, or the operator wasn't high
+ // enough precedence to fold into the right-hand-side,
+ // exit this loop.
+ if (rightOpPrecedence <= opPrecedence)
+ break;
+
+ // Now invoke the parser recursively, passing in our
+ // existing right-hand side to form an even larger one.
+ right = ParseAndEvaluateInfixExpressionWithPrecedence(
+ context,
+ right,
+ rightOpPrecedence);
+ }
+
+ // Now combine the left- and right-hand sides using
+ // the operator we found above.
+ left = EvaluateInfixOp(context, opToken, left, right);
+ }
+ return left;
+}
+
+// Parse a complete (infix) preprocessor expression, and return its value
+static PreprocessorExpressionValue ParseAndEvaluateExpression(PreprocessorDirectiveContext* context)
+{
+ // First read in the left-hand side (or the whole expression in the unary case)
+ PreprocessorExpressionValue value = ParseAndEvaluateUnaryExpression(context);
+
+ // Try to read in trailing infix operators with correct precedence
+ return ParseAndEvaluateInfixExpressionWithPrecedence(context, value, 0);
+}
+
+// Handle a `#if` directive
+static void HandleIfDirective(PreprocessorDirectiveContext* context)
+{
+ // Parse a preprocessor expression.
+ PreprocessorExpressionValue value = ParseAndEvaluateExpression(context);
+
+ // Begin a preprocessor block, enabled based on the expression.
+ BeginConditional(context, value != 0);
+}
+
+// Handle a `#ifdef` directive
+static void HandleIfDefDirective(PreprocessorDirectiveContext* context)
+{
+ // Expect a raw identifier, so we can check if it is defined
+ Token nameToken;
+ if(!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInPreprocessorDirective, &nameToken))
+ return;
+ String name = nameToken.Content;
+
+ // Check if the name is defined.
+ BeginConditional(context, LookupMacro(context, name) != NULL);
+}
+
+// Handle a `#ifndef` directive
+static void HandleIfNDefDirective(PreprocessorDirectiveContext* context)
+{
+ // Expect a raw identifier, so we can check if it is defined
+ Token nameToken;
+ if(!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInPreprocessorDirective, &nameToken))
+ return;
+ String name = nameToken.Content;
+
+ // Check if the name is defined.
+ BeginConditional(context, LookupMacro(context, name) == NULL);
+}
+
+// Handle a `#else` directive
+static void HandleElseDirective(PreprocessorDirectiveContext* context)
+{
+ PreprocessorInputStream* inputStream = context->preprocessor->inputStream;
+ assert(inputStream);
+
+ // if we aren't inside a conditional, then error
+ PreprocessorConditional* conditional = inputStream->conditional;
+ if (!conditional)
+ {
+ GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveWithoutIf, GetDirectiveName(context));
+ return;
+ }
+
+ // if we've already seen a `#else`, then it is an error
+ if (conditional->elseToken.Type != TokenType::Unknown)
+ {
+ GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveAfterElse, GetDirectiveName(context));
+ GetSink(context)->diagnose(conditional->elseToken.Position, Diagnostics::seeDirective);
+ return;
+ }
+ conditional->elseToken = context->directiveToken;
+
+ switch (conditional->state)
+ {
+ case PreprocessorConditionalState::Before:
+ conditional->state = PreprocessorConditionalState::During;
+ break;
+
+ case PreprocessorConditionalState::During:
+ conditional->state = PreprocessorConditionalState::After;
+ break;
+
+ default:
+ break;
+ }
+}
+
+// Handle a `#elif` directive
+static void HandleElifDirective(PreprocessorDirectiveContext* context)
+{
+ // HACK(tfoley): handle an empty `elif` like an `else` directive
+ //
+ // This is the behavior expected by at least one input program.
+ // We will eventually want to be pedantic about this.
+ // even if t
+ if (PeekRawTokenType(context) == TokenType::EndOfDirective)
+ {
+ GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveExpectsExpression, GetDirectiveName(context));
+ HandleElseDirective(context);
+ return;
+ }
+
+ PreprocessorExpressionValue value = ParseAndEvaluateExpression(context);
+
+ PreprocessorInputStream* inputStream = context->preprocessor->inputStream;
+ assert(inputStream);
+
+ // if we aren't inside a conditional, then error
+ PreprocessorConditional* conditional = inputStream->conditional;
+ if (!conditional)
+ {
+ GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveWithoutIf, GetDirectiveName(context));
+ return;
+ }
+
+ // if we've already seen a `#else`, then it is an error
+ if (conditional->elseToken.Type != TokenType::Unknown)
+ {
+ GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveAfterElse, GetDirectiveName(context));
+ GetSink(context)->diagnose(conditional->elseToken.Position, Diagnostics::seeDirective);
+ return;
+ }
+
+ switch (conditional->state)
+ {
+ case PreprocessorConditionalState::Before:
+ if(value)
+ conditional->state = PreprocessorConditionalState::During;
+ break;
+
+ case PreprocessorConditionalState::During:
+ conditional->state = PreprocessorConditionalState::After;
+ break;
+
+ default:
+ break;
+ }
+}
+
+// Handle a `#endif` directive
+static void HandleEndIfDirective(PreprocessorDirectiveContext* context)
+{
+ PreprocessorInputStream* inputStream = context->preprocessor->inputStream;
+ assert(inputStream);
+
+ // if we aren't inside a conditional, then error
+ PreprocessorConditional* conditional = inputStream->conditional;
+ if (!conditional)
+ {
+ GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::directiveWithoutIf, GetDirectiveName(context));
+ return;
+ }
+
+ inputStream->conditional = conditional->parent;
+ DestroyConditional(conditional);
+}
+
+// Helper routine to check that we find the end of a directive where
+// we expect it.
+//
+// Most directives do not need to call this directly, since we have
+// a catch-all case in the main `HandleDirective()` funciton.
+// The `#include` case will call it directly to avoid complications
+// when it switches the input stream.
+static void expectEndOfDirective(PreprocessorDirectiveContext* context)
+{
+ if(context->haveDoneEndOfDirectiveChecks)
+ return;
+
+ context->haveDoneEndOfDirectiveChecks = true;
+
+ if (!IsEndOfLine(context))
+ {
+ // If we already saw a previous parse error, then don't
+ // emit another one for the same directive.
+ if (!context->parseError)
+ {
+ GetSink(context)->diagnose(PeekLoc(context), Diagnostics::unexpectedTokensAfterDirective, GetDirectiveName(context));
+ }
+ SkipToEndOfLine(context);
+ }
+
+ // Clear out the end-of-directive token
+ AdvanceRawToken(context->preprocessor);
+}
+
+
+// Handle a `#include` directive
+static void HandleIncludeDirective(PreprocessorDirectiveContext* context)
+{
+ Token pathToken;
+ if(!Expect(context, TokenType::StringLiterial, Diagnostics::expectedTokenInPreprocessorDirective, &pathToken))
+ return;
+
+ String path = getFileNameTokenValue(pathToken);
+
+ // TODO(tfoley): make this robust in presence of `#line`
+ String pathIncludedFrom = GetDirectiveLoc(context).FileName;
+ String foundPath;
+ String foundSource;
+
+
+ IncludeHandler* includeHandler = context->preprocessor->includeHandler;
+ if (!includeHandler)
+ {
+ GetSink(context)->diagnose(pathToken.Position, Diagnostics::includeFailed, path);
+ GetSink(context)->diagnose(pathToken.Position, Diagnostics::noIncludeHandlerSpecified);
+ return;
+ }
+ if (!includeHandler->TryToFindIncludeFile(path, pathIncludedFrom, &foundPath, &foundSource))
+ {
+ GetSink(context)->diagnose(pathToken.Position, Diagnostics::includeFailed, path);
+ return;
+ }
+
+ // Do all checking related to the end of this directive before we push a new stream,
+ // just to avoid complications where that check would need to deal with
+ // a switch of input stream
+ expectEndOfDirective(context);
+
+ // Push the new file onto our stack of input streams
+ // TODO(tfoley): check if we have made our include stack too deep
+ PreprocessorInputStream* inputStream = CreateInputStreamForSource(context->preprocessor, foundSource, foundPath);
+ inputStream->parent = context->preprocessor->inputStream;
+ context->preprocessor->inputStream = inputStream;
+}
+
+// Handle a `#define` directive
+static void HandleDefineDirective(PreprocessorDirectiveContext* context)
+{
+ Token nameToken;
+ if (!Expect(context, TokenType::Identifier, Diagnostics::expectedTokenInPreprocessorDirective, &nameToken))
+ return;
+ String name = nameToken.Content;
+
+ PreprocessorMacro* macro = CreateMacro(context->preprocessor);
+ macro->nameToken = nameToken;
+
+ PreprocessorMacro* oldMacro = LookupMacro(&context->preprocessor->globalEnv, name);
+ if (oldMacro)
+ {
+ GetSink(context)->diagnose(nameToken.Position, Diagnostics::macroRedefinition, name);
+ GetSink(context)->diagnose(oldMacro->nameToken.Position, Diagnostics::seePreviousDefinitionOf, name);
+
+ DestroyMacro(context->preprocessor, oldMacro);
+ }
+ context->preprocessor->globalEnv.macros[name] = macro;
+
+ // If macro name is immediately followed (with no space) by `(`,
+ // then we have a function-like macro
+ if (PeekRawTokenType(context) == TokenType::LParent)
+ {
+ if (!(PeekRawToken(context).flags & TokenFlag::AfterWhitespace))
+ {
+ // This is a function-like macro, so we need to remember that
+ // and start capturing parameters
+ macro->flavor = PreprocessorMacroFlavor::FunctionLike;
+
+ AdvanceRawToken(context);
+
+ // If there are any parameters, parse them
+ if (PeekRawTokenType(context) != TokenType::RParent)
+ {
+ for (;;)
+ {
+ // TODO: handle elipsis (`...`) for varags
+
+ // A macro parameter name should be a raw identifier
+ Token paramToken;
+ if (!ExpectRaw(context, TokenType::Identifier, Diagnostics::expectedTokenInMacroParameters, &paramToken))
+ break;
+
+ // TODO(tfoley): some validation on parameter name.
+ // Certain names (e.g., `defined` and `__VA_ARGS__`
+ // are not allowed to be used as macros or parameters).
+
+ // Add the parameter to the macro being deifned
+ macro->params.Add(paramToken);
+
+ // If we see `)` then we are done with arguments
+ if (PeekRawTokenType(context) == TokenType::RParent)
+ break;
+
+ ExpectRaw(context, TokenType::Comma, Diagnostics::expectedTokenInMacroParameters);
+ }
+ }
+
+ ExpectRaw(context, TokenType::RParent, Diagnostics::expectedTokenInMacroParameters);
+ }
+ }
+
+ // consume tokens until end-of-line
+ for(;;)
+ {
+ Token token = AdvanceRawToken(context);
+ if( token.Type == TokenType::EndOfDirective )
+ {
+ // Last token on line will be turned into a conceptual end-of-file
+ // token for the sub-stream that the macro expands into.
+ token.Type = TokenType::EndOfFile;
+ macro->tokens.mTokens.Add(token);
+ break;
+ }
+
+ // In the ordinary case, we just add the token to the definition
+ macro->tokens.mTokens.Add(token);
+ }
+}
+
+// Handle a `#undef` directive
+static void HandleUndefDirective(PreprocessorDirectiveContext* context)
+{
+ Token nameToken;
+ if (!Expect(context, TokenType::Identifier, Diagnostics::expectedTokenInPreprocessorDirective, &nameToken))
+ return;
+ String name = nameToken.Content;
+
+ PreprocessorEnvironment* env = &context->preprocessor->globalEnv;
+ PreprocessorMacro* macro = LookupMacro(env, name);
+ if (macro != NULL)
+ {
+ // name was defined, so remove it
+ env->macros.Remove(name);
+
+ DestroyMacro(context->preprocessor, macro);
+ }
+ else
+ {
+ // name wasn't defined
+ GetSink(context)->diagnose(nameToken.Position, Diagnostics::macroNotDefined, name);
+ }
+}
+
+// Handle a `#warning` directive
+static void HandleWarningDirective(PreprocessorDirectiveContext* context)
+{
+ // TODO: read rest of line without actual tokenization
+ GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::userDefinedWarning, "user-defined warning");
+ SkipToEndOfLine(context);
+}
+
+// Handle a `#error` directive
+static void HandleErrorDirective(PreprocessorDirectiveContext* context)
+{
+ // TODO: read rest of line without actual tokenization
+ GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::userDefinedError, "user-defined warning");
+ SkipToEndOfLine(context);
+}
+
+// Handle a `#line` directive
+static void HandleLineDirective(PreprocessorDirectiveContext* context)
+{
+ int line = 0;
+ if (PeekTokenType(context) == TokenType::IntLiterial)
+ {
+ line = StringToInt(AdvanceToken(context).Content);
+ }
+ else if (PeekTokenType(context) == TokenType::Identifier
+ && PeekToken(context).Content == "default")
+ {
+ AdvanceToken(context);
+
+ // Stop overiding soure locations.
+ context->preprocessor->inputStream->isOverridingSourceLoc = false;
+ return;
+ }
+ else
+ {
+ GetSink(context)->diagnose(PeekLoc(context), Diagnostics::expected2TokensInPreprocessorDirective,
+ TokenType::IntLiterial,
+ "default",
+ GetDirectiveName(context));
+ context->parseError = true;
+ return;
+ }
+
+ CodePosition directiveLoc = GetDirectiveLoc(context);
+
+ String file;
+ if (PeekTokenType(context) == TokenType::EndOfDirective)
+ {
+ file = directiveLoc.FileName;
+ }
+ else if (PeekTokenType(context) == TokenType::StringLiterial)
+ {
+ file = AdvanceToken(context).Content;
+ }
+ else if (PeekTokenType(context) == TokenType::IntLiterial)
+ {
+ // Note(tfoley): GLSL allows the "source string" to be indicated by an integer
+ // TODO(tfoley): Figure out a better way to handle this, if it matters
+ file = AdvanceToken(context).Content;
+ }
+ else
+ {
+ Expect(context, TokenType::StringLiterial, Diagnostics::expectedTokenInPreprocessorDirective);
+ return;
+ }
+
+ PreprocessorInputStream* inputStream = context->preprocessor->inputStream;
+
+ inputStream->isOverridingSourceLoc = true;
+ inputStream->overrideFileName = file;
+ inputStream->overrideLineOffset = line - (directiveLoc.Line + 1);
+}
+
+// Handle a `#pragma` directive
+static void HandlePragmaDirective(PreprocessorDirectiveContext* context)
+{
+ // TODO(tfoley): figure out which pragmas to parse,
+ // and which to pass along
+ SkipToEndOfLine(context);
+}
+
+// Handle a `#version` directive
+static void handleGLSLVersionDirective(PreprocessorDirectiveContext* context)
+{
+ Token versionNumberToken;
+ if(!ExpectRaw(
+ context,
+ TokenType::IntLiterial,
+ Diagnostics::expectedTokenInPreprocessorDirective,
+ &versionNumberToken))
+ {
+ return;
+ }
+
+ Token glslProfileToken;
+ if(PeekTokenType(context) == TokenType::Identifier)
+ {
+ glslProfileToken = AdvanceToken(context);
+ }
+
+ // Need to construct a representation taht we can hook into our compilation result
+
+ auto modifier = new GLSLVersionDirective();
+ modifier->versionNumberToken = versionNumberToken;
+ modifier->glslProfileToken = glslProfileToken;
+
+ // Attach the modifier to the program we are parsing!
+
+ addModifier(
+ context->preprocessor->syntax,
+ modifier);
+}
+
+// Handle a `#extension` directive, e.g.,
+//
+// #extension some_extension_name : enable
+//
+static void handleGLSLExtensionDirective(PreprocessorDirectiveContext* context)
+{
+ Token extensionNameToken;
+ if(!ExpectRaw(
+ context,
+ TokenType::Identifier,
+ Diagnostics::expectedTokenInPreprocessorDirective,
+ &extensionNameToken))
+ {
+ return;
+ }
+
+ if( !ExpectRaw(context, TokenType::Colon, Diagnostics::expectedTokenInPreprocessorDirective) )
+ {
+ return;
+ }
+
+ Token dispositionToken;
+ if(!ExpectRaw(
+ context,
+ TokenType::Identifier,
+ Diagnostics::expectedTokenInPreprocessorDirective,
+ &dispositionToken))
+ {
+ return;
+ }
+
+ // Need to construct a representation taht we can hook into our compilation result
+
+ auto modifier = new GLSLExtensionDirective();
+ modifier->extensionNameToken = extensionNameToken;
+ modifier->dispositionToken = dispositionToken;
+
+ // Attach the modifier to the program we are parsing!
+
+ addModifier(
+ context->preprocessor->syntax,
+ modifier);
+}
+
+// Handle an invalid directive
+static void HandleInvalidDirective(PreprocessorDirectiveContext* context)
+{
+ GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::unknownPreprocessorDirective, GetDirectiveName(context));
+ SkipToEndOfLine(context);
+}
+
+// Callback interface used by preprocessor directives
+typedef void (*PreprocessorDirectiveCallback)(PreprocessorDirectiveContext* context);
+
+enum PreprocessorDirectiveFlag : unsigned int
+{
+ // Should this directive be handled even when skipping disbaled code?
+ ProcessWhenSkipping = 1 << 0,
+};
+
+// Information about a specific directive
+struct PreprocessorDirective
+{
+ // Name of the directive
+ char const* name;
+
+ // Callback to handle the directive
+ PreprocessorDirectiveCallback callback;
+
+ unsigned int flags;
+};
+
+// A simple array of all the directives we know how to handle.
+// TODO(tfoley): considering making this into a real hash map,
+// and then make it easy-ish for users of the codebase to add
+// their own directives as desired.
+static const PreprocessorDirective kDirectives[] =
+{
+ { "if", &HandleIfDirective, ProcessWhenSkipping },
+ { "ifdef", &HandleIfDefDirective, ProcessWhenSkipping },
+ { "ifndef", &HandleIfNDefDirective, ProcessWhenSkipping },
+ { "else", &HandleElseDirective, ProcessWhenSkipping },
+ { "elif", &HandleElifDirective, ProcessWhenSkipping },
+ { "endif", &HandleEndIfDirective, ProcessWhenSkipping },
+
+ { "include", &HandleIncludeDirective, 0 },
+ { "define", &HandleDefineDirective, 0 },
+ { "undef", &HandleUndefDirective, 0 },
+ { "warning", &HandleWarningDirective, 0 },
+ { "error", &HandleErrorDirective, 0 },
+ { "line", &HandleLineDirective, 0 },
+ { "pragma", &HandlePragmaDirective, 0 },
+
+ // TODO(tfoley): These are specific to GLSL, and probably
+ // shouldn't be enabled for HLSL or Slang
+ { "version", &handleGLSLVersionDirective, 0 },
+ { "extension", &handleGLSLExtensionDirective, 0 },
+
+ { NULL, NULL },
+};
+
+static const PreprocessorDirective kInvalidDirective = {
+ NULL, &HandleInvalidDirective, 0,
+};
+
+// Look up the directive with the given name.
+static PreprocessorDirective const* FindDirective(String const& name)
+{
+ char const* nameStr = name.Buffer();
+ for (int ii = 0; kDirectives[ii].name; ++ii)
+ {
+ if (strcmp(kDirectives[ii].name, nameStr) != 0)
+ continue;
+
+ return &kDirectives[ii];
+ }
+
+ return &kInvalidDirective;
+}
+
+// Process a directive, where the preprocessor has already consumed the
+// `#` token that started the directive line.
+static void HandleDirective(PreprocessorDirectiveContext* context)
+{
+ // Try to read the directive name.
+ context->directiveToken = PeekRawToken(context);
+
+ TokenType directiveTokenType = GetDirective(context).Type;
+
+ // An empty directive is allowed, and ignored.
+ if (directiveTokenType == TokenType::EndOfDirective)
+ {
+ return;
+ }
+ // Otherwise the directive name had better be an identifier
+ else if (directiveTokenType != TokenType::Identifier)
+ {
+ GetSink(context)->diagnose(GetDirectiveLoc(context), Diagnostics::expectedPreprocessorDirectiveName);
+ SkipToEndOfLine(context);
+ return;
+ }
+
+ // Consume the directive name token.
+ AdvanceRawToken(context);
+
+ // Look up the handler for the directive.
+ PreprocessorDirective const* directive = FindDirective(GetDirectiveName(context));
+
+ // If we are skipping disabled code, and the directive is not one
+ // of the small number that need to run even in that case, skip it.
+ if (IsSkipping(context) && !(directive->flags & PreprocessorDirectiveFlag::ProcessWhenSkipping))
+ {
+ SkipToEndOfLine(context);
+ return;
+ }
+
+ // Apply the directive-specific callback
+ (directive->callback)(context);
+
+ // We expect the directive callback to consume the entire line, so if
+ // it hasn't that is a parse error.
+ expectEndOfDirective(context);
+}
+
+// Read one token using the full preprocessor, with all its behaviors.
+static Token ReadToken(Preprocessor* preprocessor)
+{
+ for (;;)
+ {
+ // Look at the next raw token in the input.
+ Token const& token = PeekRawToken(preprocessor);
+
+ // If we have a directive (`#` at start of line) then handle it
+ if ((token.Type == TokenType::Pound) && (token.flags & TokenFlag::AtStartOfLine))
+ {
+ // Skip the `#`
+ AdvanceRawToken(preprocessor);
+
+ // Create a context for parsing the directive
+ PreprocessorDirectiveContext directiveContext;
+ directiveContext.preprocessor = preprocessor;
+ directiveContext.parseError = false;
+ directiveContext.haveDoneEndOfDirectiveChecks = false;
+
+ // Parse and handle the directive
+ HandleDirective(&directiveContext);
+ continue;
+ }
+
+ // otherwise, if we are currently in a skipping mode, then skip tokens
+ if (IsSkipping(preprocessor))
+ {
+ AdvanceRawToken(preprocessor);
+ continue;
+ }
+
+ // otherwise read a token, which may involve macro expansion
+ return AdvanceToken(preprocessor);
+ }
+}
+
+// intialize a preprocessor context, using the given sink for errros
+static void InitializePreprocessor(
+ Preprocessor* preprocessor,
+ DiagnosticSink* sink)
+{
+ preprocessor->sink = sink;
+ preprocessor->includeHandler = NULL;
+ preprocessor->endOfFileToken.Type = TokenType::EndOfFile;
+ preprocessor->endOfFileToken.flags = TokenFlag::AtStartOfLine;
+}
+
+// clean up after an environment
+PreprocessorEnvironment::~PreprocessorEnvironment()
+{
+ for (auto pair : this->macros)
+ {
+ DestroyMacro(NULL, pair.Value);
+ }
+}
+
+// finalize a preprocessor and free any memory still in use
+static void FinalizePreprocessor(
+ Preprocessor* preprocessor)
+{
+ // Clear out any waiting input streams
+ PreprocessorInputStream* input = preprocessor->inputStream;
+ while (input)
+ {
+ PreprocessorInputStream* parent = input->parent;
+ DestroyInputStream(preprocessor, input);
+ input = parent;
+ }
+
+#if 0
+ // clean up any macros that were allocated
+ for (auto pair : preprocessor->globalEnv.macros)
+ {
+ DestroyMacro(preprocessor, pair.Value);
+ }
+#endif
+}
+
+// Add a simple macro definition from a string (e.g., for a
+// `-D` option passed on the command line
+static void DefineMacro(
+ Preprocessor* preprocessor,
+ String const& key,
+ String const& value)
+{
+ String fileName = "command line";
+ PreprocessorMacro* macro = CreateMacro(preprocessor);
+
+ // Use existing `Lexer` to generate a token stream.
+ Lexer lexer(fileName, value, GetSink(preprocessor));
+ macro->tokens = lexer.lexAllTokens();
+ macro->nameToken = Token(TokenType::Identifier, key, 0, 0, 0, fileName);
+
+ PreprocessorMacro* oldMacro = NULL;
+ if (preprocessor->globalEnv.macros.TryGetValue(key, oldMacro))
+ {
+ DestroyMacro(preprocessor, oldMacro);
+ }
+
+ preprocessor->globalEnv.macros[key] = macro;
+}
+
+// read the entire input into tokens
+static TokenList ReadAllTokens(
+ Preprocessor* preprocessor)
+{
+ TokenList tokens;
+ for (;;)
+ {
+ Token token = ReadToken(preprocessor);
+
+ tokens.mTokens.Add(token);
+
+ // Note: we include the EOF token in the list,
+ // since that is expected by the `TokenList` type.
+ if (token.Type == TokenType::EndOfFile)
+ break;
+ }
+ return tokens;
+}
+
+TokenList preprocessSource(
+ CoreLib::String const& source,
+ CoreLib::String const& fileName,
+ DiagnosticSink* sink,
+ IncludeHandler* includeHandler,
+ CoreLib::Dictionary<CoreLib::String, CoreLib::String> defines,
+ ProgramSyntaxNode* syntax)
+{
+ Preprocessor preprocessor;
+ InitializePreprocessor(&preprocessor, sink);
+ preprocessor.syntax = syntax;
+
+ preprocessor.includeHandler = includeHandler;
+ for (auto p : defines)
+ {
+ DefineMacro(&preprocessor, p.Key, p.Value);
+ }
+
+ // create an initial input stream based on the provided buffer
+ preprocessor.inputStream = CreateInputStreamForSource(&preprocessor, source, fileName);
+
+ TokenList tokens = ReadAllTokens(&preprocessor);
+
+ FinalizePreprocessor(&preprocessor);
+
+ // debugging: build the pre-processed source back together
+#if 0
+ StringBuilder sb;
+ for (auto t : tokens)
+ {
+ if (t.flags & TokenFlag::AtStartOfLine)
+ {
+ sb << "\n";
+ }
+ else if (t.flags & TokenFlag::AfterWhitespace)
+ {
+ sb << " ";
+ }
+
+ sb << t.Content;
+ }
+
+ String s = sb.ProduceString();
+#endif
+
+ return tokens;
+}
+
+}}
diff --git a/source/slang/preprocessor.h b/source/slang/preprocessor.h
new file mode 100644
index 000000000..ab72f3f87
--- /dev/null
+++ b/source/slang/preprocessor.h
@@ -0,0 +1,35 @@
+// Preprocessor.h
+#ifndef SLANG_PREPROCESSOR_H_INCLUDED
+#define SLANG_PREPROCESSOR_H_INCLUDED
+
+#include "../core/basic.h"
+#include "../slang/lexer.h"
+
+namespace Slang{ namespace Compiler {
+
+class DiagnosticSink;
+class ProgramSyntaxNode;
+
+// Callback interface for the preprocessor to use when looking
+// for files in `#include` directives.
+struct IncludeHandler
+{
+ virtual bool TryToFindIncludeFile(
+ CoreLib::String const& pathToInclude,
+ CoreLib::String const& pathIncludedFrom,
+ CoreLib::String* outFoundPath,
+ CoreLib::String* outFoundSource) = 0;
+};
+
+// Take a string of source code and preprocess it into a list of tokens.
+TokenList preprocessSource(
+ CoreLib::String const& source,
+ CoreLib::String const& fileName,
+ DiagnosticSink* sink,
+ IncludeHandler* includeHandler,
+ CoreLib::Dictionary<CoreLib::String, CoreLib::String> defines,
+ ProgramSyntaxNode* syntax);
+
+}}
+
+#endif
diff --git a/source/slang/profile-defs.h b/source/slang/profile-defs.h
new file mode 100644
index 000000000..76ba476bb
--- /dev/null
+++ b/source/slang/profile-defs.h
@@ -0,0 +1,123 @@
+//
+
+// Define all the various language "profiles" we want to support.
+
+#ifndef LANGUAGE
+#define LANGUAGE(TAG, NAME) /* emptry */
+#endif
+
+#ifndef LANGUAGE_ALIAS
+#define LANGUAGE_ALIAS(TAG, NAME) /* empty */
+#endif
+
+#ifndef PROFILE_FAMILY
+#define PROFILE_FAMILY(TAG) /* empty */
+#endif
+
+#ifndef PROFILE_VERSION
+#define PROFILE_VERSION(TAG, FAMILY) /* empty */
+#endif
+
+#ifndef PROFILE_STAGE
+#define PROFILE_STAGE(TAG, NAME, VAL) /* empty */
+#endif
+
+#ifndef PROFILE_STAGE_ALIAS
+#define PROFILE_STAGE_ALIAS(TAG, NAME) /* empty */
+#endif
+
+
+#ifndef PROFILE
+#define PROFILE(TAG, NAME, STAGE, VERSION) /* empty */
+#endif
+
+#ifndef PROFILE_ALIAS
+#define PROFILE_ALIAS(TAG, NAME) /* empty */
+#endif
+
+// Source and destination languages
+
+LANGUAGE(HLSL, hlsl)
+LANGUAGE(DXBytecode, dxbc)
+LANGUAGE(DXBytecodeAssembly,dxbc_asm)
+LANGUAGE(DXIL, dxil)
+LANGUAGE(DXILAssembly, dxil_asm)
+LANGUAGE(GLSL, glsl)
+LANGUAGE(GLSL_ES, glsl_es)
+LANGUAGE(GLSL_VK, glsl_vk)
+LANGUAGE(SPIRV, spirv)
+LANGUAGE(SPIRV_GL, spirv_gl)
+
+LANGUAGE_ALIAS(GLSL, glsl_gl)
+LANGUAGE_ALIAS(SPIRV, spirv_vk)
+
+
+// Pipeline stages to target
+PROFILE_STAGE(Vertex, vertex, SLANG_STAGE_VERTEX)
+PROFILE_STAGE(Hull, hull, SLANG_STAGE_HULL)
+PROFILE_STAGE(Domain, domain, SLANG_STAGE_DOMAIN)
+PROFILE_STAGE(Geometry, geometry, SLANG_STAGE_GEOMETRY)
+PROFILE_STAGE(Fragment, fragment, SLANG_STAGE_FRAGMENT)
+PROFILE_STAGE(Compute, compute, SLANG_STAGE_COMPUTE)
+
+PROFILE_STAGE_ALIAS(Fragment, pixel)
+
+// Profile families
+
+PROFILE_FAMILY(DX)
+PROFILE_FAMILY(GLSL)
+PROFILE_FAMILY(SPRIV)
+
+// Profile versions
+
+
+PROFILE_VERSION(DX_4_0, DX)
+PROFILE_VERSION(DX_4_0_Level_9_0, DX)
+PROFILE_VERSION(DX_4_0_Level_9_1, DX)
+PROFILE_VERSION(DX_4_0_Level_9_3, DX)
+PROFILE_VERSION(DX_4_1, DX)
+PROFILE_VERSION(DX_5_0, DX)
+
+PROFILE_VERSION(GLSL, GLSL)
+
+
+// Specific profiles
+
+PROFILE(DX_Compute_4_0, cs_4_0, Compute, DX_4_0)
+PROFILE(DX_Compute_4_1, cs_4_1, Compute, DX_4_1)
+PROFILE(DX_Compute_5_0, cs_5_0, Compute, DX_5_0)
+PROFILE(DX_Domain_5_0, ds_5_0, Domain, DX_5_0)
+PROFILE(DX_Geometry_4_0, gs_4_0, Geometry, DX_4_0)
+PROFILE(DX_Geometry_4_1, gs_4_1, Geometry, DX_4_1)
+PROFILE(DX_Geometry_5_0, gs_5_0, Geometry, DX_5_0)
+PROFILE(DX_Hull_5_0, hs_5_0, Hull, DX_5_0)
+PROFILE(DX_Fragment_4_0, ps_4_0, Fragment, DX_4_0)
+PROFILE(DX_Fragment_4_0_Level_9_0, ps_4_0_level_9_0, Fragment, DX_4_0_Level_9_0)
+PROFILE(DX_Fragment_4_0_Level_9_1, ps_4_0_level_9_1, Fragment, DX_4_0_Level_9_1)
+PROFILE(DX_Fragment_4_0_Level_9_3, ps_4_0_level_9_3, Fragment, DX_4_0_Level_9_3)
+PROFILE(DX_Fragment_4_1, ps_4_1, Fragment, DX_4_1)
+PROFILE(DX_Fragment_5_0, ps_5_0, Fragment, DX_5_0)
+PROFILE(DX_Vertex_4_0, vs_4_0, Vertex, DX_4_0)
+PROFILE(DX_Vertex_4_0_Level_9_0, vs_4_0_level_9_0, Vertex, DX_4_0_Level_9_0)
+PROFILE(DX_Vertex_4_0_Level_9_1, vs_4_0_level_9_1, Vertex, DX_4_0_Level_9_1)
+PROFILE(DX_Vertex_4_0_Level_9_3, vs_4_0_level_9_3, Vertex, DX_4_0_Level_9_3)
+PROFILE(DX_Vertex_4_1, vs_4_1, Vertex, DX_4_1)
+PROFILE(DX_Vertex_5_0, vs_5_0, Vertex, DX_5_0)
+
+//
+
+PROFILE(GLSL_Compute, glsl_compute, Compute, GLSL)
+PROFILE(GLSL_Vertex, glsl_vertex, Vertex, GLSL)
+PROFILE(GLSL_Fragment, glsl_fragment, Fragment, GLSL)
+PROFILE(GLSL_Geometry, glsl_geometry, Geometry, GLSL)
+PROFILE(GLSL_TessControl, glsl_tess_control, Hull, GLSL)
+PROFILE(GLSL_TessEval, glsl_tess_eval, Domain, GLSL)
+
+#undef LANGUAGE
+#undef LANGUAGE_ALIAS
+#undef PROFILE_FAMILY
+#undef PROFILE_VERSION
+#undef PROFILE_STAGE
+#undef PROFILE_STAGE_ALIAS
+#undef PROFILE
+#undef PROFILE_ALIAS
diff --git a/source/slang/profile.cpp b/source/slang/profile.cpp
new file mode 100644
index 000000000..923dc2841
--- /dev/null
+++ b/source/slang/profile.cpp
@@ -0,0 +1,20 @@
+// profile.cpp
+#include "Profile.h"
+
+
+namespace Slang {
+namespace Compiler {
+
+
+ProfileFamily getProfileFamily(ProfileVersion version)
+{
+ switch( version )
+ {
+ default: return ProfileFamily::Unknown;
+
+#define PROFILE_VERSION(TAG, FAMILY) case ProfileVersion::TAG: return ProfileFamily::FAMILY;
+#include "profile-defs.h"
+ }
+}
+
+}}
diff --git a/source/slang/profile.h b/source/slang/profile.h
new file mode 100644
index 000000000..31465c38c
--- /dev/null
+++ b/source/slang/profile.h
@@ -0,0 +1,84 @@
+#ifndef SLANG_PROFILE_H_INCLUDED
+#define SLANG_PROFILE_H_INCLUDED
+
+#include "../core/basic.h"
+#include "../../slang.h"
+
+namespace Slang
+{
+ namespace Compiler
+ {
+ // Flavors of translation unit
+ enum class SourceLanguage : SlangSourceLanguage
+ {
+ Unknown = SLANG_SOURCE_LANGUAGE_UNKNOWN, // should not occur
+ Slang = SLANG_SOURCE_LANGUAGE_SLANG,
+ HLSL = SLANG_SOURCE_LANGUAGE_HLSL,
+ GLSL = SLANG_SOURCE_LANGUAGE_GLSL,
+
+ // A separate PACKAGE of Slang code that has been imported
+ ImportedSlangCode,
+ };
+
+ // TODO(tfoley): This should merge with the above...
+ enum class Language
+ {
+ Unknown,
+#define LANGUAGE(TAG, NAME) TAG,
+#include "profile-defs.h"
+ };
+
+ enum class ProfileFamily
+ {
+ Unknown,
+#define PROFILE_FAMILY(TAG) TAG,
+#include "profile-defs.h"
+ };
+
+ enum class ProfileVersion
+ {
+ Unknown,
+#define PROFILE_VERSION(TAG, FAMILY) TAG,
+#include "profile-defs.h"
+ };
+
+ enum class Stage : SlangStage
+ {
+ Unknown = SLANG_STAGE_NONE,
+#define PROFILE_STAGE(TAG, NAME, VAL) TAG = VAL,
+#include "profile-defs.h"
+ };
+
+ ProfileFamily getProfileFamily(ProfileVersion version);
+
+ struct Profile
+ {
+ typedef uint32_t RawVal;
+ enum : RawVal
+ {
+ Unknown,
+
+#define PROFILE(TAG, NAME, STAGE, VERSION) TAG = (uint32_t(Stage::STAGE) << 16) | uint32_t(ProfileVersion::VERSION),
+#include "profile-defs.h"
+ };
+
+ Profile() {}
+ Profile(RawVal raw)
+ : raw(raw)
+ {}
+
+ bool operator==(Profile const& other) const { return raw == other.raw; }
+ bool operator!=(Profile const& other) const { return raw != other.raw; }
+
+ Stage GetStage() const { return Stage((uint32_t(raw) >> 16) & 0xFFFF); }
+ ProfileVersion GetVersion() const { return ProfileVersion(uint32_t(raw) & 0xFFFF); }
+ ProfileFamily getFamily() const { return getProfileFamily(GetVersion()); }
+
+ static Profile LookUp(char const* name);
+
+ RawVal raw = Unknown;
+ };
+ }
+}
+
+#endif
diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp
new file mode 100644
index 000000000..1a56a8c1f
--- /dev/null
+++ b/source/slang/reflection.cpp
@@ -0,0 +1,1404 @@
+// reflection.cpp
+#include "reflection.h"
+
+#include "compiler.h"
+#include "type-layout.h"
+
+#include <assert.h>
+
+// Implementation to back public-facing reflection API
+
+using namespace Slang;
+using namespace Slang::Compiler;
+
+
+// Conversion routines to help with strongly-typed reflection API
+
+static inline ExpressionType* convert(SlangReflectionType* type)
+{
+ return (ExpressionType*) type;
+}
+
+static inline SlangReflectionType* convert(ExpressionType* type)
+{
+ return (SlangReflectionType*) type;
+}
+
+static inline TypeLayout* convert(SlangReflectionTypeLayout* type)
+{
+ return (TypeLayout*) type;
+}
+
+static inline SlangReflectionTypeLayout* convert(TypeLayout* type)
+{
+ return (SlangReflectionTypeLayout*) type;
+}
+
+static inline VarDeclBase* convert(SlangReflectionVariable* var)
+{
+ return (VarDeclBase*) var;
+}
+
+static inline SlangReflectionVariable* convert(VarDeclBase* var)
+{
+ return (SlangReflectionVariable*) var;
+}
+
+static inline VarLayout* convert(SlangReflectionVariableLayout* var)
+{
+ return (VarLayout*) var;
+}
+
+static inline SlangReflectionVariableLayout* convert(VarLayout* var)
+{
+ return (SlangReflectionVariableLayout*) var;
+}
+
+static inline EntryPointLayout* convert(SlangReflectionEntryPoint* entryPoint)
+{
+ return (EntryPointLayout*) entryPoint;
+}
+
+static inline SlangReflectionEntryPoint* convert(EntryPointLayout* entryPoint)
+{
+ return (SlangReflectionEntryPoint*) entryPoint;
+}
+
+
+static inline ProgramLayout* convert(SlangReflection* program)
+{
+ return (ProgramLayout*) program;
+}
+
+static inline SlangReflection* convert(ProgramLayout* program)
+{
+ return (SlangReflection*) program;
+}
+
+// Type Reflection
+
+
+SLANG_API SlangTypeKind spReflectionType_GetKind(SlangReflectionType* inType)
+{
+ auto type = convert(inType);
+ if(!type) return SLANG_TYPE_KIND_NONE;
+
+ // TODO(tfoley: Don't emit the same type more than once...
+
+ if (auto basicType = type->As<BasicExpressionType>())
+ {
+ return SLANG_TYPE_KIND_SCALAR;
+ }
+ else if (auto vectorType = type->As<VectorExpressionType>())
+ {
+ return SLANG_TYPE_KIND_VECTOR;
+ }
+ else if (auto matrixType = type->As<MatrixExpressionType>())
+ {
+ return SLANG_TYPE_KIND_MATRIX;
+ }
+ else if (auto constantBufferType = type->As<ConstantBufferType>())
+ {
+ return SLANG_TYPE_KIND_CONSTANT_BUFFER;
+ }
+ else if (auto samplerStateType = type->As<SamplerStateType>())
+ {
+ return SLANG_TYPE_KIND_SAMPLER_STATE;
+ }
+ else if (auto textureType = type->As<TextureType>())
+ {
+ return SLANG_TYPE_KIND_RESOURCE;
+ }
+
+ // TODO: need a better way to handle this stuff...
+#define CASE(TYPE) \
+ else if(type->As<TYPE>()) do { \
+ return SLANG_TYPE_KIND_RESOURCE; \
+ } while(0)
+
+ CASE(HLSLBufferType);
+ CASE(HLSLRWBufferType);
+ CASE(HLSLBufferType);
+ CASE(HLSLRWBufferType);
+ CASE(HLSLStructuredBufferType);
+ CASE(HLSLRWStructuredBufferType);
+ CASE(HLSLAppendStructuredBufferType);
+ CASE(HLSLConsumeStructuredBufferType);
+ CASE(HLSLByteAddressBufferType);
+ CASE(HLSLRWByteAddressBufferType);
+ CASE(UntypedBufferResourceType);
+#undef CASE
+
+ else if (auto arrayType = type->As<ArrayExpressionType>())
+ {
+ return SLANG_TYPE_KIND_ARRAY;
+ }
+ else if( auto declRefType = type->As<DeclRefType>() )
+ {
+ auto declRef = declRefType->declRef;
+ if( auto structDeclRef = declRef.As<StructDeclRef>() )
+ {
+ return SLANG_TYPE_KIND_STRUCT;
+ }
+ }
+
+ assert(!"unexpected");
+ return SLANG_TYPE_KIND_NONE;
+}
+
+SLANG_API unsigned int spReflectionType_GetFieldCount(SlangReflectionType* inType)
+{
+ auto type = convert(inType);
+ if(!type) return 0;
+
+ // TODO: maybe filter based on kind
+
+ if(auto declRefType = dynamic_cast<DeclRefType*>(type))
+ {
+ auto declRef = declRefType->declRef;
+ if( auto structDeclRef = declRef.As<StructDeclRef>())
+ {
+ return structDeclRef.GetFields().Count();
+ }
+ }
+
+ return 0;
+}
+
+SLANG_API SlangReflectionVariable* spReflectionType_GetFieldByIndex(SlangReflectionType* inType, unsigned index)
+{
+ auto type = convert(inType);
+ if(!type) return nullptr;
+
+ // TODO: maybe filter based on kind
+
+ if(auto declRefType = dynamic_cast<DeclRefType*>(type))
+ {
+ auto declRef = declRefType->declRef;
+ if( auto structDeclRef = declRef.As<StructDeclRef>())
+ {
+ auto fieldDeclRef = structDeclRef.GetFields().ToArray()[index];
+ return (SlangReflectionVariable*) fieldDeclRef.GetDecl();
+ }
+ }
+
+ return nullptr;
+}
+
+SLANG_API size_t spReflectionType_GetElementCount(SlangReflectionType* inType)
+{
+ auto type = convert(inType);
+ if(!type) return 0;
+
+ if(auto arrayType = dynamic_cast<ArrayExpressionType*>(type))
+ {
+ return GetIntVal(arrayType->ArrayLength);
+ }
+ else if( auto vectorType = dynamic_cast<VectorExpressionType*>(type))
+ {
+ return GetIntVal(vectorType->elementCount);
+ }
+
+ return 0;
+}
+
+SLANG_API SlangReflectionType* spReflectionType_GetElementType(SlangReflectionType* inType)
+{
+ auto type = convert(inType);
+ if(!type) return nullptr;
+
+ if(auto arrayType = dynamic_cast<ArrayExpressionType*>(type))
+ {
+ return (SlangReflectionType*) arrayType->BaseType.Ptr();
+ }
+ else if( auto constantBufferType = dynamic_cast<ConstantBufferType*>(type))
+ {
+ return convert(constantBufferType->elementType.Ptr());
+ }
+ else if( auto vectorType = dynamic_cast<VectorExpressionType*>(type))
+ {
+ return convert(vectorType->elementType.Ptr());
+ }
+ else if( auto matrixType = dynamic_cast<MatrixExpressionType*>(type))
+ {
+ return convert(matrixType->getElementType());
+ }
+
+ return nullptr;
+}
+
+SLANG_API unsigned int spReflectionType_GetRowCount(SlangReflectionType* inType)
+{
+ auto type = convert(inType);
+ if(!type) return 0;
+
+ if(auto matrixType = dynamic_cast<MatrixExpressionType*>(type))
+ {
+ return GetIntVal(matrixType->getRowCount());
+ }
+ else if(auto vectorType = dynamic_cast<VectorExpressionType*>(type))
+ {
+ return 1;
+ }
+ else if( auto basicType = dynamic_cast<BasicExpressionType*>(type) )
+ {
+ return 1;
+ }
+
+ return 0;
+}
+
+SLANG_API unsigned int spReflectionType_GetColumnCount(SlangReflectionType* inType)
+{
+ auto type = convert(inType);
+ if(!type) return 0;
+
+ if(auto matrixType = dynamic_cast<MatrixExpressionType*>(type))
+ {
+ return GetIntVal(matrixType->getColumnCount());
+ }
+ else if(auto vectorType = dynamic_cast<VectorExpressionType*>(type))
+ {
+ return GetIntVal(vectorType->elementCount);
+ }
+ else if( auto basicType = dynamic_cast<BasicExpressionType*>(type) )
+ {
+ return 1;
+ }
+
+ return 0;
+}
+
+SLANG_API SlangScalarType spReflectionType_GetScalarType(SlangReflectionType* inType)
+{
+ auto type = convert(inType);
+ if(!type) return 0;
+
+ if(auto matrixType = dynamic_cast<MatrixExpressionType*>(type))
+ {
+ type = matrixType->getElementType();
+ }
+ else if(auto vectorType = dynamic_cast<VectorExpressionType*>(type))
+ {
+ type = vectorType->elementType.Ptr();
+ }
+
+ if(auto basicType = dynamic_cast<BasicExpressionType*>(type))
+ {
+ switch (basicType->BaseType)
+ {
+#define CASE(BASE, TAG) \
+ case BaseType::BASE: return SLANG_SCALAR_TYPE_##TAG
+
+ CASE(Void, VOID);
+ CASE(Int, INT32);
+ CASE(Float, FLOAT32);
+ CASE(UInt, UINT32);
+ CASE(Bool, BOOL);
+ CASE(UInt64, UINT64);
+
+#undef CASE
+
+ default:
+ assert(!"unexpected");
+ return SLANG_SCALAR_TYPE_NONE;
+ break;
+ }
+ }
+
+ return SLANG_SCALAR_TYPE_NONE;
+}
+
+SLANG_API SlangResourceShape spReflectionType_GetResourceShape(SlangReflectionType* inType)
+{
+ auto type = convert(inType);
+ if(!type) return 0;
+
+ while(auto arrayType = type->As<ArrayExpressionType>())
+ {
+ type = arrayType->BaseType.Ptr();
+ }
+
+ if(auto textureType = type->As<TextureType>())
+ {
+ return textureType->getShape();
+ }
+
+ // TODO: need a better way to handle this stuff...
+#define CASE(TYPE, SHAPE, ACCESS) \
+ else if(type->As<TYPE>()) do { \
+ return SHAPE; \
+ } while(0)
+
+ CASE(HLSLBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ);
+ CASE(HLSLRWBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE);
+ CASE(HLSLBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ);
+ CASE(HLSLRWBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE);
+ CASE(HLSLStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ);
+ CASE(HLSLRWStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE);
+ CASE(HLSLAppendStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_APPEND);
+ CASE(HLSLConsumeStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_CONSUME);
+ CASE(HLSLByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ);
+ CASE(HLSLRWByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE);
+ CASE(UntypedBufferResourceType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ);
+#undef CASE
+
+ return SLANG_RESOURCE_NONE;
+}
+
+SLANG_API SlangResourceAccess spReflectionType_GetResourceAccess(SlangReflectionType* inType)
+{
+ auto type = convert(inType);
+ if(!type) return 0;
+
+ while(auto arrayType = type->As<ArrayExpressionType>())
+ {
+ type = arrayType->BaseType.Ptr();
+ }
+
+ if(auto textureType = type->As<TextureType>())
+ {
+ return textureType->getAccess();
+ }
+
+ // TODO: need a better way to handle this stuff...
+#define CASE(TYPE, SHAPE, ACCESS) \
+ else if(type->As<TYPE>()) do { \
+ return ACCESS; \
+ } while(0)
+
+ CASE(HLSLBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ);
+ CASE(HLSLRWBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE);
+ CASE(HLSLBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ);
+ CASE(HLSLRWBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE);
+ CASE(HLSLStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ);
+ CASE(HLSLRWStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE);
+ CASE(HLSLAppendStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_APPEND);
+ CASE(HLSLConsumeStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_CONSUME);
+ CASE(HLSLByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ);
+ CASE(HLSLRWByteAddressBufferType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE);
+ CASE(UntypedBufferResourceType, SLANG_BYTE_ADDRESS_BUFFER, SLANG_RESOURCE_ACCESS_READ);
+#undef CASE
+
+ return SLANG_RESOURCE_ACCESS_NONE;
+}
+
+SLANG_API SlangReflectionType* spReflectionType_GetResourceResultType(SlangReflectionType* inType)
+{
+ auto type = convert(inType);
+ if(!type) return nullptr;
+
+ while(auto arrayType = type->As<ArrayExpressionType>())
+ {
+ type = arrayType->BaseType.Ptr();
+ }
+
+ if (auto textureType = type->As<TextureType>())
+ {
+ return convert(textureType->elementType.Ptr());
+ }
+
+ // TODO: need a better way to handle this stuff...
+#define CASE(TYPE, SHAPE, ACCESS) \
+ else if(type->As<TYPE>()) do { \
+ return convert(type->As<TYPE>()->elementType.Ptr()); \
+ } while(0)
+
+ CASE(HLSLBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ);
+ CASE(HLSLRWBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE);
+ CASE(HLSLBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ);
+ CASE(HLSLRWBufferType, SLANG_TEXTURE_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE);
+
+ // TODO: structured buffer needs to expose type layout!
+
+ CASE(HLSLStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ);
+ CASE(HLSLRWStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_READ_WRITE);
+ CASE(HLSLAppendStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_APPEND);
+ CASE(HLSLConsumeStructuredBufferType, SLANG_STRUCTURED_BUFFER, SLANG_RESOURCE_ACCESS_CONSUME);
+#undef CASE
+
+ return nullptr;
+}
+
+// Type Layout Reflection
+
+SLANG_API SlangReflectionType* spReflectionTypeLayout_GetType(SlangReflectionTypeLayout* inTypeLayout)
+{
+ auto typeLayout = convert(inTypeLayout);
+ if(!typeLayout) return nullptr;
+
+ return (SlangReflectionType*) typeLayout->type.Ptr();
+}
+
+SLANG_API size_t spReflectionTypeLayout_GetSize(SlangReflectionTypeLayout* inTypeLayout, SlangParameterCategory category)
+{
+ auto typeLayout = convert(inTypeLayout);
+ if(!typeLayout) return 0;
+
+ auto info = typeLayout->FindResourceInfo(LayoutResourceKind(category));
+ if(!info) return 0;
+
+ return info->count;
+}
+
+SLANG_API SlangReflectionVariableLayout* spReflectionTypeLayout_GetFieldByIndex(SlangReflectionTypeLayout* inTypeLayout, unsigned index)
+{
+ auto typeLayout = convert(inTypeLayout);
+ if(!typeLayout) return nullptr;
+
+ if(auto structTypeLayout = dynamic_cast<StructTypeLayout*>(typeLayout))
+ {
+ return (SlangReflectionVariableLayout*) structTypeLayout->fields[index].Ptr();
+ }
+
+ return nullptr;
+}
+
+SLANG_API size_t spReflectionTypeLayout_GetElementStride(SlangReflectionTypeLayout* inTypeLayout, SlangParameterCategory category)
+{
+ auto typeLayout = convert(inTypeLayout);
+ if(!typeLayout) return 0;
+
+ if( auto arrayTypeLayout = dynamic_cast<ArrayTypeLayout*>(typeLayout))
+ {
+ if(category == SLANG_PARAMETER_CATEGORY_UNIFORM)
+ {
+ return arrayTypeLayout->uniformStride;
+ }
+ else
+ {
+ auto elementTypeLayout = arrayTypeLayout->elementTypeLayout;
+ auto info = elementTypeLayout->FindResourceInfo(LayoutResourceKind(category));
+ if(!info) return 0;
+ return info->count;
+ }
+ }
+
+ return 0;
+}
+
+SLANG_API SlangReflectionTypeLayout* spReflectionTypeLayout_GetElementTypeLayout(SlangReflectionTypeLayout* inTypeLayout)
+{
+ auto typeLayout = convert(inTypeLayout);
+ if(!typeLayout) return nullptr;
+
+ if( auto arrayTypeLayout = dynamic_cast<ArrayTypeLayout*>(typeLayout))
+ {
+ return (SlangReflectionTypeLayout*) arrayTypeLayout->elementTypeLayout.Ptr();
+ }
+ else if( auto constantBufferTypeLayout = dynamic_cast<ParameterBlockTypeLayout*>(typeLayout))
+ {
+ return convert(constantBufferTypeLayout->elementTypeLayout.Ptr());
+ }
+ else if( auto structuredBufferTypeLayout = dynamic_cast<StructuredBufferTypeLayout*>(typeLayout))
+ {
+ return convert(structuredBufferTypeLayout->elementTypeLayout.Ptr());
+ }
+
+ return nullptr;
+}
+
+static SlangParameterCategory getParameterCategory(
+ LayoutResourceKind kind)
+{
+ return SlangParameterCategory(kind);
+}
+
+static SlangParameterCategory getParameterCategory(
+ TypeLayout* typeLayout)
+{
+ auto resourceInfoCount = typeLayout->resourceInfos.Count();
+ if(resourceInfoCount == 1)
+ {
+ return getParameterCategory(typeLayout->resourceInfos[0].kind);
+ }
+ else if(resourceInfoCount == 0)
+ {
+ // TODO: can this ever happen?
+ return SLANG_PARAMETER_CATEGORY_NONE;
+ }
+ return SLANG_PARAMETER_CATEGORY_MIXED;
+}
+
+SLANG_API SlangParameterCategory spReflectionTypeLayout_GetParameterCategory(SlangReflectionTypeLayout* inTypeLayout)
+{
+ auto typeLayout = convert(inTypeLayout);
+ if(!typeLayout) return SLANG_PARAMETER_CATEGORY_NONE;
+
+ return getParameterCategory(typeLayout);
+}
+
+SLANG_API unsigned spReflectionTypeLayout_GetCategoryCount(SlangReflectionTypeLayout* inTypeLayout)
+{
+ auto typeLayout = convert(inTypeLayout);
+ if(!typeLayout) return 0;
+
+ return (unsigned) typeLayout->resourceInfos.Count();
+}
+
+SLANG_API SlangParameterCategory spReflectionTypeLayout_GetCategoryByIndex(SlangReflectionTypeLayout* inTypeLayout, unsigned index)
+{
+ auto typeLayout = convert(inTypeLayout);
+ if(!typeLayout) return SLANG_PARAMETER_CATEGORY_NONE;
+
+ return typeLayout->resourceInfos[index].kind;
+}
+
+// Variable Reflection
+
+SLANG_API char const* spReflectionVariable_GetName(SlangReflectionVariable* inVar)
+{
+ auto var = convert(inVar);
+ if(!var) return nullptr;
+
+ // If the variable is one that has an "external" name that is supposed
+ // to be exposed for reflection, then report it here
+ if(auto reflectionNameMod = var->FindModifier<ParameterBlockReflectionName>())
+ return reflectionNameMod->nameToken.Content.Buffer();
+
+ return var->getName().Buffer();
+}
+
+SLANG_API SlangReflectionType* spReflectionVariable_GetType(SlangReflectionVariable* inVar)
+{
+ auto var = convert(inVar);
+ if(!var) return nullptr;
+
+ return convert(var->getType());
+}
+
+// Variable Layout Reflection
+
+SLANG_API SlangReflectionVariable* spReflectionVariableLayout_GetVariable(SlangReflectionVariableLayout* inVarLayout)
+{
+ auto varLayout = convert(inVarLayout);
+ if(!varLayout) return nullptr;
+
+ return convert(varLayout->varDecl.GetDecl());
+}
+
+SLANG_API SlangReflectionTypeLayout* spReflectionVariableLayout_GetTypeLayout(SlangReflectionVariableLayout* inVarLayout)
+{
+ auto varLayout = convert(inVarLayout);
+ if(!varLayout) return nullptr;
+
+ return convert(varLayout->getTypeLayout());
+}
+
+SLANG_API size_t spReflectionVariableLayout_GetOffset(SlangReflectionVariableLayout* inVarLayout, SlangParameterCategory category)
+{
+ auto varLayout = convert(inVarLayout);
+ if(!varLayout) return 0;
+
+ auto info = varLayout->FindResourceInfo(LayoutResourceKind(category));
+ if(!info) return 0;
+
+ return info->index;
+}
+
+SLANG_API size_t spReflectionVariableLayout_GetSpace(SlangReflectionVariableLayout* inVarLayout, SlangParameterCategory category)
+{
+ auto varLayout = convert(inVarLayout);
+ if(!varLayout) return 0;
+
+ auto info = varLayout->FindResourceInfo(LayoutResourceKind(category));
+ if(!info) return 0;
+
+ return info->space;
+}
+
+
+// Shader Parameter Reflection
+
+SLANG_API unsigned spReflectionParameter_GetBindingIndex(SlangReflectionParameter* inVarLayout)
+{
+ auto varLayout = convert(inVarLayout);
+ if(!varLayout) return 0;
+
+ if(varLayout->resourceInfos.Count() > 0)
+ {
+ return (unsigned) varLayout->resourceInfos[0].index;
+ }
+
+ return 0;
+}
+
+SLANG_API unsigned spReflectionParameter_GetBindingSpace(SlangReflectionParameter* inVarLayout)
+{
+ auto varLayout = convert(inVarLayout);
+ if(!varLayout) return 0;
+
+ if(varLayout->resourceInfos.Count() > 0)
+ {
+ return (unsigned) varLayout->resourceInfos[0].space;
+ }
+
+ return 0;
+}
+
+// Entry Point Reflection
+
+SLANG_API SlangStage spReflectionEntryPoint_getStage(SlangReflectionEntryPoint* inEntryPoint)
+{
+ auto entryPointLayout = convert(inEntryPoint);
+
+ if(!entryPointLayout) return SLANG_STAGE_NONE;
+
+ return SlangStage(entryPointLayout->profile.GetStage());
+}
+
+SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize(
+ SlangReflectionEntryPoint* inEntryPoint,
+ SlangUInt axisCount,
+ SlangUInt* outSizeAlongAxis)
+{
+ auto entryPointLayout = convert(inEntryPoint);
+
+ if(!entryPointLayout) return;
+ if(!axisCount) return;
+ if(!outSizeAlongAxis) return;
+
+ auto entryPointFunc = entryPointLayout->entryPoint;
+ if(!entryPointFunc) return;
+
+ auto numThreadsAttribute = entryPointFunc->FindModifier<HLSLNumThreadsAttribute>();
+ if(!numThreadsAttribute) return;
+
+ if(axisCount > 0) outSizeAlongAxis[0] = numThreadsAttribute->x;
+ if(axisCount > 1) outSizeAlongAxis[1] = numThreadsAttribute->y;
+ if(axisCount > 2) outSizeAlongAxis[2] = numThreadsAttribute->z;
+ for( SlangUInt aa = 3; aa < axisCount; ++aa )
+ {
+ outSizeAlongAxis[aa] = 1;
+ }
+}
+
+
+// Shader Reflection
+
+SLANG_API unsigned spReflection_GetParameterCount(SlangReflection* inProgram)
+{
+ auto program = convert(inProgram);
+ if(!program) return 0;
+
+ auto globalLayout = program->globalScopeLayout;
+ if(auto globalConstantBufferLayout = globalLayout.As<ParameterBlockTypeLayout>())
+ {
+ globalLayout = globalConstantBufferLayout->elementTypeLayout;
+ }
+
+ if(auto globalStructLayout = globalLayout.As<StructTypeLayout>())
+ {
+ return globalStructLayout->fields.Count();
+ }
+
+ return 0;
+}
+
+SLANG_API SlangReflectionParameter* spReflection_GetParameterByIndex(SlangReflection* inProgram, unsigned index)
+{
+ auto program = convert(inProgram);
+ if(!program) return nullptr;
+
+ auto globalLayout = program->globalScopeLayout;
+ if(auto globalConstantBufferLayout = globalLayout.As<ParameterBlockTypeLayout>())
+ {
+ globalLayout = globalConstantBufferLayout->elementTypeLayout;
+ }
+
+ if(auto globalStructLayout = globalLayout.As<StructTypeLayout>())
+ {
+ return convert(globalStructLayout->fields[index].Ptr());
+ }
+
+ return nullptr;
+}
+
+SLANG_API SlangUInt spReflection_getEntryPointCount(SlangReflection* inProgram)
+{
+ auto program = convert(inProgram);
+ if(!program) return 0;
+
+ return SlangUInt(program->entryPoints.Count());
+}
+
+SLANG_API SlangReflectionEntryPoint* spReflection_getEntryPointByIndex(SlangReflection* inProgram, SlangUInt index)
+{
+ auto program = convert(inProgram);
+ if(!program) return 0;
+
+ return convert(program->entryPoints[(int) index].Ptr());
+}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+namespace Slang {
+namespace Compiler {
+
+
+
+
+
+
+
+// Debug helper code: dump reflection data after generation
+
+struct PrettyWriter
+{
+ StringBuilder sb;
+ bool startOfLine = true;
+ int indent = 0;
+};
+
+static void adjust(PrettyWriter& writer)
+{
+ if (!writer.startOfLine)
+ return;
+
+ int indent = writer.indent;
+ for (int ii = 0; ii < indent; ++ii)
+ writer.sb << " ";
+
+ writer.startOfLine = false;
+}
+
+static void indent(PrettyWriter& writer)
+{
+ writer.indent++;
+}
+
+static void dedent(PrettyWriter& writer)
+{
+ writer.indent--;
+}
+
+static void write(PrettyWriter& writer, char const* text)
+{
+ // TODO: can do this more efficiently...
+ char const* cursor = text;
+ for(;;)
+ {
+ char c = *cursor++;
+ if (!c) break;
+
+ if (c == '\n')
+ {
+ writer.startOfLine = true;
+ }
+ else
+ {
+ adjust(writer);
+ }
+
+ writer.sb << c;
+ }
+}
+
+static void write(PrettyWriter& writer, UInt val)
+{
+ adjust(writer);
+ writer.sb << ((unsigned int) val);
+}
+
+static void emitReflectionVarInfoJSON(PrettyWriter& writer, slang::VariableReflection* var);
+static void emitReflectionTypeLayoutJSON(PrettyWriter& writer, slang::TypeLayoutReflection* type);
+static void emitReflectionTypeJSON(PrettyWriter& writer, slang::TypeReflection* type);
+
+static void emitReflectionVarBindingInfoJSON(
+ PrettyWriter& writer,
+ SlangParameterCategory category,
+ UInt index,
+ UInt count,
+ UInt space = 0)
+{
+ if( category == SLANG_PARAMETER_CATEGORY_UNIFORM )
+ {
+ write(writer,"\"kind\": \"uniform\"");
+ write(writer, ", ");
+ write(writer,"\"offset\": ");
+ write(writer, index);
+ write(writer, ", ");
+ write(writer, "\"size\": ");
+ write(writer, count);
+ }
+ else
+ {
+ write(writer, "\"kind\": \"");
+ switch( category )
+ {
+ #define CASE(NAME, KIND) case SLANG_PARAMETER_CATEGORY_##NAME: write(writer, #KIND); break
+ CASE(CONSTANT_BUFFER, constantBuffer);
+ CASE(SHADER_RESOURCE, shaderResource);
+ CASE(UNORDERED_ACCESS, unorderedAccess);
+ CASE(VERTEX_INPUT, vertexInput);
+ CASE(FRAGMENT_OUTPUT, fragmentOutput);
+ CASE(SAMPLER_STATE, samplerState);
+ #undef CASE
+
+ default:
+ write(writer, "unknown");
+ assert(!"unexpected");
+ break;
+ }
+ write(writer, "\"");
+ if( space )
+ {
+ write(writer, ", ");
+ write(writer, "\"space\": ");
+ write(writer, space);
+ }
+ write(writer, ", ");
+ write(writer, "\"index\": ");
+ write(writer, index);
+ if( count != 1)
+ {
+ write(writer, ", ");
+ write(writer, "\"count\": ");
+ write(writer, count);
+ }
+ }
+}
+
+static void emitReflectionVarBindingInfoJSON(
+ PrettyWriter& writer,
+ slang::VariableLayoutReflection* var)
+{
+ auto typeLayout = var->getTypeLayout();
+ auto categoryCount = var->getCategoryCount();
+
+ if( categoryCount != 1 )
+ {
+ write(writer,"\"bindings\": [\n");
+ }
+ else
+ {
+ write(writer,"\"binding\": ");
+ }
+ indent(writer);
+
+ for(uint32_t cc = 0; cc < categoryCount; ++cc )
+ {
+ auto category = var->getCategoryByIndex(cc);
+ auto index = var->getOffset(category);
+ auto space = var->getBindingSpace(category);
+ auto count = typeLayout->getSize(category);
+
+ if (cc != 0) write(writer, ",\n");
+
+ write(writer,"{");
+ emitReflectionVarBindingInfoJSON(
+ writer,
+ category,
+ index,
+ count,
+ space);
+ write(writer,"}");
+ }
+
+ dedent(writer);
+ if( categoryCount != 1 )
+ {
+ write(writer,"\n]");
+ }
+}
+
+static void emitReflectionNameInfoJSON(
+ PrettyWriter& writer,
+ char const* name)
+{
+ // TODO: deal with escaping special characters if/when needed
+ write(writer, "\"name\": \"");
+ write(writer, name);
+ write(writer, "\"");
+}
+
+static void emitReflectionVarLayoutJSON(
+ PrettyWriter& writer,
+ slang::VariableLayoutReflection* var)
+{
+ write(writer, "{\n");
+ indent(writer);
+
+ emitReflectionNameInfoJSON(writer, var->getName());
+ write(writer, ",\n");
+
+ write(writer, "\"type\": ");
+ emitReflectionTypeLayoutJSON(writer, var->getTypeLayout());
+ write(writer, ",\n");
+
+ emitReflectionVarBindingInfoJSON(writer, var);
+
+ dedent(writer);
+ write(writer, "\n}");
+}
+
+static void emitReflectionScalarTypeInfoJSON(
+ PrettyWriter& writer,
+ SlangScalarType scalarType)
+{
+ write(writer, "\"scalarType\": \"");
+ switch (scalarType)
+ {
+ default:
+ write(writer, "unknown");
+ assert(!"unexpected");
+ break;
+#define CASE(TAG, ID) case slang::TypeReflection::ScalarType::TAG: write(writer, #ID); break
+ CASE(Void, void);
+ CASE(Bool, bool);
+ CASE(Int32, int32);
+ CASE(UInt32, uint32);
+ CASE(Int64, int64);
+ CASE(UInt64, uint64);
+ CASE(Float16, float16);
+ CASE(Float32, float32);
+ CASE(Float64, float64);
+#undef CASE
+ }
+ write(writer, "\"");
+}
+
+static void emitReflectionTypeInfoJSON(
+ PrettyWriter& writer,
+ slang::TypeReflection* type)
+{
+ switch( type->getKind() )
+ {
+ case SLANG_TYPE_KIND_SAMPLER_STATE:
+ write(writer, "\"kind\": \"samplerState\"");
+ break;
+
+ case SLANG_TYPE_KIND_RESOURCE:
+ {
+ auto shape = type->getResourceShape();
+ auto access = type->getResourceAccess();
+ write(writer, "\"kind\": \"resource\"");
+ write(writer, ",\n");
+ write(writer, "\"baseShape\": \"");
+ switch (shape & SLANG_RESOURCE_BASE_SHAPE_MASK)
+ {
+ default:
+ write(writer, "unknown");
+ assert(!"unexpected");
+ break;
+
+#define CASE(SHAPE, NAME) case SLANG_##SHAPE: write(writer, #NAME); break
+ CASE(TEXTURE_1D, texture1D);
+ CASE(TEXTURE_2D, texture2D);
+ CASE(TEXTURE_3D, texture3D);
+ CASE(TEXTURE_CUBE, textureCube);
+ CASE(TEXTURE_BUFFER, textureBuffer);
+ CASE(STRUCTURED_BUFFER, structuredBuffer);
+ CASE(BYTE_ADDRESS_BUFFER, byteAddressBuffer);
+#undef CASE
+ }
+ write(writer, "\"");
+ if (shape & SLANG_TEXTURE_ARRAY_FLAG)
+ {
+ write(writer, ",\n");
+ write(writer, "\"array\": true");
+ }
+ if (shape & SLANG_TEXTURE_MULTISAMPLE_FLAG)
+ {
+ write(writer, ",\n");
+ write(writer, "\"multisample\": true");
+ }
+
+ if( access != SLANG_RESOURCE_ACCESS_READ )
+ {
+ write(writer, ",\n\"access\": \"");
+ switch(access)
+ {
+ default:
+ write(writer, "unknown");
+ assert(!"unexpected");
+ break;
+
+ case SLANG_RESOURCE_ACCESS_READ:
+ break;
+
+ case SLANG_RESOURCE_ACCESS_READ_WRITE: write(writer, "readWrite"); break;
+ case SLANG_RESOURCE_ACCESS_RASTER_ORDERED: write(writer, "rasterOrdered"); break;
+ case SLANG_RESOURCE_ACCESS_APPEND: write(writer, "append"); break;
+ case SLANG_RESOURCE_ACCESS_CONSUME: write(writer, "consume"); break;
+ }
+ write(writer, "\"");
+ }
+ }
+ break;
+
+ case SLANG_TYPE_KIND_CONSTANT_BUFFER:
+ write(writer, "\"kind\": \"constantBuffer\"");
+ write(writer, ",\n");
+ write(writer, "\"elementType\": ");
+ emitReflectionTypeJSON(
+ writer,
+ type->getElementType());
+ break;
+
+ case SLANG_TYPE_KIND_SCALAR:
+ write(writer, "\"kind\": \"scalar\"");
+ write(writer, ",\n");
+ emitReflectionScalarTypeInfoJSON(
+ writer,
+ type->getScalarType());
+ break;
+
+ case SLANG_TYPE_KIND_VECTOR:
+ write(writer, "\"kind\": \"vector\"");
+ write(writer, ",\n");
+ write(writer, "\"elementCount\": ");
+ write(writer, type->getElementCount());
+ write(writer, ",\n");
+ write(writer, "\"elementType\": ");
+ emitReflectionTypeJSON(
+ writer,
+ type->getElementType());
+ break;
+
+ case SLANG_TYPE_KIND_MATRIX:
+ write(writer, "\"kind\": \"matrix\"");
+ write(writer, ",\n");
+ write(writer, "\"rowCount\": ");
+ write(writer, type->getRowCount());
+ write(writer, ",\n");
+ write(writer, "\"columnCount\": ");
+ write(writer, type->getColumnCount());
+ write(writer, ",\n");
+ write(writer, "\"elementType\": ");
+ emitReflectionTypeJSON(
+ writer,
+ type->getElementType());
+ break;
+
+ case SLANG_TYPE_KIND_ARRAY:
+ {
+ auto arrayType = type;
+ write(writer, "\"kind\": \"array\"");
+ write(writer, ",\n");
+ write(writer, "\"elementCount\": ");
+ write(writer, arrayType->getElementCount());
+ write(writer, ",\n");
+ write(writer, "\"elementType\": ");
+ emitReflectionTypeJSON(writer, arrayType->getElementType());
+ }
+ break;
+
+ case SLANG_TYPE_KIND_STRUCT:
+ {
+ write(writer, "\"kind\": \"struct\",\n");
+ write(writer, "\"fields\": [\n");
+ indent(writer);
+
+ auto structType = type;
+ auto fieldCount = structType->getFieldCount();
+ for( uint32_t ff = 0; ff < fieldCount; ++ff )
+ {
+ if (ff != 0) write(writer, ",\n");
+ emitReflectionVarInfoJSON(
+ writer,
+ structType->getFieldByIndex(ff));
+ }
+ dedent(writer);
+ write(writer, "\n]");
+ }
+ break;
+
+ default:
+ assert(!"unimplemented");
+ break;
+ }
+}
+
+static void emitReflectionTypeLayoutInfoJSON(
+ PrettyWriter& writer,
+ slang::TypeLayoutReflection* typeLayout)
+{
+ switch( typeLayout->getKind() )
+ {
+ default:
+ emitReflectionTypeInfoJSON(writer, typeLayout->getType());
+ break;
+
+ case SLANG_TYPE_KIND_ARRAY:
+ {
+ auto arrayTypeLayout = typeLayout;
+ auto elementTypeLayout = arrayTypeLayout->getElementTypeLayout();
+ write(writer, "\"kind\": \"array\"");
+ write(writer, ",\n");
+ write(writer, "\"elementCount\": ");
+ write(writer, arrayTypeLayout->getElementCount());
+ write(writer, ",\n");
+ write(writer, "\"elementType\": ");
+ emitReflectionTypeLayoutJSON(
+ writer,
+ elementTypeLayout);
+ if (arrayTypeLayout->getSize(SLANG_PARAMETER_CATEGORY_UNIFORM) != 0)
+ {
+ write(writer, ",\n");
+ write(writer, "\"uniformStride\": ");
+ write(writer, arrayTypeLayout->getElementStride(SLANG_PARAMETER_CATEGORY_UNIFORM));
+ }
+ }
+ break;
+
+ case SLANG_TYPE_KIND_STRUCT:
+ {
+ write(writer, "\"kind\": \"struct\",\n");
+ write(writer, "\"fields\": [\n");
+ indent(writer);
+
+ auto structTypeLayout = typeLayout;
+ auto fieldCount = structTypeLayout->getFieldCount();
+ for( uint32_t ff = 0; ff < fieldCount; ++ff )
+ {
+ if (ff != 0) write(writer, ",\n");
+ emitReflectionVarLayoutJSON(
+ writer,
+ structTypeLayout->getFieldByIndex(ff));
+ }
+ dedent(writer);
+ write(writer, "\n]");
+ }
+ break;
+
+ case SLANG_TYPE_KIND_CONSTANT_BUFFER:
+ write(writer, "\"kind\": \"constantBuffer\"");
+ write(writer, ",\n");
+ write(writer, "\"elementType\": ");
+ emitReflectionTypeLayoutJSON(
+ writer,
+ typeLayout->getElementTypeLayout());
+ break;
+
+ }
+
+ // TODO: emit size info for types
+}
+
+static void emitReflectionTypeLayoutJSON(
+ PrettyWriter& writer,
+ slang::TypeLayoutReflection* typeLayout)
+{
+ write(writer, "{\n");
+ indent(writer);
+ emitReflectionTypeLayoutInfoJSON(writer, typeLayout);
+ dedent(writer);
+ write(writer, "\n}");
+}
+
+static void emitReflectionTypeJSON(
+ PrettyWriter& writer,
+ slang::TypeReflection* type)
+{
+ write(writer, "{\n");
+ indent(writer);
+ emitReflectionTypeInfoJSON(writer, type);
+ dedent(writer);
+ write(writer, "\n}");
+}
+
+static void emitReflectionVarInfoJSON(
+ PrettyWriter& writer,
+ slang::VariableReflection* var)
+{
+ emitReflectionNameInfoJSON(writer, var->getName());
+ write(writer, ",\n");
+
+ write(writer, "\"type\": ");
+ emitReflectionTypeJSON(writer, var->getType());
+}
+
+#if 0
+static void emitReflectionBindingInfoJSON(
+ PrettyWriter& writer,
+
+ ReflectionParameterNode* param)
+{
+ auto info = &param->binding;
+
+ if( info->category == SLANG_PARAMETER_CATEGORY_MIXED )
+ {
+ write(writer,"\"bindings\": [\n");
+ indent(writer);
+
+ ReflectionSize bindingCount = info->bindingCount;
+ assert(bindingCount);
+ ReflectionParameterBindingInfo* bindings = info->bindings;
+ for( ReflectionSize bb = 0; bb < bindingCount; ++bb )
+ {
+ if (bb != 0) write(writer, ",\n");
+
+ write(writer,"{");
+ auto& binding = bindings[bb];
+ emitReflectionVarBindingInfoJSON(
+ writer,
+ binding.category,
+ binding.index,
+ (ReflectionSize) param->GetTypeLayout()->GetSize(binding.category),
+ binding.space);
+
+ write(writer,"}");
+ }
+ dedent(writer);
+ write(writer,"\n]");
+ }
+ else
+ {
+ write(writer,"\"binding\": {");
+ indent(writer);
+
+ emitReflectionVarBindingInfoJSON(
+ writer,
+ info->category,
+ info->index,
+ (ReflectionSize) param->GetTypeLayout()->GetSize(info->category),
+ info->space);
+
+ dedent(writer);
+ write(writer,"}");
+ }
+}
+#endif
+
+static void emitReflectionParamJSON(
+ PrettyWriter& writer,
+ slang::VariableLayoutReflection* param)
+{
+ write(writer, "{\n");
+ indent(writer);
+
+ emitReflectionNameInfoJSON(writer, param->getName());
+ write(writer, ",\n");
+
+ emitReflectionVarBindingInfoJSON(writer, param);
+ write(writer, ",\n");
+
+ write(writer, "\"type\": ");
+ emitReflectionTypeLayoutJSON(writer, param->getTypeLayout());
+
+ dedent(writer);
+ write(writer, "\n}");
+}
+
+template<typename T>
+struct Range
+{
+public:
+ Range(
+ T begin,
+ T end)
+ : mBegin(begin)
+ , mEnd(end)
+ {}
+
+ struct Iterator
+ {
+ public:
+ explicit Iterator(T value)
+ : mValue(value)
+ {}
+
+ T operator*() const { return mValue; }
+ void operator++() { mValue++; }
+
+ bool operator!=(Iterator const& other)
+ {
+ return mValue != other.mValue;
+ }
+
+ private:
+ T mValue;
+ };
+
+ Iterator begin() const { return Iterator(mBegin); }
+ Iterator end() const { return Iterator(mEnd); }
+
+private:
+ T mBegin;
+ T mEnd;
+};
+
+template<typename T>
+Range<T> range(T begin, T end)
+{
+ return Range<T>(begin, end);
+}
+
+template<typename T>
+Range<T> range(T end)
+{
+ return Range<T>(T(0), end);
+}
+
+static void emitReflectionJSON(
+ PrettyWriter& writer,
+ slang::ShaderReflection* programReflection)
+{
+ write(writer, "{\n");
+ indent(writer);
+ write(writer, "\"parameters\": [\n");
+ indent(writer);
+
+ auto parameterCount = programReflection->getParameterCount();
+ for( auto pp : range(parameterCount) )
+ {
+ if(pp != 0) write(writer, ",\n");
+
+ auto parameter = programReflection->getParameterByIndex(pp);
+ emitReflectionParamJSON(writer, parameter);
+ }
+
+ dedent(writer);
+ write(writer, "\n]");
+ dedent(writer);
+ write(writer, "\n}\n");
+}
+
+#if 0
+ReflectionBlob* ReflectionBlob::Create(
+ CollectionOfTranslationUnits* program)
+{
+ ReflectionGenerationContext context;
+ ReflectionBlob* blob = GenerateReflectionBlob(&context, program);
+#if 0
+ String debugDump = blob->emitAsJSON();
+ OutputDebugStringA("REFLECTION BLOB\n");
+ OutputDebugStringA(debugDump.begin());
+#endif
+ return blob;
+}
+#endif
+
+// JSON emit logic
+
+
+
+String emitReflectionJSON(
+ ProgramLayout* programLayout)
+{
+ auto programReflection = (slang::ShaderReflection*) programLayout;
+
+ PrettyWriter writer;
+ emitReflectionJSON(writer, programReflection);
+ return writer.sb.ProduceString();
+}
+
+}}
diff --git a/source/slang/reflection.h b/source/slang/reflection.h
new file mode 100644
index 000000000..4d2c53084
--- /dev/null
+++ b/source/slang/reflection.h
@@ -0,0 +1,39 @@
+#ifndef SLANG_REFLECTION_H
+#define SLANG_REFLECTION_H
+
+#include "../core/basic.h"
+#include "syntax.h"
+
+#include "../../slang.h"
+
+namespace Slang {
+
+// TODO(tfoley): Need to move these somewhere universal
+
+typedef intptr_t Int;
+typedef int64_t Int64;
+
+typedef uintptr_t UInt;
+typedef uint64_t UInt64;
+
+namespace Compiler {
+
+class ProgramLayout;
+class TypeLayout;
+
+String emitReflectionJSON(
+ ProgramLayout* programLayout);
+
+//
+
+SlangTypeKind getReflectionTypeKind(ExpressionType* type);
+
+SlangTypeKind getReflectionParameterCategory(TypeLayout* typeLayout);
+
+UInt getReflectionFieldCount(ExpressionType* type);
+UInt getReflectionFieldByIndex(ExpressionType* type, UInt index);
+UInt getReflectionFieldByIndex(TypeLayout* typeLayout, UInt index);
+
+}}
+
+#endif // SLANG_REFLECTION_H
diff --git a/source/slang/slang-stdlib.cpp b/source/slang/slang-stdlib.cpp
new file mode 100644
index 000000000..f74fcd603
--- /dev/null
+++ b/source/slang/slang-stdlib.cpp
@@ -0,0 +1,1855 @@
+// slang-stdlib.cpp
+
+#include "slang-stdlib.h"
+#include "syntax.h"
+
+#define STRINGIZE(x) STRINGIZE2(x)
+#define STRINGIZE2(x) #x
+#define LINE_STRING STRINGIZE(__LINE__)
+
+enum { kLibIncludeStringLine = __LINE__+1 };
+const char * LibIncludeStringChunks[] = { R"(
+
+typedef uint UINT;
+
+__generic<T> __intrinsic(Assign) T operator=(out T left, T right);
+
+__generic<T,U> __intrinsic(Sequence) U operator,(T left, U right);
+
+__generic<T> __intrinsic(Select) T operator?:(bool condition, T ifTrue, T ifFalse);
+__generic<T, let N : int> __intrinsic(Select) vector<T,N> operator?:(vector<bool,N> condition, vector<T,N> ifTrue, vector<T,N> ifFalse);
+
+__generic<T> __magic_type(HLSLAppendStructuredBufferType) struct AppendStructuredBuffer
+{
+ __intrinsic void Append(T value);
+
+ __intrinsic void GetDimensions(
+ out uint numStructs,
+ out uint stride);
+};
+
+__generic<T> __magic_type(HLSLBufferType) struct Buffer
+{
+ __intrinsic void GetDimensions(
+ out uint dim);
+
+ __intrinsic T Load(int location);
+ __intrinsic T Load(int location, out uint status);
+
+ __intrinsic __subscript(uint index) -> T;
+};
+
+__magic_type(HLSLByteAddressBufferType) struct ByteAddressBuffer
+{
+ __intrinsic void GetDimensions(
+ out uint dim);
+
+ __intrinsic uint Load(int location);
+ __intrinsic uint Load(int location, out uint status);
+
+ __intrinsic uint2 Load2(int location);
+ __intrinsic uint2 Load2(int location, out uint status);
+
+ __intrinsic uint3 Load3(int location);
+ __intrinsic uint3 Load3(int location, out uint status);
+
+ __intrinsic uint4 Load4(int location);
+ __intrinsic uint4 Load4(int location, out uint status);
+};
+
+__generic<T> __magic_type(HLSLStructuredBufferType) struct StructuredBuffer
+{
+ __intrinsic void GetDimensions(
+ out uint numStructs,
+ out uint stride);
+
+ __intrinsic T Load(int location);
+ __intrinsic T Load(int location, out uint status);
+
+ __intrinsic __subscript(uint index) -> T;
+};
+
+__generic<T> __magic_type(HLSLConsumeStructuredBufferType) struct ConsumeStructuredBuffer
+{
+ __intrinsic T Consume();
+
+ __intrinsic void GetDimensions(
+ out uint numStructs,
+ out uint stride);
+};
+
+__generic<T> __magic_type(HLSLInputPatchType) struct InputPatch
+{
+ __intrinsic __subscript(uint index) -> T;
+};
+
+__generic<T> __magic_type(HLSLOutputPatchType) struct OutputPatch
+{
+ __intrinsic __subscript(uint index) -> T { set; }
+};
+
+__generic<T> __magic_type(HLSLRWBufferType) struct RWBuffer
+{
+ // Note(tfoley): duplication with declaration of `Buffer`
+
+ __intrinsic void GetDimensions(
+ out uint dim);
+
+ __intrinsic T Load(int location);
+ __intrinsic T Load(int location, out uint status);
+
+ __intrinsic __subscript(uint index) -> T { get; set; }
+};
+
+__magic_type(HLSLRWByteAddressBufferType) struct RWByteAddressBuffer
+{
+ // Note(tfoley): supports alll operations from `ByteAddressBuffer`
+ // TODO(tfoley): can this be made a sub-type?
+
+ __intrinsic void GetDimensions(
+ out uint dim);
+
+ __intrinsic uint Load(int location);
+ __intrinsic uint Load(int location, out uint status);
+
+ __intrinsic uint2 Load2(int location);
+ __intrinsic uint2 Load2(int location, out uint status);
+
+ __intrinsic uint3 Load3(int location);
+ __intrinsic uint3 Load3(int location, out uint status);
+
+ __intrinsic uint4 Load4(int location);
+ __intrinsic uint4 Load4(int location, out uint status);
+
+ // Added operations:
+
+ __intrinsic void InterlockedAdd(
+ UINT dest,
+ UINT value,
+ out UINT original_value);
+ __intrinsic void InterlockedAdd(
+ UINT dest,
+ UINT value);
+
+ __intrinsic void InterlockedAnd(
+ UINT dest,
+ UINT value,
+ out UINT original_value);
+ __intrinsic void InterlockedAnd(
+ UINT dest,
+ UINT value);
+
+ __intrinsic void InterlockedCompareExchange(
+ UINT dest,
+ UINT compare_value,
+ UINT value,
+ out UINT original_value);
+ __intrinsic void InterlockedCompareExchange(
+ UINT dest,
+ UINT compare_value,
+ UINT value);
+
+ __intrinsic void InterlockedCompareStore(
+ UINT dest,
+ UINT compare_value,
+ UINT value);
+ __intrinsic void InterlockedCompareStore(
+ UINT dest,
+ UINT compare_value);
+
+ __intrinsic void InterlockedExchange(
+ UINT dest,
+ UINT value,
+ out UINT original_value);
+ __intrinsic void InterlockedExchange(
+ UINT dest,
+ UINT value);
+
+ __intrinsic void InterlockedMax(
+ UINT dest,
+ UINT value,
+ out UINT original_value);
+ __intrinsic void InterlockedMax(
+ UINT dest,
+ UINT value);
+
+ __intrinsic void InterlockedMin(
+ UINT dest,
+ UINT value,
+ out UINT original_value);
+ __intrinsic void InterlockedMin(
+ UINT dest,
+ UINT value);
+
+ __intrinsic void InterlockedOr(
+ UINT dest,
+ UINT value,
+ out UINT original_value);
+ __intrinsic void InterlockedOr(
+ UINT dest,
+ UINT value);
+
+ __intrinsic void InterlockedXor(
+ UINT dest,
+ UINT value,
+ out UINT original_value);
+ __intrinsic void InterlockedXor(
+ UINT dest,
+ UINT value);
+
+ __intrinsic void Store(
+ uint address,
+ uint value);
+
+ __intrinsic void Store2(
+ uint address,
+ uint2 value);
+
+ __intrinsic void Store3(
+ uint address,
+ uint3 value);
+
+ __intrinsic void Store4(
+ uint address,
+ uint4 value);
+};
+
+__generic<T> __magic_type(HLSLRWStructuredBufferType) struct RWStructuredBuffer
+{
+ __intrinsic uint DecrementCounter();
+
+ __intrinsic void GetDimensions(
+ out uint numStructs,
+ out uint stride);
+
+ __intrinsic void IncrementCounter();
+
+ __intrinsic T Load(int location);
+ __intrinsic T Load(int location, out uint status);
+
+ __intrinsic __subscript(uint index) -> T { get; set; }
+};
+
+__generic<T> __magic_type(HLSLPointStreamType) struct PointStream {};
+__generic<T> __magic_type(HLSLLineStreamType) struct LineStream {};
+__generic<T> __magic_type(HLSLLineStreamType) struct TriangleStream {};
+
+)", R"(
+
+// Note(tfoley): Trying to systematically add all the HLSL builtins
+
+// A type that can be used as an operand for builtins
+__trait __BuiltinType {}
+
+// A type that can be used for arithmetic operations
+__trait __BuiltinArithmeticType : __BuiltinType {}
+
+// A type that logically has a sign (positive/negative/zero)
+__trait __BuiltinSignedArithmeticType : __BuiltinArithmeticType {}
+
+// A type that can represent integers
+__trait __BuiltinIntegerType : __BuiltinArithmeticType {}
+
+// A type that can represent non-integers
+__trait __BuiltinRealType : __BuiltinArithmeticType {}
+
+// A type that uses a floating-point representation
+__trait __BuiltinFloatingPointType : __BuiltinRealType, __BuiltinSignedType {}
+
+// Try to terminate the current draw or dispatch call (HLSL SM 4.0)
+__intrinsic void abort();
+
+// Absolute value (HLSL SM 1.0)
+__generic<T : __BuiltinSignedArithmeticType> __intrinsic T abs(T x);
+__generic<T : __BuiltinSignedArithmeticType, let N : int> __intrinsic vector<T,N> abs(vector<T,N> x);
+__generic<T : __BuiltinSignedArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> abs(matrix<T,N,M> x);
+
+// Inverse cosine (HLSL SM 1.0)
+__generic<T : __BuiltinFloatingPointType> __intrinsic T acos(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> acos(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> acos(matrix<T,N,M> x);
+
+// Test if all components are non-zero (HLSL SM 1.0)
+__generic<T : __BuiltinType> __intrinsic T all(T x);
+__generic<T : __BuiltinType, let N : int> __intrinsic vector<T,N> all(vector<T,N> x);
+__generic<T : __BuiltinType, let N : int, let M : int> __intrinsic matrix<T,N,M> all(matrix<T,N,M> x);
+
+// Barrier for writes to all memory spaces (HLSL SM 5.0)
+__intrinsic void AllMemoryBarrier();
+
+// Thread-group sync and barrier for writes to all memory spaces (HLSL SM 5.0)
+__intrinsic void AllMemoryBarrierWithGroupSync();
+
+// Test if any components is non-zero (HLSL SM 1.0)
+__generic<T : __BuiltinType> __intrinsic T any(T x);
+__generic<T : __BuiltinType, let N : int> __intrinsic vector<T,N> any(vector<T,N> x);
+__generic<T : __BuiltinType, let N : int, let M : int> __intrinsic matrix<T,N,M> any(matrix<T,N,M> x);
+
+
+// Reinterpret bits as a double (HLSL SM 5.0)
+__intrinsic double asdouble(uint lowbits, uint highbits);
+
+// Reinterpret bits as a float (HLSL SM 4.0)
+__intrinsic float asfloat( int x);
+__intrinsic float asfloat(uint x);
+__generic<let N : int> __intrinsic vector<float,N> asfloat(vector< int,N> x);
+__generic<let N : int> __intrinsic vector<float,N> asfloat(vector<uint,N> x);
+__generic<let N : int, let M : int> __intrinsic matrix<float,N,M> asfloat(matrix< int,N,M> x);
+__generic<let N : int, let M : int> __intrinsic matrix<float,N,M> asfloat(matrix<uint,N,M> x);
+
+
+// Inverse sine (HLSL SM 1.0)
+__generic<T : __BuiltinFloatingPointType> __intrinsic T asin(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> asin(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> asin(matrix<T,N,M> x);
+
+// Reinterpret bits as an int (HLSL SM 4.0)
+__intrinsic int asint(float x);
+__intrinsic int asint(uint x);
+__generic<let N : int> __intrinsic vector<int,N> asint(vector<float,N> x);
+__generic<let N : int> __intrinsic vector<int,N> asint(vector<uint,N> x);
+__generic<let N : int, let M : int> __intrinsic matrix<int,N,M> asint(matrix<float,N,M> x);
+__generic<let N : int, let M : int> __intrinsic matrix<int,N,M> asint(matrix<uint,N,M> x);
+
+// Reinterpret bits of double as a uint (HLSL SM 5.0)
+__intrinsic void asuint(double value, out uint lowbits, out uint highbits);
+
+// Reinterpret bits as a uint (HLSL SM 4.0)
+__intrinsic uint asuint(float x);
+__intrinsic uint asuint(int x);
+__generic<let N : int> __intrinsic vector<uint,N> asuint(vector<float,N> x);
+__generic<let N : int> __intrinsic vector<uint,N> asuint(vector<int,N> x);
+__generic<let N : int, let M : int> __intrinsic matrix<uint,N,M> asuint(matrix<float,N,M> x);
+__generic<let N : int, let M : int> __intrinsic matrix<uint,N,M> asuint(matrix<int,N,M> x);
+
+// Inverse tangent (HLSL SM 1.0)
+__generic<T : __BuiltinFloatingPointType> __intrinsic T atan(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> atan(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> atan(matrix<T,N,M> x);
+
+__generic<T : __BuiltinFloatingPointType> __intrinsic T atan2(T y, T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> atan2(vector<T,N> y, vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> atan2(matrix<T,N,M> y, matrix<T,N,M> x);
+
+// Ceiling (HLSL SM 1.0)
+__generic<T : __BuiltinFloatingPointType> __intrinsic T ceil(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> ceil(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> ceil(matrix<T,N,M> x);
+
+
+// Check access status to tiled resource
+__intrinsic bool CheckAccessFullyMapped(uint status);
+
+// Clamp (HLSL SM 1.0)
+__generic<T : __BuiltinArithmeticType> __intrinsic T clamp(T x, T min, T max);
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> clamp(vector<T,N> x, vector<T,N> min, vector<T,N> max);
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> clamp(matrix<T,N,M> x, matrix<T,N,M> min, matrix<T,N,M> max);
+
+// Clip (discard) fragment conditionally
+__generic<T : __BuiltinFloatingPointType> __intrinsic void clip(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic void clip(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic void clip(matrix<T,N,M> x);
+
+// Cosine
+__generic<T : __BuiltinFloatingPointType> __intrinsic T cos(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> cos(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> cos(matrix<T,N,M> x);
+
+// Hyperbolic cosine
+__generic<T : __BuiltinFloatingPointType> __intrinsic T cosh(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> cosh(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> cosh(matrix<T,N,M> x);
+
+// Population count
+__intrinsic uint countbits(uint value);
+
+// Cross product
+__generic<T : __BuiltinArithmeticType> __intrinsic vector<T,3> cross(vector<T,3> x, vector<T,3> y);
+
+// Convert encoded color
+__intrinsic int4 D3DCOLORtoUBYTE4(float4 x);
+
+// Partial-difference derivatives
+__generic<T : __BuiltinFloatingPointType> __intrinsic T ddx(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> ddx(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> ddx(matrix<T,N,M> x);
+
+__generic<T : __BuiltinFloatingPointType> __intrinsic T ddx_coarse(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> ddx_coarse(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> ddx_coarse(matrix<T,N,M> x);
+
+__generic<T : __BuiltinFloatingPointType> __intrinsic T ddx_fine(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> ddx_fine(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> ddx_fine(matrix<T,N,M> x);
+
+__generic<T : __BuiltinFloatingPointType> __intrinsic T ddy(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> ddy(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> ddy(matrix<T,N,M> x);
+
+__generic<T : __BuiltinFloatingPointType> __intrinsic T ddy_coarse(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> ddy_coarse(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> ddy_coarse(matrix<T,N,M> x);
+
+__generic<T : __BuiltinFloatingPointType> __intrinsic T ddy_fine(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> ddy_fine(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> ddy_fine(matrix<T,N,M> x);
+
+
+// Radians to degrees
+__generic<T : __BuiltinFloatingPointType> __intrinsic T degrees(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> degrees(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> degrees(matrix<T,N,M> x);
+
+// Matrix determinant
+
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic T determinant(matrix<T,N,N> m);
+
+// Barrier for device memory
+__intrinsic void DeviceMemoryBarrier();
+__intrinsic void DeviceMemoryBarrierWithGroupSync();
+
+// Vector distance
+
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic T distance(vector<T,N> x, vector<T,N> y);
+
+// Vector dot product
+
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic T dot(vector<T,N> x, vector<T,N> y);
+
+// Helper for computing distance terms for lighting (obsolete)
+
+__generic<T : __BuiltinFloatingPointType> __intrinsic vector<T,4> dst(vector<T,4> x, vector<T,4> y);
+
+// Error message
+
+// __intrinsic void errorf( string format, ... );
+
+// Attribute evaluation
+
+__generic<T : __BuiltinArithmeticType> __intrinsic T EvaluateAttributeAtCentroid(T x);
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> EvaluateAttributeAtCentroid(vector<T,N> x);
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> EvaluateAttributeAtCentroid(matrix<T,N,M> x);
+
+__generic<T : __BuiltinArithmeticType> __intrinsic T EvaluateAttributeAtSample(T x, uint sampleindex);
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> EvaluateAttributeAtSample(vector<T,N> x, uint sampleindex);
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> EvaluateAttributeAtSample(matrix<T,N,M> x, uint sampleindex);
+
+__generic<T : __BuiltinArithmeticType> __intrinsic T EvaluateAttributeSnapped(T x, int2 offset);
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> EvaluateAttributeSnapped(vector<T,N> x, int2 offset);
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> EvaluateAttributeSnapped(matrix<T,N,M> x, int2 offset);
+
+// Base-e exponent
+__generic<T : __BuiltinFloatingPointType> __intrinsic T exp(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> exp(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> exp(matrix<T,N,M> x);
+
+// Base-2 exponent
+__generic<T : __BuiltinFloatingPointType> __intrinsic T exp2(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> exp2(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> exp2(matrix<T,N,M> x);
+
+// Convert 16-bit float stored in low bits of integer
+__intrinsic float f16tof32(uint value);
+__generic<let N : int> __intrinsic vector<float,N> f16tof32(vector<uint,N> value);
+
+// Convert to 16-bit float stored in low bits of integer
+__intrinsic uint f32tof16(float value);
+__generic<let N : int> __intrinsic vector<uint,N> f32tof16(vector<float,N> value);
+
+// Flip surface normal to face forward, if needed
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> faceforward(vector<T,N> n, vector<T,N> i, vector<T,N> ng);
+
+// Find first set bit starting at high bit and working down
+__intrinsic int firstbithigh(int value);
+__generic<let N : int> __intrinsic vector<int,N> firstbithigh(vector<int,N> value);
+
+__intrinsic uint firstbithigh(uint value);
+__generic<let N : int> __intrinsic vector<uint,N> firstbithigh(vector<uint,N> value);
+
+// Find first set bit starting at low bit and working up
+__intrinsic int firstbitlow(int value);
+__generic<let N : int> __intrinsic vector<int,N> firstbitlow(vector<int,N> value);
+
+__intrinsic uint firstbitlow(uint value);
+__generic<let N : int> __intrinsic vector<uint,N> firstbitlow(vector<uint,N> value);
+
+// Floor (HLSL SM 1.0)
+__generic<T : __BuiltinFloatingPointType> __intrinsic T floor(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> floor(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> floor(matrix<T,N,M> x);
+
+// Fused multiply-add for doubles
+__intrinsic double fma(double a, double b, double c);
+__generic<let N : int> __intrinsic vector<double, N> fma(vector<double, N> a, vector<double, N> b, vector<double, N> c);
+__generic<let N : int, let M : int> __intrinsic matrix<double,N,M> fma(matrix<double,N,M> a, matrix<double,N,M> b, matrix<double,N,M> c);
+
+// Floating point remainder of x/y
+__generic<T : __BuiltinFloatingPointType> __intrinsic T fmod(T x, T y);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> fmod(vector<T,N> x, vector<T,N> y);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> fmod(matrix<T,N,M> x, matrix<T,N,M> y);
+
+// Fractional part
+__generic<T : __BuiltinFloatingPointType> __intrinsic T frac(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> frac(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> frac(matrix<T,N,M> x);
+
+// Split float into mantissa and exponent
+__generic<T : __BuiltinFloatingPointType> __intrinsic T frexp(T x, out T exp);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> frexp(vector<T,N> x, out vector<T,N> exp);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> frexp(matrix<T,N,M> x, out matrix<T,N,M> exp);
+
+// Texture filter width
+__generic<T : __BuiltinFloatingPointType> __intrinsic T fwidth(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> fwidth(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> fwidth(matrix<T,N,M> x);
+
+)", R"(
+
+// Get number of samples in render target
+__intrinsic uint GetRenderTargetSampleCount();
+
+// Get position of given sample
+__intrinsic float2 GetRenderTargetSamplePosition(int Index);
+
+// Group memory barrier
+__intrinsic void GroupMemoryBarrier();
+__intrinsic void GroupMemoryBarrierWithGroupSync();
+
+// Atomics
+__intrinsic void InterlockedAdd(in out int dest, int value, out int original_value);
+__intrinsic void InterlockedAdd(in out uint dest, uint value, out uint original_value);
+
+__intrinsic void InterlockedAnd(in out int dest, int value, out int original_value);
+__intrinsic void InterlockedAnd(in out uint dest, uint value, out uint original_value);
+
+__intrinsic void InterlockedCompareExchange(in out int dest, int compare_value, int value, out int original_value);
+__intrinsic void InterlockedCompareExchange(in out uint dest, uint compare_value, uint value, out uint original_value);
+
+__intrinsic void InterlockedCompareStore(in out int dest, int compare_value, int value);
+__intrinsic void InterlockedCompareStore(in out uint dest, uint compare_value, uint value);
+
+__intrinsic void InterlockedExchange(in out int dest, int value, out int original_value);
+__intrinsic void InterlockedExchange(in out uint dest, uint value, out uint original_value);
+
+__intrinsic void InterlockedMax(in out int dest, int value, out int original_value);
+__intrinsic void InterlockedMax(in out uint dest, uint value, out uint original_value);
+
+__intrinsic void InterlockedMin(in out int dest, int value, out int original_value);
+__intrinsic void InterlockedMin(in out uint dest, uint value, out uint original_value);
+
+__intrinsic void InterlockedOr(in out int dest, int value, out int original_value);
+__intrinsic void InterlockedOr(in out uint dest, uint value, out uint original_value);
+
+__intrinsic void InterlockedXor(in out int dest, int value, out int original_value);
+__intrinsic void InterlockedXor(in out uint dest, uint value, out uint original_value);
+
+// Is floating-point value finite?
+__generic<T : __BuiltinFloatingPointType> __intrinsic bool isfinite(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<bool,N> isfinite(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<bool,N,M> isfinite(matrix<T,N,M> x);
+
+// Is floating-point value infinite?
+__generic<T : __BuiltinFloatingPointType> __intrinsic bool isinf(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<bool,N> isinf(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<bool,N,M> isinf(matrix<T,N,M> x);
+
+// Is floating-point value not-a-number?
+__generic<T : __BuiltinFloatingPointType> __intrinsic bool isnan(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<bool,N> isnan(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<bool,N,M> isnan(matrix<T,N,M> x);
+
+// Construct float from mantissa and exponent
+__generic<T : __BuiltinFloatingPointType> __intrinsic T ldexp(T x, T exp);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> ldexp(vector<T,N> x, vector<T,N> exp);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> ldexp(matrix<T,N,M> x, matrix<T,N,M> exp);
+
+// Vector length
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic T length(vector<T,N> x);
+
+// Linear interpolation
+__generic<T : __BuiltinFloatingPointType> __intrinsic T lerp(T x, T y, T s);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> lerp(vector<T,N> x, vector<T,N> y, vector<T,N> s);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> lerp(matrix<T,N,M> x, matrix<T,N,M> y, matrix<T,N,M> s);
+
+// Legacy lighting function (obsolete)
+__intrinsic float4 lit(float n_dot_l, float n_dot_h, float m);
+
+// Base-e logarithm
+__generic<T : __BuiltinFloatingPointType> __intrinsic T log(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> log(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> log(matrix<T,N,M> x);
+
+// Base-10 logarithm
+__generic<T : __BuiltinFloatingPointType> __intrinsic T log10(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> log10(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> log10(matrix<T,N,M> x);
+
+// Base-2 logarithm
+__generic<T : __BuiltinFloatingPointType> __intrinsic T log2(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> log2(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> log2(matrix<T,N,M> x);
+
+// multiply-add
+__generic<T : __BuiltinArithmeticType> __intrinsic T mad(T mvalue, T avalue, T bvalue);
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> mad(vector<T,N> mvalue, vector<T,N> avalue, vector<T,N> bvalue);
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> mad(matrix<T,N,M> mvalue, matrix<T,N,M> avalue, matrix<T,N,M> bvalue);
+
+// maximum
+__generic<T : __BuiltinArithmeticType> __intrinsic T max(T x, T y);
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> max(vector<T,N> x, vector<T,N> y);
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> max(matrix<T,N,M> x, matrix<T,N,M> y);
+
+// minimum
+__generic<T : __BuiltinArithmeticType> __intrinsic T min(T x, T y);
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> min(vector<T,N> x, vector<T,N> y);
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> min(matrix<T,N,M> x, matrix<T,N,M> y);
+
+// split into integer and fractional parts (both with same sign)
+__generic<T : __BuiltinFloatingPointType> __intrinsic T modf(T x, out T ip);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> modf(vector<T,N> x, out vector<T,N> ip);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> modf(matrix<T,N,M> x, out matrix<T,N,M> ip);
+
+// msad4 (whatever that is)
+__intrinsic uint4 msad4(uint reference, uint2 source, uint4 accum);
+
+// General inner products
+
+// scalar-scalar
+__generic<T : __BuiltinArithmeticType> __intrinsic(Mul_Scalar_Scalar) T mul(T x, T y);
+
+// scalar-vector and vector-scalar
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic(Mul_Vector_Scalar) vector<T,N> mul(vector<T,N> x, T y);
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic(Mul_Scalar_Vector) vector<T,N> mul(T x, vector<T,N> y);
+
+// scalar-matrix and matrix-scalar
+__generic<T : __BuiltinArithmeticType, let N : int, let M :int> __intrinsic(Mul_Matrix_Scalar) matrix<T,N,M> mul(matrix<T,N,M> x, T y);
+__generic<T : __BuiltinArithmeticType, let N : int, let M :int> __intrinsic(Mul_Scalar_Matrix) matrix<T,N,M> mul(T x, matrix<T,N,M> y);
+
+// vector-vector (dot product)
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic(InnerProduct_Vector_Vector) T mul(vector<T,N> x, vector<T,N> y);
+
+// vector-matrix
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic(InnerProduct_Vector_Matrix) vector<T,M> mul(vector<T,N> x, matrix<T,N,M> y);
+
+// matrix-vector
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic(InnerProduct_Matrix_Vector) vector<T,N> mul(matrix<T,N,M> x, vector<T,M> y);
+
+// matrix-matrix
+__generic<T : __BuiltinArithmeticType, let R : int, let N : int, let C : int> __intrinsic(InnerProduct_Matrix_Matrix) matrix<T,R,C> mul(matrix<T,R,N> x, matrix<T,N,C> y);
+
+// noise (deprecated)
+__intrinsic float noise(float x);
+__generic<let N : int> __intrinsic float noise(vector<float, N> x);
+
+// Normalize a vector
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> normalize(vector<T,N> x);
+
+// Raise to a power
+__generic<T : __BuiltinFloatingPointType> __intrinsic T pow(T x, T y);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> pow(vector<T,N> x, vector<T,N> y);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> pow(matrix<T,N,M> x, matrix<T,N,M> y);
+
+// Output message
+
+// __intrinsic void printf( string format, ... );
+
+// Tessellation factor fixup routines
+
+__intrinsic void Process2DQuadTessFactorsAvg(
+ in float4 RawEdgeFactors,
+ in float2 InsideScale,
+ out float4 RoundedEdgeTessFactors,
+ out float2 RoundedInsideTessFactors,
+ out float2 UnroundedInsideTessFactors);
+
+__intrinsic void Process2DQuadTessFactorsMax(
+ in float4 RawEdgeFactors,
+ in float2 InsideScale,
+ out float4 RoundedEdgeTessFactors,
+ out float2 RoundedInsideTessFactors,
+ out float2 UnroundedInsideTessFactors);
+
+__intrinsic void Process2DQuadTessFactorsMin(
+ in float4 RawEdgeFactors,
+ in float2 InsideScale,
+ out float4 RoundedEdgeTessFactors,
+ out float2 RoundedInsideTessFactors,
+ out float2 UnroundedInsideTessFactors);
+
+__intrinsic void ProcessIsolineTessFactors(
+ in float RawDetailFactor,
+ in float RawDensityFactor,
+ out float RoundedDetailFactor,
+ out float RoundedDensityFactor);
+
+__intrinsic void ProcessQuadTessFactorsAvg(
+ in float4 RawEdgeFactors,
+ in float InsideScale,
+ out float4 RoundedEdgeTessFactors,
+ out float2 RoundedInsideTessFactors,
+ out float2 UnroundedInsideTessFactors);
+
+__intrinsic void ProcessQuadTessFactorsMax(
+ in float4 RawEdgeFactors,
+ in float InsideScale,
+ out float4 RoundedEdgeTessFactors,
+ out float2 RoundedInsideTessFactors,
+ out float2 UnroundedInsideTessFactors);
+
+__intrinsic void ProcessQuadTessFactorsMin(
+ in float4 RawEdgeFactors,
+ in float InsideScale,
+ out float4 RoundedEdgeTessFactors,
+ out float2 RoundedInsideTessFactors,
+ out float2 UnroundedInsideTessFactors);
+
+__intrinsic void ProcessTriTessFactorsAvg(
+ in float3 RawEdgeFactors,
+ in float InsideScale,
+ out float3 RoundedEdgeTessFactors,
+ out float RoundedInsideTessFactor,
+ out float UnroundedInsideTessFactor);
+
+__intrinsic void ProcessTriTessFactorsMax(
+ in float3 RawEdgeFactors,
+ in float InsideScale,
+ out float3 RoundedEdgeTessFactors,
+ out float RoundedInsideTessFactor,
+ out float UnroundedInsideTessFactor);
+
+__intrinsic void ProcessTriTessFactorsMin(
+ in float3 RawEdgeFactors,
+ in float InsideScale,
+ out float3 RoundedEdgeTessFactors,
+ out float RoundedInsideTessFactors,
+ out float UnroundedInsideTessFactors);
+
+// Degrees to radians
+__generic<T : __BuiltinFloatingPointType> __intrinsic T radians(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> radians(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> radians(matrix<T,N,M> x);
+
+// Approximate reciprocal
+__generic<T : __BuiltinFloatingPointType> __intrinsic T rcp(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> rcp(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> rcp(matrix<T,N,M> x);
+
+// Reflect incident vector across plane with given normal
+__generic<T : __BuiltinFloatingPointType, let N : int>
+__intrinsic
+vector<T,N> reflect(vector<T,N> i, vector<T,N> n);
+
+// Refract incident vector given surface normal and index of refraction
+__generic<T : __BuiltinFloatingPointType, let N : int>
+__intrinsic
+vector<T,N> refract(vector<T,N> i, vector<T,N> n, float eta);
+
+// Reverse order of bits
+__intrinsic uint reversebits(uint value);
+__generic<let N : int> vector<uint,N> reversebits(vector<uint,N> value);
+
+// Round-to-nearest
+__generic<T : __BuiltinFloatingPointType> __intrinsic T round(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> round(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> round(matrix<T,N,M> x);
+
+// Reciprocal of square root
+__generic<T : __BuiltinFloatingPointType> __intrinsic T rsqrt(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> rsqrt(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> rsqrt(matrix<T,N,M> x);
+
+// Clamp value to [0,1] range
+__generic<T : __BuiltinFloatingPointType> __intrinsic T saturate(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> saturate(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> saturate(matrix<T,N,M> x);
+
+
+// Extract sign of value
+__generic<T : __BuiltinSignedArithmeticType> __intrinsic int sign(T x);
+__generic<T : __BuiltinSignedArithmeticType, let N : int> __intrinsic vector<int,N> sign(vector<T,N> x);
+__generic<T : __BuiltinSignedArithmeticType, let N : int, let M : int> __intrinsic matrix<int,N,M> sign(matrix<T,N,M> x);
+
+)", R"(
+
+
+// Sine
+__generic<T : __BuiltinFloatingPointType> __intrinsic T sin(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> sin(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> sin(matrix<T,N,M> x);
+
+// Sine and cosine
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic void sincos(T x, out T s, out T c);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic void sincos(vector<T,N> x, out vector<T,N> s, out vector<T,N> c);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic void sincos(matrix<T,N,M> x, out matrix<T,N,M> s, out matrix<T,N,M> c);
+
+// Hyperbolic Sine
+__generic<T : __BuiltinFloatingPointType> __intrinsic T sinh(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> sinh(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> sinh(matrix<T,N,M> x);
+
+// Smooth step (Hermite interpolation)
+__generic<T : __BuiltinFloatingPointType> __intrinsic T smoothstep(T min, T max, T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> smoothstep(vector<T,N> min, vector<T,N> max, vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> smoothstep(matrix<T,N,M> min, matrix<T,N,M> max, matrix<T,N,M> x);
+
+// Square root
+__generic<T : __BuiltinFloatingPointType> __intrinsic T sqrt(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> sqrt(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> sqrt(matrix<T,N,M> x);
+
+// Step function
+__generic<T : __BuiltinFloatingPointType> __intrinsic T step(T y, T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> step(vector<T,N> y, vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> step(matrix<T,N,M> y, matrix<T,N,M> x);
+
+// Tangent
+__generic<T : __BuiltinFloatingPointType> __intrinsic T tan(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> tan(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> tan(matrix<T,N,M> x);
+
+// Hyperbolic tangent
+__generic<T : __BuiltinFloatingPointType> __intrinsic T tanh(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> tanh(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> tanh(matrix<T,N,M> x);
+
+// Legacy texture-fetch operations
+
+/*
+__intrinsic float4 tex1D(sampler1D s, float t);
+__intrinsic float4 tex1D(sampler1D s, float t, float ddx, float ddy);
+__intrinsic float4 tex1Dbias(sampler1D s, float4 t);
+__intrinsic float4 tex1Dgrad(sampler1D s, float t, float ddx, float ddy);
+__intrinsic float4 tex1Dlod(sampler1D s, float4 t);
+__intrinsic float4 tex1Dproj(sampler1D s, float4 t);
+
+__intrinsic float4 tex2D(sampler2D s, float2 t);
+__intrinsic float4 tex2D(sampler2D s, float2 t, float2 ddx, float2 ddy);
+__intrinsic float4 tex2Dbias(sampler2D s, float4 t);
+__intrinsic float4 tex2Dgrad(sampler2D s, float2 t, float2 ddx, float2 ddy);
+__intrinsic float4 tex2Dlod(sampler2D s, float4 t);
+__intrinsic float4 tex2Dproj(sampler2D s, float4 t);
+
+__intrinsic float4 tex3D(sampler3D s, float3 t);
+__intrinsic float4 tex3D(sampler3D s, float3 t, float3 ddx, float3 ddy);
+__intrinsic float4 tex3Dbias(sampler3D s, float4 t);
+__intrinsic float4 tex3Dgrad(sampler3D s, float3 t, float3 ddx, float3 ddy);
+__intrinsic float4 tex3Dlod(sampler3D s, float4 t);
+__intrinsic float4 tex3Dproj(sampler3D s, float4 t);
+
+__intrinsic float4 texCUBE(samplerCUBE s, float3 t);
+__intrinsic float4 texCUBE(samplerCUBE s, float3 t, float3 ddx, float3 ddy);
+__intrinsic float4 texCUBEbias(samplerCUBE s, float4 t);
+__intrinsic float4 texCUBEgrad(samplerCUBE s, float3 t, float3 ddx, float3 ddy);
+__intrinsic float4 texCUBElod(samplerCUBE s, float4 t);
+__intrinsic float4 texCUBEproj(samplerCUBE s, float4 t);
+*/
+
+// Matrix transpose
+__generic<T : __BuiltinType, let N : int, let M : int> __intrinsic matrix<T,M,N> transpose(matrix<T,N,M> x);
+
+// Truncate to integer
+__generic<T : __BuiltinFloatingPointType> __intrinsic T trunc(T x);
+__generic<T : __BuiltinFloatingPointType, let N : int> __intrinsic vector<T,N> trunc(vector<T,N> x);
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __intrinsic matrix<T,N,M> trunc(matrix<T,N,M> x);
+
+
+)", R"(
+
+// Shader model 6.0 stuff
+
+__intrinsic uint GlobalOrderedCountIncrement(uint countToAppendForThisLane);
+
+__generic<T : __BuiltinType> __intrinsic T QuadReadLaneAt(T sourceValue, int quadLaneID);
+__generic<T : __BuiltinType, let N : int> __intrinsic vector<T,N> QuadReadLaneAt(vector<T,N> sourceValue, int quadLaneID);
+__generic<T : __BuiltinType, let N : int, let M : int> __intrinsic matrix<T,N,M> QuadReadLaneAt(matrix<T,N,M> sourceValue, int quadLaneID);
+
+__generic<T : __BuiltinType> __intrinsic T QuadSwapX(T localValue);
+__generic<T : __BuiltinType, let N : int> __intrinsic vector<T,N> QuadSwapX(vector<T,N> localValue);
+__generic<T : __BuiltinType, let N : int, let M : int> __intrinsic matrix<T,N,M> QuadSwapX(matrix<T,N,M> localValue);
+
+__generic<T : __BuiltinType> __intrinsic T QuadSwapY(T localValue);
+__generic<T : __BuiltinType, let N : int> __intrinsic vector<T,N> QuadSwapY(vector<T,N> localValue);
+__generic<T : __BuiltinType, let N : int, let M : int> __intrinsic matrix<T,N,M> QuadSwapY(matrix<T,N,M> localValue);
+
+__generic<T : __BuiltinIntegerType> __intrinsic T WaveAllBitAnd(T expr);
+__generic<T : __BuiltinIntegerType, let N : int> __intrinsic vector<T,N> WaveAllBitAnd(vector<T,N> expr);
+__generic<T : __BuiltinIntegerType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveAllBitAnd(matrix<T,N,M> expr);
+
+__generic<T : __BuiltinIntegerType> __intrinsic T WaveAllBitOr(T expr);
+__generic<T : __BuiltinIntegerType, let N : int> __intrinsic vector<T,N> WaveAllBitOr(vector<T,N> expr);
+__generic<T : __BuiltinIntegerType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveAllBitOr(matrix<T,N,M> expr);
+
+__generic<T : __BuiltinIntegerType> __intrinsic T WaveAllBitXor(T expr);
+__generic<T : __BuiltinIntegerType, let N : int> __intrinsic vector<T,N> WaveAllBitXor(vector<T,N> expr);
+__generic<T : __BuiltinIntegerType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveAllBitXor(matrix<T,N,M> expr);
+
+__generic<T : __BuiltinArithmeticType> __intrinsic T WaveAllMax(T expr);
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> WaveAllMax(vector<T,N> expr);
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveAllMax(matrix<T,N,M> expr);
+
+__generic<T : __BuiltinArithmeticType> __intrinsic T WaveAllMin(T expr);
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> WaveAllMin(vector<T,N> expr);
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveAllMin(matrix<T,N,M> expr);
+
+__generic<T : __BuiltinArithmeticType> __intrinsic T WaveAllProduct(T expr);
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> WaveAllProduct(vector<T,N> expr);
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveAllProduct(matrix<T,N,M> expr);
+
+__generic<T : __BuiltinArithmeticType> __intrinsic T WaveAllSum(T expr);
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> WaveAllSum(vector<T,N> expr);
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveAllSum(matrix<T,N,M> expr);
+
+__intrinsic bool WaveAllEqual(bool expr);
+__intrinsic bool WaveAllTrue(bool expr);
+__intrinsic bool WaveAnyTrue(bool expr);
+
+uint64_t WaveBallot(bool expr);
+
+uint WaveGetLaneCount();
+uint WaveGetLaneIndex();
+uint WaveGetOrderedIndex();
+
+bool WaveIsHelperLane();
+
+bool WaveOnce();
+
+__generic<T : __BuiltinArithmeticType> __intrinsic T WavePrefixProduct(T expr);
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> WavePrefixProduct(vector<T,N> expr);
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> WavePrefixProduct(matrix<T,N,M> expr);
+
+__generic<T : __BuiltinArithmeticType> __intrinsic T WavePrefixSum(T expr);
+__generic<T : __BuiltinArithmeticType, let N : int> __intrinsic vector<T,N> WavePrefixSum(vector<T,N> expr);
+__generic<T : __BuiltinArithmeticType, let N : int, let M : int> __intrinsic matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr);
+
+__generic<T : __BuiltinType> __intrinsic T WaveReadFirstLane(T expr);
+__generic<T : __BuiltinType, let N : int> __intrinsic vector<T,N> WaveReadFirstLane(vector<T,N> expr);
+__generic<T : __BuiltinType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveReadFirstLane(matrix<T,N,M> expr);
+
+__generic<T : __BuiltinType> __intrinsic T WaveReadLaneAt(T expr, int laneIndex);
+__generic<T : __BuiltinType, let N : int> __intrinsic vector<T,N> WaveReadLaneAt(vector<T,N> expr, int laneIndex);
+__generic<T : __BuiltinType, let N : int, let M : int> __intrinsic matrix<T,N,M> WaveReadLaneAt(matrix<T,N,M> expr, int laneIndex);
+
+
+)", R"(
+
+// `typedef`s to help with the fact that HLSL has been sorta-kinda case insensitive at various points
+typedef Texture2D texture2D;
+
+#line default
+)" };
+
+
+using namespace CoreLib::Basic;
+
+namespace Slang
+{
+ namespace Compiler
+ {
+ static String stdlibPath;
+
+ String getStdlibPath()
+ {
+ if(stdlibPath.Length() != 0)
+ return stdlibPath;
+
+ StringBuilder pathBuilder;
+ for( auto cc = __FILE__; *cc; ++cc )
+ {
+ switch( *cc )
+ {
+ case '\n':
+ case '\t':
+ case '\\':
+ pathBuilder << "\\";
+ default:
+ pathBuilder << *cc;
+ break;
+ }
+ }
+ stdlibPath = pathBuilder.ProduceString();
+
+ return stdlibPath;
+ }
+
+ String SlangStdLib::code;
+
+ enum
+ {
+ SINT_MASK = 1 << 0,
+ FLOAT_MASK = 1 << 1,
+ COMPARISON = 1 << 2,
+ BOOL_MASK = 1 << 3,
+ UINT_MASK = 1 << 4,
+ ASSIGNMENT = 1 << 5,
+ POSTFIX = 1 << 6,
+
+ INT_MASK = SINT_MASK | UINT_MASK,
+ ARITHMETIC_MASK = INT_MASK | FLOAT_MASK,
+ LOGICAL_MASK = INT_MASK | BOOL_MASK,
+ ANY_MASK = INT_MASK | FLOAT_MASK | BOOL_MASK,
+ };
+
+ String SlangStdLib::GetCode()
+ {
+ if (code.Length() > 0)
+ return code;
+ StringBuilder sb;
+
+ // generate operator overloads
+
+
+
+ struct OpInfo { IntrinsicOp opCode; char const* opName; unsigned flags; };
+
+ OpInfo unaryOps[] = {
+ { IntrinsicOp::Neg, "-", ARITHMETIC_MASK },
+ { IntrinsicOp::Not, "!", ANY_MASK },
+ { IntrinsicOp::Not, "~", INT_MASK },
+ { IntrinsicOp::PreInc, "++", ARITHMETIC_MASK | ASSIGNMENT },
+ { IntrinsicOp::PreDec, "--", ARITHMETIC_MASK | ASSIGNMENT },
+ { IntrinsicOp::PostInc, "++", ARITHMETIC_MASK | ASSIGNMENT | POSTFIX },
+ { IntrinsicOp::PostDec, "--", ARITHMETIC_MASK | ASSIGNMENT | POSTFIX },
+ };
+
+ OpInfo binaryOps[] = {
+ { IntrinsicOp::Add, "+", ARITHMETIC_MASK },
+ { IntrinsicOp::Sub, "-", ARITHMETIC_MASK },
+ { IntrinsicOp::Mul, "*", ARITHMETIC_MASK },
+ { IntrinsicOp::Div, "/", ARITHMETIC_MASK },
+ { IntrinsicOp::Mod, "%", INT_MASK },
+
+ { IntrinsicOp::And, "&&", LOGICAL_MASK },
+ { IntrinsicOp::Or, "||", LOGICAL_MASK },
+
+ { IntrinsicOp::BitAnd, "&", LOGICAL_MASK },
+ { IntrinsicOp::BitOr, "|", LOGICAL_MASK },
+ { IntrinsicOp::BitXor, "^", LOGICAL_MASK },
+
+ { IntrinsicOp::Lsh, "<<", INT_MASK },
+ { IntrinsicOp::Rsh, ">>", INT_MASK },
+
+ { IntrinsicOp::Eql, "==", ANY_MASK | COMPARISON },
+ { IntrinsicOp::Neq, "!=", ANY_MASK | COMPARISON },
+
+ { IntrinsicOp::Greater, ">", ARITHMETIC_MASK | COMPARISON },
+ { IntrinsicOp::Less, "<", ARITHMETIC_MASK | COMPARISON },
+ { IntrinsicOp::Geq, ">=", ARITHMETIC_MASK | COMPARISON },
+ { IntrinsicOp::Leq, "<=", ARITHMETIC_MASK | COMPARISON },
+
+ { IntrinsicOp::AddAssign, "+=", ASSIGNMENT | ARITHMETIC_MASK },
+ { IntrinsicOp::SubAssign, "-=", ASSIGNMENT | ARITHMETIC_MASK },
+ { IntrinsicOp::MulAssign, "*=", ASSIGNMENT | ARITHMETIC_MASK },
+ { IntrinsicOp::DivAssign, "/=", ASSIGNMENT | ARITHMETIC_MASK },
+ { IntrinsicOp::ModAssign, "%=", ASSIGNMENT | ARITHMETIC_MASK },
+ { IntrinsicOp::AndAssign, "&=", ASSIGNMENT | LOGICAL_MASK },
+ { IntrinsicOp::OrAssign, "|=", ASSIGNMENT | LOGICAL_MASK },
+ { IntrinsicOp::XorAssign, "^=", ASSIGNMENT | LOGICAL_MASK },
+ { IntrinsicOp::LshAssign, "<<=", ASSIGNMENT | INT_MASK },
+ { IntrinsicOp::RshAssign, ">>=", ASSIGNMENT | INT_MASK },
+
+
+ };
+
+ /*
+ String floatTypes[] = { "float", "float2", "float3", "float4" };
+ String intTypes[] = { "int", "int2", "int3", "int4" };
+ String uintTypes[] = { "uint", "uint2", "uint3", "uint4" };
+ */
+
+ String path = getStdlibPath();
+
+
+
+#define EMIT_LINE_DIRECTIVE() sb << "#line " << (__LINE__+1) << " \"" << path << "\"\n"
+
+ // Generate declarations for all the base types
+
+ static const struct {
+ char const* name;
+ BaseType tag;
+ unsigned flags;
+ } kBaseTypes[] = {
+ { "void", BaseType::Void, 0 },
+ { "int", BaseType::Int, SINT_MASK },
+ { "float", BaseType::Float, FLOAT_MASK },
+ { "uint", BaseType::UInt, UINT_MASK },
+ { "bool", BaseType::Bool, BOOL_MASK },
+ { "uint64_t", BaseType::UInt64, UINT_MASK },
+ };
+ static const int kBaseTypeCount = sizeof(kBaseTypes) / sizeof(kBaseTypes[0]);
+ for (int tt = 0; tt < kBaseTypeCount; ++tt)
+ {
+ EMIT_LINE_DIRECTIVE();
+ sb << "__builtin_type(" << int(kBaseTypes[tt].tag) << ") struct " << kBaseTypes[tt].name << "\n{\n";
+
+ // Declare trait conformances for this type
+
+ sb << "__conforms __BuiltinType;\n";
+
+ switch( kBaseTypes[tt].tag )
+ {
+ case BaseType::Float:
+ sb << "__conforms __BuiltinFloatingPointType;\n";
+ sb << "__conforms __BuiltinRealType;\n";
+ // fall through to:
+ case BaseType::Int:
+ sb << "__conforms __BuiltinSignedArithmeticType;\n";
+ // fall through to:
+ case BaseType::UInt:
+ case BaseType::UInt64:
+ sb << "__conforms __BuiltinArithmeticType;\n";
+ // fall through to:
+ case BaseType::Bool:
+ sb << "__conforms __BuiltinType;\n";
+ break;
+
+ default:
+ break;
+ }
+
+ // Declare initializers to convert from various other types
+ for( int ss = 0; ss < kBaseTypeCount; ++ss )
+ {
+ if( kBaseTypes[ss].tag == BaseType::Void )
+ continue;
+
+ EMIT_LINE_DIRECTIVE();
+ sb << "__init(" << kBaseTypes[ss].name << " value);\n";
+ }
+
+ sb << "};\n";
+ }
+
+ // Declare ad hoc aliases for some types, just to get things compiling
+ //
+ // TODO(tfoley): At the very least, `double` should be treated as a distinct type.
+ sb << "typedef float double;\n";
+ sb << "typedef float half;\n";
+
+ // Declare vector and matrix types
+
+ sb << "__generic<T = float, let N : int = 4> __magic_type(Vector) struct vector\n{\n";
+ sb << " __init(T value);\n"; // initialize from single scalar
+ sb << "};\n";
+ sb << "__generic<T = float, let R : int = 4, let C : int = 4> __magic_type(Matrix) struct matrix {};\n";
+
+ static const struct {
+ char const* name;
+ char const* glslPrefix;
+ } kTypes[] =
+ {
+ {"float", ""},
+ {"int", "i"},
+ {"uint", "u"},
+ {"bool", "b"},
+ };
+ static const int kTypeCount = sizeof(kTypes) / sizeof(kTypes[0]);
+
+ for (int tt = 0; tt < kTypeCount; ++tt)
+ {
+ // Declare HLSL vector types
+ for (int ii = 1; ii <= 4; ++ii)
+ {
+ sb << "typedef vector<" << kTypes[tt].name << "," << ii << "> " << kTypes[tt].name << ii << ";\n";
+ }
+
+ // Declare HLSL matrix types
+ for (int rr = 2; rr <= 4; ++rr)
+ for (int cc = 2; cc <= 4; ++cc)
+ {
+ sb << "typedef matrix<" << kTypes[tt].name << "," << rr << "," << cc << "> " << kTypes[tt].name << rr << "x" << cc << ";\n";
+ }
+ }
+
+ static const char* kComponentNames[]{ "x", "y", "z", "w" };
+ static const char* kVectorNames[]{ "", "x", "xy", "xyz", "xyzw" };
+
+ // Need to add constructors to the types above
+ for (int N = 2; N <= 4; ++N)
+ {
+ sb << "__generic<T> __extension vector<T, " << N << ">\n{\n";
+
+ // initialize from N scalars
+ sb << "__init(";
+ for (int ii = 0; ii < N; ++ii)
+ {
+ if (ii != 0) sb << ", ";
+ sb << "T " << kComponentNames[ii];
+ }
+ sb << ");\n";
+
+ // Initialize from an M-vector and then scalars
+ for (int M = 2; M < N; ++M)
+ {
+ sb << "__init(vector<T," << M << "> " << kVectorNames[M];
+ for (int ii = M; ii < N; ++ii)
+ {
+ sb << ", T " << kComponentNames[ii];
+ }
+ sb << ");\n";
+ }
+
+ // initialize from another vector of the same size
+ //
+ // TODO(tfoley): this overlaps with implicit conversions.
+ // We should look for a way that we can define implicit
+ // conversions directly in the stdlib instead...
+ sb << "__generic<U> __init(vector<U," << N << ">);\n";
+
+ sb << "}\n";
+ }
+
+ for( int R = 2; R <= 4; ++R )
+ for( int C = 2; C <= 4; ++C )
+ {
+ sb << "__generic<T> __extension matrix<T, " << R << "," << C << ">\n{\n";
+
+ // initialize from R*C scalars
+ sb << "__init(";
+ for( int ii = 0; ii < R; ++ii )
+ for( int jj = 0; jj < C; ++jj )
+ {
+ if ((ii+jj) != 0) sb << ", ";
+ sb << "T m" << ii << jj;
+ }
+ sb << ");\n";
+
+ // Initialize from R C-vectors
+ sb << "__init(";
+ for (int ii = 0; ii < R; ++ii)
+ {
+ if(ii != 0) sb << ", ";
+ sb << "vector<T," << C << "> row" << ii;
+ }
+ sb << ");\n";
+
+
+ // initialize from another matrix of the same size
+ //
+ // TODO(tfoley): See comment about how this overlaps
+ // with implicit conversion, in the `vector` case above
+ sb << "__generic<U> __init(matrix<U," << R << ", " << C << ">);\n";
+
+ sb << "}\n";
+ }
+
+
+ // Declare built-in texture and sampler types
+
+ sb << "__magic_type(SamplerState," << int(SamplerStateType::Flavor::SamplerState) << ") struct SamplerState {};";
+ sb << "__magic_type(SamplerState," << int(SamplerStateType::Flavor::SamplerComparisonState) << ") struct SamplerComparisonState {};";
+
+ // TODO(tfoley): Need to handle `RW*` variants of texture types as well...
+ static const struct {
+ char const* name;
+ TextureType::Shape baseShape;
+ int coordCount;
+ } kBaseTextureTypes[] = {
+ { "Texture1D", TextureType::Shape1D, 1 },
+ { "Texture2D", TextureType::Shape2D, 2 },
+ { "Texture3D", TextureType::Shape3D, 3 },
+ { "TextureCube", TextureType::ShapeCube, 3 },
+ };
+ static const int kBaseTextureTypeCount = sizeof(kBaseTextureTypes) / sizeof(kBaseTextureTypes[0]);
+
+
+ static const struct {
+ char const* name;
+ SlangResourceAccess access;
+ } kBaseTextureAccessLevels[] = {
+ { "", SLANG_RESOURCE_ACCESS_READ },
+ { "RW", SLANG_RESOURCE_ACCESS_READ_WRITE },
+ { "RasterizerOrdered", SLANG_RESOURCE_ACCESS_RASTER_ORDERED },
+ };
+ static const int kBaseTextureAccessLevelCount = sizeof(kBaseTextureAccessLevels) / sizeof(kBaseTextureAccessLevels[0]);
+
+ for (int tt = 0; tt < kBaseTextureTypeCount; ++tt)
+ {
+ char const* name = kBaseTextureTypes[tt].name;
+ TextureType::Shape baseShape = kBaseTextureTypes[tt].baseShape;
+
+ for (int isArray = 0; isArray < 2; ++isArray)
+ {
+ // Arrays of 3D textures aren't allowed
+ if (isArray && baseShape == TextureType::Shape3D) continue;
+
+ for (int isMultisample = 0; isMultisample < 2; ++isMultisample)
+ for (int accessLevel = 0; accessLevel < kBaseTextureAccessLevelCount; ++accessLevel)
+ {
+ auto access = kBaseTextureAccessLevels[accessLevel].access;
+
+ // TODO: any constraints to enforce on what gets to be multisampled?
+
+ unsigned flavor = baseShape;
+ if (isArray) flavor |= TextureType::ArrayFlag;
+ if (isMultisample) flavor |= TextureType::MultisampleFlag;
+// if (isShadow) flavor |= TextureType::ShadowFlag;
+
+ flavor |= (access << 8);
+
+
+ // emit a generic signature
+ // TODO: allow for multisample count to come in as well...
+ sb << "__generic<T = float4> ";
+
+ sb << "__magic_type(Texture," << int(flavor) << ") struct ";
+ sb << kBaseTextureAccessLevels[accessLevel].name;
+ sb << name;
+ if (isMultisample) sb << "MS";
+ if (isArray) sb << "Array";
+// if (isShadow) sb << "Shadow";
+ sb << "\n{";
+
+ if( !isMultisample )
+ {
+ sb << "float CalculateLevelOfDetail(SamplerState s, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount << " location);\n";
+
+ sb << "float CalculateLevelOfDetailUnclamped(SamplerState s, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount << " location);\n";
+
+ // TODO: `Gather` operation
+ // (tricky because it returns a 4-vector of the element type
+ // of the texture components...)
+ }
+
+ // TODO: `GetDimensions` operations
+
+ for(int isFloat = 0; isFloat < 2; ++isFloat)
+ for(int includeMipInfo = 0; includeMipInfo < 2; ++includeMipInfo)
+ {
+ char const* t = isFloat ? "out float " : "out UINT ";
+
+ sb << "void GetDimensions(";
+ if(includeMipInfo)
+ sb << "UINT mipLevel, ";
+
+ switch(baseShape)
+ {
+ case TextureType::Shape1D:
+ sb << t << "width";
+ break;
+
+ case TextureType::Shape2D:
+ case TextureType::ShapeCube:
+ sb << t << "width,";
+ sb << t << "height";
+ break;
+
+ case TextureType::Shape3D:
+ sb << t << "width,";
+ sb << t << "height,";
+ sb << t << "depth";
+ break;
+
+ default:
+ assert(!"unexpected");
+ break;
+ }
+
+ if(isArray)
+ {
+ sb << ", " << t << "elements";
+ }
+
+ if(includeMipInfo)
+ sb << ", " << t << "numberOfLevels";
+
+ sb << ");\n";
+ }
+
+ // `GetSamplePosition()`
+ if( isMultisample )
+ {
+ sb << "float2 GetSamplePosition(int s);\n";
+ }
+
+ // `Load()`
+
+ if( kBaseTextureTypes[tt].coordCount + isArray < 4 )
+ {
+ sb << "T Load(";
+ sb << "int" << kBaseTextureTypes[tt].coordCount + isArray + 1 << " location);\n";
+
+ if( !isMultisample )
+ {
+ sb << "T Load(";
+ sb << "int" << kBaseTextureTypes[tt].coordCount + isArray + 1 << " location, ";
+ sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n";
+ }
+ else
+ {
+ sb << "T Load(";
+ sb << "int" << kBaseTextureTypes[tt].coordCount + isArray + 1 << " location, ";
+ sb << "int sampleIndex, ";
+ sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n";
+ }
+ }
+
+ if(baseShape != TextureType::ShapeCube)
+ {
+ // subscript operator
+ sb << "__intrinsic __subscript(uint" << kBaseTextureTypes[tt].coordCount + isArray << " location) -> T;\n";
+ }
+
+ if( !isMultisample )
+ {
+ // `Sample()`
+
+ sb << "T Sample(SamplerState s, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location);\n";
+
+ if( baseShape != TextureType::ShapeCube )
+ {
+ sb << "T Sample(SamplerState s, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, ";
+ sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n";
+ }
+
+ sb << "T Sample(SamplerState s, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, ";
+ if( baseShape != TextureType::ShapeCube )
+ {
+ sb << "int" << kBaseTextureTypes[tt].coordCount << " offset, ";
+ }
+ sb << "float clamp);\n";
+
+ sb << "T Sample(SamplerState s, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, ";
+ if( baseShape != TextureType::ShapeCube )
+ {
+ sb << "int" << kBaseTextureTypes[tt].coordCount << " offset, ";
+ }
+ sb << "float clamp, out uint status);\n";
+
+
+ // `SampleBias()`
+ sb << "T SampleBias(SamplerState s, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, float bias);\n";
+
+ if( baseShape != TextureType::ShapeCube )
+ {
+ sb << "T SampleBias(SamplerState s, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, float bias, ";
+ sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n";
+ }
+
+ // `SampleCmp()` and `SampleCmpLevelZero`
+ sb << "T SampleCmp(SamplerComparisonState s, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, ";
+ sb << "float compareValue";
+ sb << ");\n";
+
+ sb << "T SampleCmpLevelZero(SamplerComparisonState s, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, ";
+ sb << "float compareValue";
+ sb << ");\n";
+
+ if( baseShape != TextureType::ShapeCube )
+ {
+ // Note(tfoley): MSDN seems confused, and claims that the `offset`
+ // parameter for `SampleCmp` is available for everything but 3D
+ // textures, while `Sample` and `SampleBias` are consistent in
+ // saying they only exclude `offset` for cube maps (which makes
+ // sense). I'm going to assume the documentation for `SampleCmp`
+ // is just wrong.
+
+ sb << "T SampleCmp(SamplerState s, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, ";
+ sb << "float compareValue, ";
+ sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n";
+
+ sb << "T SampleCmpLevelZero(SamplerState s, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, ";
+ sb << "float compareValue, ";
+ sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n";
+ }
+
+ sb << "T SampleGrad(SamplerState s, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount << " gradX, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount << " gradY";
+ sb << ");\n";
+
+ if( baseShape != TextureType::ShapeCube )
+ {
+ sb << "T SampleGrad(SamplerState s, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount << " gradX, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount << " gradY, ";
+ sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n";
+ }
+
+ // `SampleLevel`
+
+ sb << "T SampleLevel(SamplerState s, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, ";
+ sb << "float level);\n";
+
+ if( baseShape != TextureType::ShapeCube )
+ {
+ sb << "T SampleLevel(SamplerState s, ";
+ sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, ";
+ sb << "float level, ";
+ sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n";
+ }
+ }
+
+ sb << "\n};\n";
+ }
+ }
+ }
+
+ // Declare additional built-in generic types
+
+ sb << "__generic<T> __magic_type(ConstantBuffer) struct ConstantBuffer {};\n";
+ sb << "__generic<T> __magic_type(TextureBuffer) struct TextureBuffer {};\n";
+
+ sb << "__generic<T> __magic_type(PackedBuffer) struct PackedBuffer {};\n";
+ sb << "__generic<T> __magic_type(Uniform) struct Uniform {};\n";
+ sb << "__generic<T> __magic_type(Patch) struct Patch {};\n";
+
+
+ // Stale declarations for GLSL inner-product builtins
+#if 0
+ sb << "__intrinsic vec3 operator * (vec3, mat3);\n";
+ sb << "__intrinsic vec3 operator * (mat3, vec3);\n";
+
+ sb << "__intrinsic vec4 operator * (vec4, mat4);\n";
+ sb << "__intrinsic vec4 operator * (mat4, vec4);\n";
+
+ sb << "__intrinsic mat3 operator * (mat3, mat3);\n";
+ sb << "__intrinsic mat4 operator * (mat4, mat4);\n";
+#endif
+
+#if 0
+ sb << "__intrinsic(And) bool operator && (bool, bool);\n";
+ sb << "__intrinsic(Or) bool operator || (bool, bool);\n";
+
+ for (auto type : intTypes)
+ {
+ sb << "__intrinsic(And) bool operator && (bool, " << type << ");\n";
+ sb << "__intrinsic(Or) bool operator || (bool, " << type << ");\n";
+ sb << "__intrinsic(And) bool operator && (" << type << ", bool);\n";
+ sb << "__intrinsic(Or) bool operator || (" << type << ", bool);\n";
+ }
+#endif
+
+ for (auto op : unaryOps)
+ {
+ for (auto type : kBaseTypes)
+ {
+ if ((type.flags & op.flags) == 0)
+ continue;
+
+ char const* fixity = (op.flags & POSTFIX) != 0 ? "__postfix " : "__prefix ";
+ char const* qual = (op.flags & ASSIGNMENT) != 0 ? "in out " : "";
+
+ // scalar version
+ sb << fixity;
+ sb << "__intrinsic(" << int(op.opCode) << ") " << type.name << " operator" << op.opName << "(" << qual << type.name << " value);\n";
+
+ // vector version
+ sb << "__generic<let N : int> ";
+ sb << fixity;
+ sb << "__intrinsic(" << int(op.opCode) << ") vector<" << type.name << ",N> operator" << op.opName << "(" << qual << "vector<" << type.name << ",N> value);\n";
+
+ // matrix version
+ sb << "__generic<let N : int, let M : int> ";
+ sb << fixity;
+ sb << "__intrinsic(" << int(op.opCode) << ") matrix<" << type.name << ",N,M> operator" << op.opName << "(" << qual << "matrix<" << type.name << ",N,M> value);\n";
+ }
+ }
+
+ for (auto op : binaryOps)
+ {
+ for (auto type : kBaseTypes)
+ {
+ if ((type.flags & op.flags) == 0)
+ continue;
+
+ char const* leftType = type.name;
+ char const* rightType = leftType;
+ char const* resultType = leftType;
+
+ if (op.flags & COMPARISON) resultType = "bool";
+
+ char const* leftQual = "";
+ if(op.flags & ASSIGNMENT) leftQual = "in out ";
+
+ // TODO: handle `SHIFT`
+
+ // scalar version
+ sb << "__intrinsic(" << int(op.opCode) << ") " << resultType << " operator" << op.opName << "(" << leftQual << leftType << " left, " << rightType << " right);\n";
+
+ // vector version
+ sb << "__generic<let N : int> ";
+ sb << "__intrinsic(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(" << leftQual << "vector<" << leftType << ",N> left, vector<" << rightType << ",N> right);\n";
+
+ // matrix version
+ sb << "__generic<let N : int, let M : int> ";
+ sb << "__intrinsic(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(" << leftQual << "matrix<" << leftType << ",N,M> left, matrix<" << rightType << ",N,M> right);\n";
+ }
+ }
+
+#if 0
+ for (auto op : intUnaryOps)
+ {
+ String opName = GetOperatorFunctionName(op);
+ for (int i = 0; i < 4; i++)
+ {
+ auto itype = intTypes[i];
+ auto utype = uintTypes[i];
+ for (int j = 0; j < 2; j++)
+ {
+ auto retType = (op == Operator::Not) ? "bool" : j == 0 ? itype : utype;
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << (j == 0 ? itype : utype) << ");\n";
+ }
+ }
+ }
+
+ for (auto op : floatUnaryOps)
+ {
+ String opName = GetOperatorFunctionName(op);
+ for (int i = 0; i < 4; i++)
+ {
+ auto type = floatTypes[i];
+ auto retType = (op == Operator::Not) ? "bool" : type;
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ");\n";
+ }
+ }
+
+ for (auto op : floatOps)
+ {
+ String opName = GetOperatorFunctionName(op);
+ for (int i = 0; i < 4; i++)
+ {
+ auto type = floatTypes[i];
+ auto itype = intTypes[i];
+ auto utype = uintTypes[i];
+ auto retType = ((op >= Operator::Eql && op <= Operator::Leq) || op == Operator::And || op == Operator::Or) ? "bool" : type;
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << type << ");\n";
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << itype << ", " << type << ");\n";
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << utype << ", " << type << ");\n";
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << itype << ");\n";
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << utype << ");\n";
+ if (i > 0)
+ {
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << floatTypes[0] << ");\n";
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << floatTypes[0] << ", " << type << ");\n";
+
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << intTypes[0] << ");\n";
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << intTypes[0] << ", " << type << ");\n";
+
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << uintTypes[0] << ");\n";
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << uintTypes[0] << ", " << type << ");\n";
+ }
+ }
+ }
+
+ for (auto op : intOps)
+ {
+ String opName = GetOperatorFunctionName(op);
+ for (int i = 0; i < 4; i++)
+ {
+ auto type = intTypes[i];
+ auto utype = uintTypes[i];
+ auto retType = ((op >= Operator::Eql && op <= Operator::Leq) || op == Operator::And || op == Operator::Or) ? "bool" : type;
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << type << ");\n";
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << utype << ", " << type << ");\n";
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << utype << ");\n";
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << utype << ", " << utype << ");\n";
+ if (i > 0)
+ {
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << intTypes[0] << ");\n";
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << intTypes[0] << ", " << type << ");\n";
+
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << uintTypes[0] << ");\n";
+ sb << "__intrinsic " << retType << " operator " << opName << "(" << uintTypes[0] << ", " << type << ");\n";
+ }
+ }
+ }
+#endif
+
+ // Output a suitable `#line` directive to point at our raw stdlib code above
+ sb << "\n#line " << kLibIncludeStringLine << " \"" << path << "\"\n";
+
+ int chunkCount = sizeof(LibIncludeStringChunks) / sizeof(LibIncludeStringChunks[0]);
+ for (int cc = 0; cc < chunkCount; ++cc)
+ {
+ sb << LibIncludeStringChunks[cc];
+ }
+
+ code = sb.ProduceString();
+ return code;
+ }
+
+
+ // GLSL-specific library code
+
+ String glslLibraryCode;
+
+ String getGLSLLibraryCode()
+ {
+ if(glslLibraryCode.Length() != 0)
+ return glslLibraryCode;
+
+ String path = getStdlibPath();
+
+ StringBuilder sb;
+
+#define RAW(TEXT) \
+ EMIT_LINE_DIRECTIVE(); \
+ sb << TEXT;
+
+ static const struct {
+ char const* name;
+ char const* glslPrefix;
+ } kTypes[] =
+ {
+ {"float", ""},
+ {"int", "i"},
+ {"uint", "u"},
+ {"bool", "b"},
+ };
+ static const int kTypeCount = sizeof(kTypes) / sizeof(kTypes[0]);
+
+ for( int tt = 0; tt < kTypeCount; ++tt )
+ {
+ // Declare GLSL aliases for HLSL types
+ for (int vv = 2; vv <= 4; ++vv)
+ {
+ sb << "typedef " << kTypes[tt].name << vv << " " << kTypes[tt].glslPrefix << "vec" << vv << ";\n";
+ sb << "typedef " << kTypes[tt].name << vv << "x" << vv << " " << kTypes[tt].glslPrefix << "mat" << vv << ";\n";
+ }
+ for (int rr = 2; rr <= 4; ++rr)
+ for (int cc = 2; cc <= 4; ++cc)
+ {
+ sb << "typedef " << kTypes[tt].name << rr << "x" << cc << " " << kTypes[tt].glslPrefix << "mat" << rr << "x" << cc << ";\n";
+ }
+ }
+
+ // TODO(tfoley): Need to handle `RW*` variants of texture types as well...
+ static const struct {
+ char const* name;
+ TextureType::Shape baseShape;
+ int coordCount;
+ } kBaseTextureTypes[] = {
+ { "1D", TextureType::Shape1D, 1 },
+ { "2D", TextureType::Shape2D, 2 },
+ { "3D", TextureType::Shape3D, 3 },
+ { "Cube", TextureType::ShapeCube, 3 },
+ };
+ static const int kBaseTextureTypeCount = sizeof(kBaseTextureTypes) / sizeof(kBaseTextureTypes[0]);
+
+
+ static const struct {
+ char const* name;
+ SlangResourceAccess access;
+ } kBaseTextureAccessLevels[] = {
+ { "", SLANG_RESOURCE_ACCESS_READ },
+ { "RW", SLANG_RESOURCE_ACCESS_READ_WRITE },
+ { "RasterizerOrdered", SLANG_RESOURCE_ACCESS_RASTER_ORDERED },
+ };
+ static const int kBaseTextureAccessLevelCount = sizeof(kBaseTextureAccessLevels) / sizeof(kBaseTextureAccessLevels[0]);
+
+ for (int tt = 0; tt < kBaseTextureTypeCount; ++tt)
+ {
+ char const* shapeName = kBaseTextureTypes[tt].name;
+ TextureType::Shape baseShape = kBaseTextureTypes[tt].baseShape;
+
+ for (int isArray = 0; isArray < 2; ++isArray)
+ {
+ // Arrays of 3D textures aren't allowed
+ if (isArray && baseShape == TextureType::Shape3D) continue;
+
+ for (int isMultisample = 0; isMultisample < 2; ++isMultisample)
+ {
+ auto access = SLANG_RESOURCE_ACCESS_READ;
+
+ // TODO: any constraints to enforce on what gets to be multisampled?
+
+
+ unsigned flavor = baseShape;
+ if (isArray) flavor |= TextureType::ArrayFlag;
+ if (isMultisample) flavor |= TextureType::MultisampleFlag;
+// if (isShadow) flavor |= TextureType::ShadowFlag;
+
+ flavor |= (access << 8);
+
+ StringBuilder nameBuilder;
+ nameBuilder << shapeName;
+ if (isMultisample) nameBuilder << "MS";
+ if (isArray) nameBuilder << "Array";
+ auto name = nameBuilder.ProduceString();
+
+ sb << "__generic<T> ";
+ sb << "__magic_type(TextureSampler," << int(flavor) << ") struct ";
+ sb << "__sampler" << name;
+ sb << " {};\n";
+
+ sb << "__generic<T> ";
+ sb << "__magic_type(Texture," << int(flavor) << ") struct ";
+ sb << "__texture" << name;
+ sb << " {};\n";
+
+ sb << "__generic<T> ";
+ sb << "__magic_type(GLSLImageType," << int(flavor) << ") struct ";
+ sb << "__image" << name;
+ sb << " {};\n";
+
+ // TODO(tfoley): flesh this out for all the available prefixes
+ static const struct
+ {
+ char const* prefix;
+ char const* elementType;
+ } kTextureElementTypes[] = {
+ { "", "vec4" },
+ { "i", "ivec4" },
+ { "u", "uvec4" },
+ { nullptr, nullptr },
+ };
+ for( auto ee = kTextureElementTypes; ee->prefix; ++ee )
+ {
+ sb << "typedef __sampler" << name << "<" << ee->elementType << "> " << ee->prefix << "sampler" << name << ";\n";
+ sb << "typedef __texture" << name << "<" << ee->elementType << "> " << ee->prefix << "texture" << name << ";\n";
+ sb << "typedef __image" << name << "<" << ee->elementType << "> " << ee->prefix << "image" << name << ";\n";
+ }
+ }
+ }
+ }
+
+ sb << "__generic<T> __magic_type(GLSLInputParameterBlockType) struct __GLSLInputParameterBlock {};\n";
+ sb << "__generic<T> __magic_type(GLSLOutputParameterBlockType) struct __GLSLOutputParameterBlock {};\n";
+ sb << "__generic<T> __magic_type(GLSLShaderStorageBufferType) struct __GLSLShaderStorageBuffer {};\n";
+
+ sb << "__magic_type(SamplerState," << int(SamplerStateType::Flavor::SamplerState) << ") struct sampler {};";
+
+ sb << "__magic_type(GLSLInputAttachmentType) struct subpassInput {};";
+
+ // Define additional keywords
+ sb << "__modifier(GLSLBufferModifier) buffer;\n";
+ sb << "__modifier(GLSLWriteOnlyModifier) writeonly;\n";
+ sb << "__modifier(GLSLReadOnlyModifier) readonly;\n";
+ sb << "__modifier(GLSLPatchModifier) patch;\n";
+
+ sb << "__modifier(SimpleModifier) flat;\n";
+
+ glslLibraryCode = sb.ProduceString();
+ return glslLibraryCode;
+ }
+
+
+
+ //
+
+ void SlangStdLib::Finalize()
+ {
+ code = nullptr;
+ stdlibPath = String();
+ glslLibraryCode = String();
+ }
+
+ }
+}
+
diff --git a/source/slang/slang-stdlib.h b/source/slang/slang-stdlib.h
new file mode 100644
index 000000000..65c70ecb5
--- /dev/null
+++ b/source/slang/slang-stdlib.h
@@ -0,0 +1,23 @@
+#ifndef SHADER_COMPILER_STD_LIB_H
+#define SHADER_COMPILER_STD_LIB_H
+
+#include "../core/basic.h"
+
+namespace Slang
+{
+ namespace Compiler
+ {
+ class SlangStdLib
+ {
+ private:
+ static CoreLib::String code;
+ public:
+ static CoreLib::String GetCode();
+ static void Finalize();
+ };
+
+ CoreLib::String getGLSLLibraryCode();
+ }
+}
+
+#endif \ No newline at end of file
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
new file mode 100644
index 000000000..2f37981c5
--- /dev/null
+++ b/source/slang/slang.cpp
@@ -0,0 +1,699 @@
+#include "../../slang.h"
+
+#include "../core/slang-io.h"
+#include "../slang/slang-stdlib.h"
+#include "../slang/parser.h"
+#include "../slang/preprocessor.h"
+#include "../slang/reflection.h"
+#include "../slang/type-layout.h"
+
+#ifdef _WIN32
+#define WIN32_LEAN_AND_MEAN
+#define NOMINMAX
+#include <Windows.h>
+#undef WIN32_LEAN_AND_MEAN
+#undef NOMINMAX
+#endif
+
+using namespace CoreLib::Basic;
+using namespace CoreLib::IO;
+using namespace Slang::Compiler;
+
+namespace SlangLib
+{
+ static void stdlibDiagnosticCallback(
+ char const* message,
+ void* userData)
+ {
+ fputs(message, stderr);
+ fflush(stderr);
+#ifdef WIN32
+ OutputDebugStringA(message);
+#endif
+ }
+
+ class Session
+ {
+ public:
+ bool useCache = false;
+ CoreLib::String cacheDir;
+
+ RefPtr<ShaderCompiler> compiler;
+
+ RefPtr<Scope> slangLanguageScope;
+ RefPtr<Scope> hlslLanguageScope;
+ RefPtr<Scope> glslLanguageScope;
+
+ List<RefPtr<ProgramSyntaxNode>> loadedModuleCode;
+
+
+ Session(bool /*pUseCache*/, CoreLib::String /*pCacheDir*/)
+ {
+ compiler = CreateShaderCompiler();
+
+ // Create scopes for various language builtins.
+ //
+ // TODO: load these on-demand to avoid parsing
+ // stdlib code for languages the user won't use.
+
+ slangLanguageScope = new Scope();
+
+ hlslLanguageScope = new Scope();
+ hlslLanguageScope->parent = slangLanguageScope;
+
+ glslLanguageScope = new Scope();
+ glslLanguageScope->parent = slangLanguageScope;
+
+ addBuiltinSource(slangLanguageScope, "stdlib", SlangStdLib::GetCode());
+ addBuiltinSource(glslLanguageScope, "glsl", getGLSLLibraryCode());
+ }
+
+ ~Session()
+ {
+ // We need to clean up the strings for the standard library
+ // code that we might have allocated and loaded into static
+ // variables (TODO: don't use `static` variables for this stuff)
+
+ SlangStdLib::Finalize();
+
+ // Ditto for our type represnetation stuff
+
+ ExpressionType::Finalize();
+ }
+
+ CompileUnit createPredefUnit()
+ {
+ CompileUnit translationUnit;
+
+
+ RefPtr<ProgramSyntaxNode> translationUnitSyntax = new ProgramSyntaxNode();
+
+ TranslationUnitOptions translationUnitOptions;
+ translationUnit.options = translationUnitOptions;
+ translationUnit.SyntaxNode = translationUnitSyntax;
+
+ return translationUnit;
+ }
+
+ void addBuiltinSource(
+ RefPtr<Scope> const& scope,
+ String const& path,
+ String const& source);
+ };
+
+ struct CompileRequest
+ {
+ // Pointer to parent session
+ Session* mSession;
+
+ // Input options
+ CompileOptions Options;
+
+ // Output stuff
+ DiagnosticSink mSink;
+ String mDiagnosticOutput;
+
+ RefPtr<CollectionOfTranslationUnits> mCollectionOfTranslationUnits;
+
+ RefPtr<ProgramLayout> mReflectionData;
+
+ CompileResult mResult;
+
+ List<String> mDependencyFilePaths;
+
+ CompileRequest(Session* session)
+ : mSession(session)
+ {}
+
+ ~CompileRequest()
+ {}
+
+ struct IncludeHandlerImpl : IncludeHandler
+ {
+ CompileRequest* request;
+
+ List<String> searchDirs;
+
+ virtual bool TryToFindIncludeFile(
+ CoreLib::String const& pathToInclude,
+ CoreLib::String const& pathIncludedFrom,
+ CoreLib::String* outFoundPath,
+ CoreLib::String* outFoundSource) override
+ {
+ String path = Path::Combine(Path::GetDirectoryName(pathIncludedFrom), pathToInclude);
+ if (File::Exists(path))
+ {
+ *outFoundPath = path;
+ *outFoundSource = File::ReadAllText(path);
+
+ request->mDependencyFilePaths.Add(path);
+
+ return true;
+ }
+
+ for (auto & dir : searchDirs)
+ {
+ path = Path::Combine(dir, pathToInclude);
+ if (File::Exists(path))
+ {
+ *outFoundPath = path;
+ *outFoundSource = File::ReadAllText(path);
+
+ request->mDependencyFilePaths.Add(path);
+
+ return true;
+ }
+ }
+ return false;
+ }
+ };
+
+
+ CompileUnit parseTranslationUnit(
+ TranslationUnitOptions const& translationUnitOptions)
+ {
+ auto& options = Options;
+
+ IncludeHandlerImpl includeHandler;
+ includeHandler.request = this;
+
+ CompileUnit translationUnit;
+
+ RefPtr<Scope> languageScope;
+ switch( translationUnitOptions.sourceLanguage )
+ {
+ case SourceLanguage::HLSL:
+ languageScope = mSession->hlslLanguageScope;
+ break;
+
+ case SourceLanguage::GLSL:
+ languageScope = mSession->glslLanguageScope;
+ break;
+
+ case SourceLanguage::Slang:
+ default:
+ languageScope = mSession->slangLanguageScope;
+ break;
+ }
+
+
+ auto& preprocesorDefinitions = options.PreprocessorDefinitions;
+
+ RefPtr<ProgramSyntaxNode> translationUnitSyntax = new ProgramSyntaxNode();
+
+ for( auto sourceFile : translationUnitOptions.sourceFiles )
+ {
+ auto sourceFilePath = sourceFile->path;
+
+ auto searchDirs = options.SearchDirectories;
+ searchDirs.Reverse();
+ searchDirs.Add(Path::GetDirectoryName(sourceFilePath));
+ searchDirs.Reverse();
+ includeHandler.searchDirs = searchDirs;
+
+ String source = sourceFile->content;
+
+ auto tokens = preprocessSource(
+ source,
+ sourceFilePath,
+ mResult.GetErrorWriter(),
+ &includeHandler,
+ preprocesorDefinitions,
+ translationUnitSyntax.Ptr());
+
+ parseSourceFile(
+ translationUnitSyntax.Ptr(),
+ options,
+ tokens,
+ mResult.GetErrorWriter(),
+ sourceFilePath,
+ languageScope);
+ }
+
+ translationUnit.options = translationUnitOptions;
+ translationUnit.SyntaxNode = translationUnitSyntax;
+
+ return translationUnit;
+ }
+
+ int executeCompilerDriverActions()
+ {
+ // If we are being asked to do pass-through, then we need to do that here...
+ if (Options.passThrough != PassThroughMode::None)
+ {
+ for( auto& translationUnitOptions : Options.translationUnits )
+ {
+ switch( translationUnitOptions.sourceLanguage )
+ {
+ // We can pass-through code written in a native shading language
+ case SourceLanguage::GLSL:
+ case SourceLanguage::HLSL:
+ break;
+
+ // All other translation units need to be skipped
+ default:
+ continue;
+ }
+
+ auto sourceFile = translationUnitOptions.sourceFiles[0];
+ auto sourceFilePath = sourceFile->path;
+ String source = sourceFile->content;
+
+ mSession->compiler->PassThrough(
+ source,
+ sourceFilePath,
+ Options,
+ translationUnitOptions);
+ }
+ return 0;
+ }
+
+ // TODO: load the stdlib
+
+ mCollectionOfTranslationUnits = new CollectionOfTranslationUnits();
+
+ // Parse everything from the input files requested
+ //
+ // TODO: this may trigger the loading and/or compilation of additional modules.
+ for( auto& translationUnitOptions : Options.translationUnits )
+ {
+ auto translationUnit = parseTranslationUnit(translationUnitOptions);
+ mCollectionOfTranslationUnits->translationUnits.Add(translationUnit);
+ }
+ if( mResult.GetErrorCount() != 0 )
+ return 1;
+
+ // Now perform semantic checks, emit output, etc.
+ mSession->compiler->Compile(
+ mResult, mCollectionOfTranslationUnits.Ptr(), Options);
+ if(mResult.GetErrorCount() != 0)
+ return 1;
+
+ mReflectionData = mCollectionOfTranslationUnits->layout;
+
+ return 0;
+ }
+
+ // Act as expected of the API-based compiler
+ int executeAPIActions()
+ {
+ mResult.mSink = &mSink;
+
+ int err = executeCompilerDriverActions();
+
+ mDiagnosticOutput = mSink.outputBuffer.ProduceString();
+
+ if(mSink.GetErrorCount() != 0)
+ return mSink.GetErrorCount();
+
+ return err;
+ }
+
+ int addTranslationUnit(SourceLanguage language, String const& name)
+ {
+ int result = Options.translationUnits.Count();
+
+ TranslationUnitOptions translationUnit;
+ translationUnit.sourceLanguage = SourceLanguage(language);
+
+ Options.translationUnits.Add(translationUnit);
+
+ return result;
+ }
+
+ void addTranslationUnitSourceString(
+ int translationUnitIndex,
+ String const& path,
+ String const& source)
+ {
+ RefPtr<SourceFile> sourceFile = new SourceFile();
+ sourceFile->path = path;
+ sourceFile->content = source;
+
+ Options.translationUnits[translationUnitIndex].sourceFiles.Add(sourceFile);
+ }
+
+ void addTranslationUnitSourceFile(
+ int translationUnitIndex,
+ String const& path)
+ {
+ String source;
+ try
+ {
+ source = File::ReadAllText(path);
+ }
+ catch( ... )
+ {
+ // Emit a diagnostic!
+ mSink.diagnose(
+ CodePosition(0,0,0,path),
+ Diagnostics::cannotOpenFile,
+ path);
+ return;
+ }
+
+ addTranslationUnitSourceString(
+ translationUnitIndex,
+ path,
+ source);
+
+ mDependencyFilePaths.Add(path);
+ }
+
+ int addTranslationUnitEntryPoint(
+ int translationUnitIndex,
+ String const& name,
+ Profile profile)
+ {
+ EntryPointOption entryPoint;
+ entryPoint.name = name;
+ entryPoint.profile = profile;
+
+ // TODO: realistically want this to be global across all TUs...
+ int result = Options.translationUnits[translationUnitIndex].entryPoints.Count();
+
+ Options.translationUnits[translationUnitIndex].entryPoints.Add(entryPoint);
+ return result;
+ }
+ };
+
+ void Session::addBuiltinSource(
+ RefPtr<Scope> const& scope,
+ String const& path,
+ String const& source)
+ {
+ CompileRequest compileRequest(this);
+
+ auto translationUnitIndex = compileRequest.addTranslationUnit(SourceLanguage::Slang, path);
+
+ compileRequest.addTranslationUnitSourceString(
+ translationUnitIndex,
+ path,
+ source);
+
+ int err = compileRequest.executeAPIActions();
+ if(err)
+ {
+ fprintf(stderr, "%s", compileRequest.mDiagnosticOutput.Buffer());
+
+#ifdef _WIN32
+ OutputDebugStringA(compileRequest.mDiagnosticOutput.Buffer());
+#endif
+
+ assert(!"error in stdlib");
+ }
+
+ // Extract the AST for the code we just parsed
+ auto syntax = compileRequest.mCollectionOfTranslationUnits->translationUnits[translationUnitIndex].SyntaxNode;
+
+ // HACK(tfoley): mark all declarations in the "stdlib" so
+ // that we can detect them later (e.g., so we don't emit them)
+ for (auto m : syntax->Members)
+ {
+ auto fromStdLibModifier = new FromStdLibModifier();
+
+ fromStdLibModifier->next = m->modifiers.first;
+ m->modifiers.first = fromStdLibModifier;
+ }
+
+ // Add the resulting code to the appropriate scope
+ if( !scope->containerDecl )
+ {
+ // We are the first chunk of code to be loaded for this scope
+ scope->containerDecl = syntax.Ptr();
+ }
+ else
+ {
+ // We need to create a new scope to link into the whole thing
+ auto subScope = new Scope();
+ subScope->containerDecl = syntax.Ptr();
+ subScope->nextSibling = scope->nextSibling;
+ scope->nextSibling = subScope;
+ }
+
+ // We need to retain this AST so that we can use it in other code
+ // (Note that the `Scope` type does not retain the AST it points to)
+ loadedModuleCode.Add(syntax);
+ }
+}
+
+using namespace SlangLib;
+
+// implementation of C interface
+
+#define SESSION(x) reinterpret_cast<SlangLib::Session *>(x)
+#define REQ(x) reinterpret_cast<SlangLib::CompileRequest*>(x)
+
+SLANG_API SlangSession* spCreateSession(const char * cacheDir)
+{
+ return reinterpret_cast<SlangSession *>(new SlangLib::Session((cacheDir ? true : false), cacheDir));
+}
+
+SLANG_API void spDestroySession(
+ SlangSession* session)
+{
+ if(!session) return;
+ delete SESSION(session);
+}
+
+SLANG_API void spAddBuiltins(
+ SlangSession* session,
+ char const* sourcePath,
+ char const* sourceString)
+{
+ auto s = SESSION(session);
+ s->addBuiltinSource(
+
+ // TODO(tfoley): Add ability to directly new builtins to the approriate scope
+ s->slangLanguageScope,
+
+ sourcePath,
+ sourceString);
+}
+
+
+SLANG_API SlangCompileRequest* spCreateCompileRequest(
+ SlangSession* session)
+{
+ auto s = SESSION(session);
+ auto req = new SlangLib::CompileRequest(s);
+ return reinterpret_cast<SlangCompileRequest*>(req);
+}
+
+/*!
+@brief Destroy a compile request.
+*/
+SLANG_API void spDestroyCompileRequest(
+ SlangCompileRequest* request)
+{
+ if(!request) return;
+ auto req = REQ(request);
+ delete req;
+}
+
+SLANG_API void spSetCompileFlags(
+ SlangCompileRequest* request,
+ SlangCompileFlags flags)
+{
+ REQ(request)->Options.flags = flags;
+}
+
+SLANG_API void spSetCodeGenTarget(
+ SlangCompileRequest* request,
+ int target)
+{
+ REQ(request)->Options.Target = (CodeGenTarget)target;
+}
+
+SLANG_API void spSetPassThrough(
+ SlangCompileRequest* request,
+ SlangPassThrough passThrough)
+{
+ REQ(request)->Options.passThrough = PassThroughMode(passThrough);
+}
+
+SLANG_API void spSetDiagnosticCallback(
+ SlangCompileRequest* request,
+ SlangDiagnosticCallback callback,
+ void const* userData)
+{
+ if(!request) return;
+ auto req = REQ(request);
+
+ req->mSink.callback = callback;
+ req->mSink.callbackUserData = (void*) userData;
+}
+
+SLANG_API void spAddSearchPath(
+ SlangCompileRequest* request,
+ const char* searchDir)
+{
+ REQ(request)->Options.SearchDirectories.Add(searchDir);
+}
+
+SLANG_API void spAddPreprocessorDefine(
+ SlangCompileRequest* request,
+ const char* key,
+ const char* value)
+{
+ REQ(request)->Options.PreprocessorDefinitions[key] = value;
+}
+
+SLANG_API char const* spGetDiagnosticOutput(
+ SlangCompileRequest* request)
+{
+ if(!request) return 0;
+ auto req = REQ(request);
+ return req->mDiagnosticOutput.begin();
+}
+
+// New-fangled compilation API
+
+SLANG_API int spAddTranslationUnit(
+ SlangCompileRequest* request,
+ SlangSourceLanguage language,
+ char const* name)
+{
+ auto req = REQ(request);
+
+ return req->addTranslationUnit(
+ SourceLanguage(language),
+ name ? name : "");
+}
+
+SLANG_API void spAddTranslationUnitSourceFile(
+ SlangCompileRequest* request,
+ int translationUnitIndex,
+ char const* path)
+{
+ if(!request) return;
+ auto req = REQ(request);
+ if(!path) return;
+ if(translationUnitIndex < 0) return;
+ if(translationUnitIndex >= req->Options.translationUnits.Count()) return;
+
+ req->addTranslationUnitSourceFile(
+ translationUnitIndex,
+ path);
+}
+
+// Add a source string to the given translation unit
+SLANG_API void spAddTranslationUnitSourceString(
+ SlangCompileRequest* request,
+ int translationUnitIndex,
+ char const* path,
+ char const* source)
+{
+ if(!request) return;
+ auto req = REQ(request);
+ if(!source) return;
+ if(translationUnitIndex < 0) return;
+ if(translationUnitIndex >= req->Options.translationUnits.Count()) return;
+
+ if(!path) path = "";
+
+ req->addTranslationUnitSourceString(
+ translationUnitIndex,
+ path,
+ source);
+
+}
+
+SLANG_API SlangProfileID spFindProfile(
+ SlangSession* session,
+ char const* name)
+{
+ return Profile::LookUp(name).raw;
+}
+
+SLANG_API int spAddTranslationUnitEntryPoint(
+ SlangCompileRequest* request,
+ int translationUnitIndex,
+ char const* name,
+ SlangProfileID profile)
+{
+ if(!request) return -1;
+ auto req = REQ(request);
+ if(!name) return -1;
+ if(translationUnitIndex < 0) return -1;
+ if(translationUnitIndex >= req->Options.translationUnits.Count()) return -1;
+
+
+ return req->addTranslationUnitEntryPoint(
+ translationUnitIndex,
+ name,
+ Profile(Profile::RawVal(profile)));
+}
+
+
+// Compile in a context that already has its translation units specified
+SLANG_API int spCompile(
+ SlangCompileRequest* request)
+{
+ auto req = REQ(request);
+
+ int anyErrors = req->executeAPIActions();
+ return anyErrors;
+}
+
+SLANG_API int
+spGetDependencyFileCount(
+ SlangCompileRequest* request)
+{
+ if(!request) return 0;
+ auto req = REQ(request);
+ return req->mDependencyFilePaths.Count();
+}
+
+/** Get the path to a file this compilation dependend on.
+*/
+SLANG_API char const*
+spGetDependencyFilePath(
+ SlangCompileRequest* request,
+ int index)
+{
+ if(!request) return 0;
+ auto req = REQ(request);
+ return req->mDependencyFilePaths[index].begin();
+}
+
+SLANG_API int
+spGetTranslationUnitCount(
+ SlangCompileRequest* request)
+{
+ auto req = REQ(request);
+ return req->mResult.translationUnits.Count();
+}
+
+// Get the output code associated with a specific translation unit
+SLANG_API char const* spGetTranslationUnitSource(
+ SlangCompileRequest* request,
+ int translationUnitIndex)
+{
+ auto req = REQ(request);
+ return req->mResult.translationUnits[translationUnitIndex].outputSource.Buffer();
+}
+
+SLANG_API char const* spGetEntryPointSource(
+ SlangCompileRequest* request,
+ int translationUnitIndex,
+ int entryPointIndex)
+{
+ auto req = REQ(request);
+ return req->mResult.translationUnits[translationUnitIndex].entryPoints[entryPointIndex].outputSource.Buffer();
+
+}
+
+// Reflection API
+
+SLANG_API SlangReflection* spGetReflection(
+ SlangCompileRequest* request)
+{
+ if( !request ) return 0;
+
+ auto req = REQ(request);
+ return (SlangReflection*) req->mReflectionData.Ptr();
+}
+
+
+// ... rest of reflection API implementation is in `Reflection.cpp`
diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis
new file mode 100644
index 000000000..bfc1e7317
--- /dev/null
+++ b/source/slang/slang.natvis
@@ -0,0 +1,14 @@
+<?xml version="1.0" encoding="utf-8"?>
+<AutoVisualizer xmlns="http://schemas.microsoft.com/vstudio/debugger/natvis/2010">
+ <Type Name="Slang::Compiler::CFGNode">
+ <DisplayString>{{CFG Basic Block}}</DisplayString>
+ <Expand>
+ <LinkedListItems>
+ <Size>kvPairs.FCount</Size>
+ <HeadPointer>kvPairs.FHead</HeadPointer>
+ <NextPointer>pNext</NextPointer>
+ <ValueNode>Value</ValueNode>
+ </LinkedListItems>
+ </Expand>
+ </Type>
+</AutoVisualizer> \ No newline at end of file
diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj
new file mode 100644
index 000000000..df34e40dc
--- /dev/null
+++ b/source/slang/slang.vcxproj
@@ -0,0 +1,427 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project DefaultTargets="Build" ToolsVersion="14.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup Label="ProjectConfigurations">
+ <ProjectConfiguration Include="DebugClang|Win32">
+ <Configuration>DebugClang</Configuration>
+ <Platform>Win32</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="DebugClang|x64">
+ <Configuration>DebugClang</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Debug_VS2013|Win32">
+ <Configuration>Debug_VS2013</Configuration>
+ <Platform>Win32</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Debug_VS2013|x64">
+ <Configuration>Debug_VS2013</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Debug|Win32">
+ <Configuration>Debug</Configuration>
+ <Platform>Win32</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Debug|x64">
+ <Configuration>Debug</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release_VS2013|Win32">
+ <Configuration>Release_VS2013</Configuration>
+ <Platform>Win32</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release_VS2013|x64">
+ <Configuration>Release_VS2013</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release|Win32">
+ <Configuration>Release</Configuration>
+ <Platform>Win32</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release|x64">
+ <Configuration>Release</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ </ItemGroup>
+ <PropertyGroup Label="Globals">
+ <ProjectGuid>{DB00DA62-0533-4AFD-B59F-A67D5B3A0808}</ProjectGuid>
+ <Keyword>Win32Proj</Keyword>
+ <RootNamespace>SpireCore</RootNamespace>
+ <ProjectName>slang</ProjectName>
+ <WindowsTargetPlatformVersion>8.1</WindowsTargetPlatformVersion>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
+ <ConfigurationType>StaticLibrary</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <PlatformToolset>v140</PlatformToolset>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug_VS2013|Win32'" Label="Configuration">
+ <ConfigurationType>StaticLibrary</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <PlatformToolset>v120</PlatformToolset>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='DebugClang|Win32'" Label="Configuration">
+ <ConfigurationType>StaticLibrary</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <PlatformToolset>v140_clang_3_7</PlatformToolset>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
+ <ConfigurationType>StaticLibrary</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <PlatformToolset>v140</PlatformToolset>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug_VS2013|x64'" Label="Configuration">
+ <ConfigurationType>StaticLibrary</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <PlatformToolset>v120</PlatformToolset>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='DebugClang|x64'" Label="Configuration">
+ <ConfigurationType>StaticLibrary</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <PlatformToolset>v140_Clang_3_7</PlatformToolset>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
+ <ConfigurationType>StaticLibrary</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v140</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release_VS2013|Win32'" Label="Configuration">
+ <ConfigurationType>StaticLibrary</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v120</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
+ <ConfigurationType>StaticLibrary</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v140</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release_VS2013|x64'" Label="Configuration">
+ <ConfigurationType>StaticLibrary</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v120</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
+ <ImportGroup Label="ExtensionSettings">
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ <Import Project="..\..\build\slang-build.props" />
+ </ImportGroup>
+ <ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug_VS2013|Win32'" Label="PropertySheets">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ <Import Project="..\..\build\slang-build.props" />
+ </ImportGroup>
+ <ImportGroup Condition="'$(Configuration)|$(Platform)'=='DebugClang|Win32'" Label="PropertySheets">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ <Import Project="..\..\build\slang-build.props" />
+ </ImportGroup>
+ <ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="PropertySheets">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ <Import Project="..\..\build\slang-build.props" />
+ </ImportGroup>
+ <ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug_VS2013|x64'" Label="PropertySheets">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ <Import Project="..\..\build\slang-build.props" />
+ </ImportGroup>
+ <ImportGroup Condition="'$(Configuration)|$(Platform)'=='DebugClang|x64'" Label="PropertySheets">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ <Import Project="..\..\build\slang-build.props" />
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ <Import Project="..\..\build\slang-build.props" />
+ </ImportGroup>
+ <ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release_VS2013|Win32'" Label="PropertySheets">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ <Import Project="..\..\build\slang-build.props" />
+ </ImportGroup>
+ <ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="PropertySheets">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ <Import Project="..\..\build\slang-build.props" />
+ </ImportGroup>
+ <ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release_VS2013|x64'" Label="PropertySheets">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ <Import Project="..\..\build\slang-build.props" />
+ </ImportGroup>
+ <PropertyGroup Label="UserMacros" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
+ <LinkIncremental>true</LinkIncremental>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug_VS2013|Win32'">
+ <LinkIncremental>true</LinkIncremental>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='DebugClang|Win32'">
+ <LinkIncremental>true</LinkIncremental>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <LinkIncremental>true</LinkIncremental>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug_VS2013|x64'">
+ <LinkIncremental>true</LinkIncremental>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='DebugClang|x64'">
+ <LinkIncremental>true</LinkIncremental>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
+ <LinkIncremental>false</LinkIncremental>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release_VS2013|Win32'">
+ <LinkIncremental>false</LinkIncremental>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <LinkIncremental>false</LinkIncremental>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release_VS2013|x64'">
+ <LinkIncremental>false</LinkIncremental>
+ </PropertyGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
+ <ClCompile>
+ <PrecompiledHeader>
+ </PrecompiledHeader>
+ <WarningLevel>Level4</WarningLevel>
+ <Optimization>Disabled</Optimization>
+ <PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories>
+ <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary>
+ <MultiProcessorCompilation>false</MultiProcessorCompilation>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug_VS2013|Win32'">
+ <ClCompile>
+ <PrecompiledHeader>
+ </PrecompiledHeader>
+ <WarningLevel>Level4</WarningLevel>
+ <Optimization>Disabled</Optimization>
+ <PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories>
+ <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary>
+ <MultiProcessorCompilation>false</MultiProcessorCompilation>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='DebugClang|Win32'">
+ <ClCompile>
+ <PrecompiledHeader>
+ </PrecompiledHeader>
+ <WarningLevel>EnableAllWarnings</WarningLevel>
+ <Optimization>Disabled</Optimization>
+ <PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories>
+ <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary>
+ <MultiProcessorCompilation>false</MultiProcessorCompilation>
+ <RuntimeTypeInfo>true</RuntimeTypeInfo>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <ClCompile>
+ <PrecompiledHeader>
+ </PrecompiledHeader>
+ <WarningLevel>Level4</WarningLevel>
+ <Optimization>Disabled</Optimization>
+ <PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories>
+ <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary>
+ <BrowseInformation>true</BrowseInformation>
+ <MultiProcessorCompilation>false</MultiProcessorCompilation>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ <Bscmake>
+ <PreserveSbr>true</PreserveSbr>
+ </Bscmake>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug_VS2013|x64'">
+ <ClCompile>
+ <PrecompiledHeader>
+ </PrecompiledHeader>
+ <WarningLevel>Level4</WarningLevel>
+ <Optimization>Disabled</Optimization>
+ <PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories>
+ <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary>
+ <BrowseInformation>true</BrowseInformation>
+ <MultiProcessorCompilation>false</MultiProcessorCompilation>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ <Bscmake>
+ <PreserveSbr>true</PreserveSbr>
+ </Bscmake>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='DebugClang|x64'">
+ <ClCompile>
+ <PrecompiledHeader>
+ </PrecompiledHeader>
+ <WarningLevel>Level4</WarningLevel>
+ <Optimization>Disabled</Optimization>
+ <PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories>
+ <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary>
+ <BrowseInformation>true</BrowseInformation>
+ <MultiProcessorCompilation>false</MultiProcessorCompilation>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ <Bscmake>
+ <PreserveSbr>true</PreserveSbr>
+ </Bscmake>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <PrecompiledHeader>
+ </PrecompiledHeader>
+ <Optimization>MaxSpeed</Optimization>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories>
+ <RuntimeLibrary>MultiThreaded</RuntimeLibrary>
+ <MultiProcessorCompilation>false</MultiProcessorCompilation>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release_VS2013|Win32'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <PrecompiledHeader>
+ </PrecompiledHeader>
+ <Optimization>MaxSpeed</Optimization>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories>
+ <RuntimeLibrary>MultiThreaded</RuntimeLibrary>
+ <MultiProcessorCompilation>false</MultiProcessorCompilation>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <ClCompile>
+ <WarningLevel>Level4</WarningLevel>
+ <PrecompiledHeader>
+ </PrecompiledHeader>
+ <Optimization>MaxSpeed</Optimization>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories>
+ <RuntimeLibrary>MultiThreaded</RuntimeLibrary>
+ <MultiProcessorCompilation>false</MultiProcessorCompilation>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release_VS2013|x64'">
+ <ClCompile>
+ <WarningLevel>Level4</WarningLevel>
+ <PrecompiledHeader>
+ </PrecompiledHeader>
+ <Optimization>MaxSpeed</Optimization>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <AdditionalIncludeDirectories>../</AdditionalIncludeDirectories>
+ <RuntimeLibrary>MultiThreaded</RuntimeLibrary>
+ <MultiProcessorCompilation>false</MultiProcessorCompilation>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemGroup>
+ <Natvis Include="slang.natvis" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClInclude Include="compiled-program.h" />
+ <ClInclude Include="compiler.h" />
+ <ClInclude Include="diagnostic-defs.h" />
+ <ClInclude Include="diagnostics.h" />
+ <ClInclude Include="emit.h" />
+ <ClInclude Include="intrinsic-defs.h" />
+ <ClInclude Include="lexer.h" />
+ <ClInclude Include="lookup.h" />
+ <ClInclude Include="parameter-binding.h" />
+ <ClInclude Include="parser.h" />
+ <ClInclude Include="preprocessor.h" />
+ <ClInclude Include="profile-defs.h" />
+ <ClInclude Include="profile.h" />
+ <ClInclude Include="reflection.h" />
+ <ClInclude Include="slang-stdlib.h" />
+ <ClInclude Include="source-loc.h" />
+ <ClInclude Include="syntax-visitors.h" />
+ <ClInclude Include="syntax.h" />
+ <ClInclude Include="token-defs.h" />
+ <ClInclude Include="token.h" />
+ <ClInclude Include="type-layout.h" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClCompile Include="check.cpp" />
+ <ClCompile Include="compiler.cpp" />
+ <ClCompile Include="diagnostics.cpp" />
+ <ClCompile Include="emit.cpp" />
+ <ClCompile Include="lexer.cpp" />
+ <ClCompile Include="lookup.cpp" />
+ <ClCompile Include="parameter-binding.cpp" />
+ <ClCompile Include="parser.cpp" />
+ <ClCompile Include="preprocessor.cpp" />
+ <ClCompile Include="profile.cpp" />
+ <ClCompile Include="reflection.cpp" />
+ <ClCompile Include="slang-stdlib.cpp" />
+ <ClCompile Include="slang.cpp" />
+ <ClCompile Include="syntax.cpp" />
+ <ClCompile Include="token.cpp" />
+ <ClCompile Include="type-layout.cpp" />
+ </ItemGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
+ <ImportGroup Label="ExtensionTargets">
+ </ImportGroup>
+</Project> \ No newline at end of file
diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters
new file mode 100644
index 000000000..dac58bbf7
--- /dev/null
+++ b/source/slang/slang.vcxproj.filters
@@ -0,0 +1,48 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup>
+ <Natvis Include="NatvisFile.natvis" />
+ <Natvis Include="slang.natvis" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClInclude Include="compiled-program.h" />
+ <ClInclude Include="compiler.h" />
+ <ClInclude Include="diagnostic-defs.h" />
+ <ClInclude Include="diagnostics.h" />
+ <ClInclude Include="emit.h" />
+ <ClInclude Include="intrinsic-defs.h" />
+ <ClInclude Include="lexer.h" />
+ <ClInclude Include="lookup.h" />
+ <ClInclude Include="parameter-binding.h" />
+ <ClInclude Include="parser.h" />
+ <ClInclude Include="preprocessor.h" />
+ <ClInclude Include="profile.h" />
+ <ClInclude Include="profile-defs.h" />
+ <ClInclude Include="reflection.h" />
+ <ClInclude Include="slang-stdlib.h" />
+ <ClInclude Include="source-loc.h" />
+ <ClInclude Include="syntax.h" />
+ <ClInclude Include="syntax-visitors.h" />
+ <ClInclude Include="token.h" />
+ <ClInclude Include="token-defs.h" />
+ <ClInclude Include="type-layout.h" />
+ </ItemGroup>
+ <ItemGroup>
+ <ClCompile Include="check.cpp" />
+ <ClCompile Include="compiler.cpp" />
+ <ClCompile Include="diagnostics.cpp" />
+ <ClCompile Include="emit.cpp" />
+ <ClCompile Include="lexer.cpp" />
+ <ClCompile Include="lookup.cpp" />
+ <ClCompile Include="parameter-binding.cpp" />
+ <ClCompile Include="parser.cpp" />
+ <ClCompile Include="preprocessor.cpp" />
+ <ClCompile Include="profile.cpp" />
+ <ClCompile Include="reflection.cpp" />
+ <ClCompile Include="slang.cpp" />
+ <ClCompile Include="slang-stdlib.cpp" />
+ <ClCompile Include="syntax.cpp" />
+ <ClCompile Include="token.cpp" />
+ <ClCompile Include="type-layout.cpp" />
+ </ItemGroup>
+</Project> \ No newline at end of file
diff --git a/source/slang/source-loc.h b/source/slang/source-loc.h
new file mode 100644
index 000000000..dc353f402
--- /dev/null
+++ b/source/slang/source-loc.h
@@ -0,0 +1,47 @@
+// source-loc.h
+#ifndef SLANG_SOURCE_LOC_H_INCLUDED
+#define SLANG_SOURCE_LOC_H_INCLUDED
+
+#include "../core/basic.h"
+
+namespace Slang {
+namespace Compiler {
+
+using namespace CoreLib::Basic;
+
+class CodePosition
+{
+public:
+ int Line = -1, Col = -1, Pos = -1;
+ String FileName;
+ String ToString()
+ {
+ StringBuilder sb(100);
+ sb << FileName;
+ if (Line != -1)
+ sb << "(" << Line << ")";
+ return sb.ProduceString();
+ }
+ CodePosition() = default;
+ CodePosition(int line, int col, int pos, String fileName)
+ {
+ Line = line;
+ Col = col;
+ Pos = pos;
+ this->FileName = fileName;
+ }
+ bool operator < (const CodePosition & pos) const
+ {
+ return FileName < pos.FileName || (FileName == pos.FileName && Line < pos.Line) ||
+ (FileName == pos.FileName && Line == pos.Line && Col < pos.Col);
+ }
+ bool operator == (const CodePosition & pos) const
+ {
+ return FileName == pos.FileName && Line == pos.Line && Col == pos.Col;
+ }
+};
+
+
+}}
+
+#endif
diff --git a/source/slang/syntax-visitors.h b/source/slang/syntax-visitors.h
new file mode 100644
index 000000000..565ac3ace
--- /dev/null
+++ b/source/slang/syntax-visitors.h
@@ -0,0 +1,21 @@
+#ifndef RASTER_RENDERER_SYNTAX_PRINTER_H
+#define RASTER_RENDERER_SYNTAX_PRINTER_H
+
+#include "diagnostics.h"
+#include "syntax.h"
+#include "compiled-program.h"
+
+namespace Slang
+{
+ namespace Compiler
+ {
+ class CompileOptions;
+ class ShaderCompiler;
+ class ShaderLinkInfo;
+ class ShaderSymbol;
+
+ SyntaxVisitor * CreateSemanticsVisitor(DiagnosticSink * err, CompileOptions const& options);
+ }
+}
+
+#endif \ No newline at end of file
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
new file mode 100644
index 000000000..3050afff6
--- /dev/null
+++ b/source/slang/syntax.cpp
@@ -0,0 +1,1484 @@
+#include "syntax.h"
+#include "syntax-visitors.h"
+#include <typeinfo>
+#include <assert.h>
+
+namespace Slang
+{
+ namespace Compiler
+ {
+ // BasicExpressionType
+
+ bool BasicExpressionType::EqualsImpl(ExpressionType * type)
+ {
+ auto basicType = dynamic_cast<const BasicExpressionType*>(type);
+ if (basicType == nullptr)
+ return false;
+ return basicType->BaseType == BaseType;
+ }
+
+ ExpressionType* BasicExpressionType::CreateCanonicalType()
+ {
+ // A basic type is already canonical, in our setup
+ return this;
+ }
+
+ CoreLib::Basic::String BasicExpressionType::ToString()
+ {
+ CoreLib::Basic::StringBuilder res;
+
+ switch (BaseType)
+ {
+ case Compiler::BaseType::Int:
+ res.Append("int");
+ break;
+ case Compiler::BaseType::UInt:
+ res.Append("uint");
+ break;
+ case Compiler::BaseType::UInt64:
+ res.Append("uint64_t");
+ break;
+ case Compiler::BaseType::Bool:
+ res.Append("bool");
+ break;
+ case Compiler::BaseType::Float:
+ res.Append("float");
+ break;
+ case Compiler::BaseType::Void:
+ res.Append("void");
+ break;
+ default:
+ break;
+ }
+ return res.ProduceString();
+ }
+
+ RefPtr<SyntaxNode> ProgramSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitProgram(this);
+ }
+
+ RefPtr<SyntaxNode> FunctionSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitFunction(this);
+ }
+
+ //
+
+ RefPtr<SyntaxNode> ScopeDecl::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitScopeDecl(this);
+ }
+
+ //
+
+ RefPtr<SyntaxNode> BlockStatementSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitBlockStatement(this);
+ }
+
+ RefPtr<SyntaxNode> BreakStatementSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitBreakStatement(this);
+ }
+
+ RefPtr<SyntaxNode> ContinueStatementSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitContinueStatement(this);
+ }
+
+ RefPtr<SyntaxNode> DoWhileStatementSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitDoWhileStatement(this);
+ }
+
+ RefPtr<SyntaxNode> EmptyStatementSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitEmptyStatement(this);
+ }
+
+ RefPtr<SyntaxNode> ForStatementSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitForStatement(this);
+ }
+
+ RefPtr<SyntaxNode> IfStatementSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitIfStatement(this);
+ }
+
+ RefPtr<SyntaxNode> ReturnStatementSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitReturnStatement(this);
+ }
+
+ RefPtr<SyntaxNode> VarDeclrStatementSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitVarDeclrStatement(this);
+ }
+
+ RefPtr<SyntaxNode> Variable::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitDeclrVariable(this);
+ }
+
+ RefPtr<SyntaxNode> WhileStatementSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitWhileStatement(this);
+ }
+
+ RefPtr<SyntaxNode> ExpressionStatementSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitExpressionStatement(this);
+ }
+
+ RefPtr<SyntaxNode> ConstantExpressionSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitConstantExpression(this);
+ }
+
+ RefPtr<SyntaxNode> IndexExpressionSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitIndexExpression(this);
+ }
+ RefPtr<SyntaxNode> MemberExpressionSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitMemberExpression(this);
+ }
+
+ // SwizzleExpr
+
+ RefPtr<SyntaxNode> SwizzleExpr::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitSwizzleExpression(this);
+ }
+
+ // DerefExpr
+
+ RefPtr<SyntaxNode> DerefExpr::Accept(SyntaxVisitor * /*visitor*/)
+ {
+ // throw "unimplemented";
+ return this;
+ }
+
+ //
+
+ RefPtr<SyntaxNode> InvokeExpressionSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitInvokeExpression(this);
+ }
+
+ RefPtr<SyntaxNode> TypeCastExpressionSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitTypeCastExpression(this);
+ }
+
+ RefPtr<SyntaxNode> VarExpressionSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitVarExpression(this);
+ }
+
+ // OverloadedExpr
+
+ RefPtr<SyntaxNode> OverloadedExpr::Accept(SyntaxVisitor * /*visitor*/)
+ {
+// throw "unimplemented";
+ return this;
+ }
+
+ //
+
+ RefPtr<SyntaxNode> ParameterSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitParameter(this);
+ }
+
+ // UsingFileDecl
+
+ RefPtr<SyntaxNode> UsingFileDecl::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitUsingFileDecl(this);
+ }
+
+ //
+
+ RefPtr<SyntaxNode> StructField::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitStructField(this);
+ }
+ RefPtr<SyntaxNode> StructSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitStruct(this);
+ }
+ RefPtr<SyntaxNode> ClassSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitClass(this);
+ }
+ RefPtr<SyntaxNode> TypeDefDecl::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitTypeDefDecl(this);
+ }
+
+ RefPtr<SyntaxNode> DiscardStatementSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitDiscardStatement(this);
+ }
+
+ // BasicExpressionType
+
+ BasicExpressionType* BasicExpressionType::GetScalarType()
+ {
+ return this;
+ }
+
+ //
+
+ bool ExpressionType::Equals(ExpressionType * type)
+ {
+ return GetCanonicalType()->EqualsImpl(type->GetCanonicalType());
+ }
+
+ bool ExpressionType::Equals(RefPtr<ExpressionType> type)
+ {
+ return Equals(type.Ptr());
+ }
+
+ bool ExpressionType::EqualsVal(Val* val)
+ {
+ if (auto type = dynamic_cast<ExpressionType*>(val))
+ return const_cast<ExpressionType*>(this)->Equals(type);
+ return false;
+ }
+
+ NamedExpressionType* ExpressionType::AsNamedType()
+ {
+ return dynamic_cast<NamedExpressionType*>(this);
+ }
+
+ RefPtr<Val> ExpressionType::SubstituteImpl(Substitutions* subst, int* ioDiff)
+ {
+ int diff = 0;
+ auto canSubst = GetCanonicalType()->SubstituteImpl(subst, &diff);
+
+ // If nothing changed, then don't drop any sugar that is applied
+ if (!diff)
+ return this;
+
+ // If the canonical type changed, then we return a canonical type,
+ // rather than try to re-construct any amount of sugar
+ (*ioDiff)++;
+ return canSubst;
+ }
+
+
+ ExpressionType* ExpressionType::GetCanonicalType()
+ {
+ if (!this) return nullptr;
+ ExpressionType* et = const_cast<ExpressionType*>(this);
+ if (!et->canonicalType)
+ {
+ // TODO(tfoley): worry about thread safety here?
+ et->canonicalType = et->CreateCanonicalType();
+ assert(et->canonicalType);
+ }
+ return et->canonicalType;
+ }
+
+ bool ExpressionType::IsTextureOrSampler()
+ {
+ return IsTexture() || IsSampler();
+ }
+ bool ExpressionType::IsStruct()
+ {
+ auto declRefType = AsDeclRefType();
+ if (!declRefType) return false;
+ auto structDeclRef = declRefType->declRef.As<StructDeclRef>();
+ if (!structDeclRef) return false;
+ return true;
+ }
+
+ bool ExpressionType::IsClass()
+ {
+ auto declRefType = AsDeclRefType();
+ if (!declRefType) return false;
+ auto classDeclRef = declRefType->declRef.As<ClassDeclRef>();
+ if (!classDeclRef) return false;
+ return true;
+ }
+
+#if 0
+ RefPtr<ExpressionType> ExpressionType::Bool;
+ RefPtr<ExpressionType> ExpressionType::UInt;
+ RefPtr<ExpressionType> ExpressionType::Int;
+ RefPtr<ExpressionType> ExpressionType::Float;
+ RefPtr<ExpressionType> ExpressionType::Float2;
+ RefPtr<ExpressionType> ExpressionType::Void;
+#endif
+ RefPtr<ExpressionType> ExpressionType::Error;
+ RefPtr<ExpressionType> ExpressionType::initializerListType;
+ RefPtr<ExpressionType> ExpressionType::Overloaded;
+
+ Dictionary<int, RefPtr<ExpressionType>> ExpressionType::sBuiltinTypes;
+ Dictionary<String, Decl*> ExpressionType::sMagicDecls;
+ List<RefPtr<ExpressionType>> ExpressionType::sCanonicalTypes;
+
+ void ExpressionType::Init()
+ {
+ Error = new ErrorType();
+ initializerListType = new InitializerListType();
+ Overloaded = new OverloadGroupType();
+ }
+ void ExpressionType::Finalize()
+ {
+ Error = nullptr;
+ initializerListType = nullptr;
+ Overloaded = nullptr;
+ // Note(tfoley): This seems to be just about the only way to clear out a List<T>
+ sCanonicalTypes = List<RefPtr<ExpressionType>>();
+ sBuiltinTypes = Dictionary<int, RefPtr<ExpressionType>>();
+ sMagicDecls = Dictionary<String, Decl*>();
+ }
+ bool ArrayExpressionType::EqualsImpl(ExpressionType * type)
+ {
+ auto arrType = type->AsArrayType();
+ if (!arrType)
+ return false;
+ return (ArrayLength == arrType->ArrayLength && BaseType->Equals(arrType->BaseType.Ptr()));
+ }
+ ExpressionType* ArrayExpressionType::CreateCanonicalType()
+ {
+ auto canonicalBaseType = BaseType->GetCanonicalType();
+ auto canonicalArrayType = new ArrayExpressionType();
+ sCanonicalTypes.Add(canonicalArrayType);
+ canonicalArrayType->BaseType = canonicalBaseType;
+ canonicalArrayType->ArrayLength = ArrayLength;
+ return canonicalArrayType;
+ }
+ int ArrayExpressionType::GetHashCode()
+ {
+ if (ArrayLength)
+ return (BaseType->GetHashCode() * 16777619) ^ ArrayLength->GetHashCode();
+ else
+ return BaseType->GetHashCode();
+ }
+ CoreLib::Basic::String ArrayExpressionType::ToString()
+ {
+ if (ArrayLength)
+ return BaseType->ToString() + "[" + ArrayLength->ToString() + "]";
+ else
+ return BaseType->ToString() + "[]";
+ }
+ RefPtr<SyntaxNode> GenericAppExpr::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitGenericApp(this);
+ }
+
+ // DeclRefType
+
+ String DeclRefType::ToString()
+ {
+ return declRef.GetName();
+ }
+
+ int DeclRefType::GetHashCode()
+ {
+ return (declRef.GetHashCode() * 16777619) ^ (int)(typeid(this).hash_code());
+ }
+
+ bool DeclRefType::EqualsImpl(ExpressionType * type)
+ {
+ if (auto declRefType = type->AsDeclRefType())
+ {
+ return declRef.Equals(declRefType->declRef);
+ }
+ return false;
+ }
+
+ ExpressionType* DeclRefType::CreateCanonicalType()
+ {
+ // A declaration reference is already canonical
+ return this;
+ }
+
+ RefPtr<Val> DeclRefType::SubstituteImpl(Substitutions* subst, int* ioDiff)
+ {
+ if (!subst) return this;
+
+ // the case we especially care about is when this type references a declaration
+ // of a generic parameter, since that is what we might be substituting...
+ if (auto genericTypeParamDecl = dynamic_cast<GenericTypeParamDecl*>(declRef.GetDecl()))
+ {
+ // search for a substitution that might apply to us
+ for (auto s = subst; s; s = s->outer.Ptr())
+ {
+ // the generic decl associated with the substitution list must be
+ // the generic decl that declared this parameter
+ auto genericDecl = s->genericDecl;
+ if (genericDecl != genericTypeParamDecl->ParentDecl)
+ continue;
+
+ int index = 0;
+ for (auto m : genericDecl->Members)
+ {
+ if (m.Ptr() == genericTypeParamDecl)
+ {
+ // We've found it, so return the corresponding specialization argument
+ (*ioDiff)++;
+ return s->args[index];
+ }
+ else if(auto typeParam = m.As<GenericTypeParamDecl>())
+ {
+ index++;
+ }
+ else if(auto valParam = m.As<GenericValueParamDecl>())
+ {
+ index++;
+ }
+ else
+ {
+ }
+ }
+
+ }
+ }
+
+
+ int diff = 0;
+ DeclRef substDeclRef = declRef.SubstituteImpl(subst, &diff);
+
+ if (!diff)
+ return this;
+
+ // Make sure to record the difference!
+ *ioDiff += diff;
+
+ // Re-construct the type in case we are using a specialized sub-class
+ return DeclRefType::Create(substDeclRef);
+ }
+
+ static RefPtr<ExpressionType> ExtractGenericArgType(RefPtr<Val> val)
+ {
+ auto type = val.As<ExpressionType>();
+ assert(type.Ptr());
+ return type;
+ }
+
+ static RefPtr<IntVal> ExtractGenericArgInteger(RefPtr<Val> val)
+ {
+ auto intVal = val.As<IntVal>();
+ assert(intVal.Ptr());
+ return intVal;
+ }
+
+ // TODO: need to figure out how to unify this with the logic
+ // in the generic case...
+ DeclRefType* DeclRefType::Create(DeclRef declRef)
+ {
+ if (auto builtinMod = declRef.GetDecl()->FindModifier<BuiltinTypeModifier>())
+ {
+ auto type = new BasicExpressionType(builtinMod->tag);
+ type->declRef = declRef;
+ return type;
+ }
+ else if (auto magicMod = declRef.GetDecl()->FindModifier<MagicTypeModifier>())
+ {
+ Substitutions* subst = declRef.substitutions.Ptr();
+
+ if (magicMod->name == "SamplerState")
+ {
+ auto type = new SamplerStateType();
+ type->declRef = declRef;
+ type->flavor = SamplerStateType::Flavor(magicMod->tag);
+ return type;
+ }
+ else if (magicMod->name == "Vector")
+ {
+ assert(subst && subst->args.Count() == 2);
+ auto vecType = new VectorExpressionType();
+ vecType->declRef = declRef;
+ vecType->elementType = ExtractGenericArgType(subst->args[0]);
+ vecType->elementCount = ExtractGenericArgInteger(subst->args[1]);
+ return vecType;
+ }
+ else if (magicMod->name == "Matrix")
+ {
+ assert(subst && subst->args.Count() == 3);
+ auto matType = new MatrixExpressionType();
+ matType->declRef = declRef;
+ return matType;
+ }
+ else if (magicMod->name == "Texture")
+ {
+ assert(subst && subst->args.Count() >= 1);
+ auto textureType = new TextureType(
+ TextureType::Flavor(magicMod->tag),
+ ExtractGenericArgType(subst->args[0]));
+ textureType->declRef = declRef;
+ return textureType;
+ }
+ else if (magicMod->name == "TextureSampler")
+ {
+ assert(subst && subst->args.Count() >= 1);
+ auto textureType = new TextureSamplerType(
+ TextureType::Flavor(magicMod->tag),
+ ExtractGenericArgType(subst->args[0]));
+ textureType->declRef = declRef;
+ return textureType;
+ }
+ else if (magicMod->name == "GLSLImageType")
+ {
+ assert(subst && subst->args.Count() >= 1);
+ auto textureType = new GLSLImageType(
+ TextureType::Flavor(magicMod->tag),
+ ExtractGenericArgType(subst->args[0]));
+ textureType->declRef = declRef;
+ return textureType;
+ }
+
+ #define CASE(n,T) \
+ else if(magicMod->name == #n) { \
+ assert(subst && subst->args.Count() == 1); \
+ auto type = new T(); \
+ type->elementType = ExtractGenericArgType(subst->args[0]); \
+ type->declRef = declRef; \
+ return type; \
+ }
+
+ CASE(ConstantBuffer, ConstantBufferType)
+ CASE(TextureBuffer, TextureBufferType)
+ CASE(GLSLInputParameterBlockType, GLSLInputParameterBlockType)
+ CASE(GLSLOutputParameterBlockType, GLSLOutputParameterBlockType)
+ CASE(GLSLShaderStorageBufferType, GLSLShaderStorageBufferType)
+
+ CASE(PackedBuffer, PackedBufferType)
+ CASE(Uniform, UniformBufferType)
+ CASE(Patch, PatchType)
+
+ CASE(HLSLBufferType, HLSLBufferType)
+ CASE(HLSLStructuredBufferType, HLSLStructuredBufferType)
+ CASE(HLSLRWBufferType, HLSLRWBufferType)
+ CASE(HLSLRWStructuredBufferType, HLSLRWStructuredBufferType)
+ CASE(HLSLAppendStructuredBufferType, HLSLAppendStructuredBufferType)
+ CASE(HLSLConsumeStructuredBufferType, HLSLConsumeStructuredBufferType)
+ CASE(HLSLInputPatchType, HLSLInputPatchType)
+ CASE(HLSLOutputPatchType, HLSLOutputPatchType)
+
+ CASE(HLSLPointStreamType, HLSLPointStreamType)
+ CASE(HLSLLineStreamType, HLSLPointStreamType)
+ CASE(HLSLTriangleStreamType, HLSLPointStreamType)
+
+ #undef CASE
+
+ // "magic" builtin types which have no generic parameters
+ #define CASE(n,T) \
+ else if(magicMod->name == #n) { \
+ auto type = new T(); \
+ type->declRef = declRef; \
+ return type; \
+ }
+
+ CASE(HLSLByteAddressBufferType, HLSLByteAddressBufferType)
+ CASE(HLSLRWByteAddressBufferType, HLSLRWByteAddressBufferType)
+ CASE(UntypedBufferResourceType, UntypedBufferResourceType)
+
+ CASE(GLSLInputAttachmentType, GLSLInputAttachmentType)
+
+ #undef CASE
+
+ else
+ {
+ throw "unimplemented";
+ }
+ }
+ else
+ {
+ return new DeclRefType(declRef);
+ }
+ }
+
+ // OverloadGroupType
+
+ String OverloadGroupType::ToString()
+ {
+ return "overload group";
+ }
+
+ bool OverloadGroupType::EqualsImpl(ExpressionType * /*type*/)
+ {
+ return false;
+ }
+
+ ExpressionType* OverloadGroupType::CreateCanonicalType()
+ {
+ return this;
+ }
+
+ int OverloadGroupType::GetHashCode()
+ {
+ return (int)(int64_t)(void*)this;
+ }
+
+ // InitializerListType
+
+ String InitializerListType::ToString()
+ {
+ return "initializer list";
+ }
+
+ bool InitializerListType::EqualsImpl(ExpressionType * /*type*/)
+ {
+ return false;
+ }
+
+ ExpressionType* InitializerListType::CreateCanonicalType()
+ {
+ return this;
+ }
+
+ int InitializerListType::GetHashCode()
+ {
+ return (int)(int64_t)(void*)this;
+ }
+
+ // ErrorType
+
+ String ErrorType::ToString()
+ {
+ return "error";
+ }
+
+ bool ErrorType::EqualsImpl(ExpressionType* type)
+ {
+ if (auto errorType = type->As<ErrorType>())
+ return true;
+ return false;
+ }
+
+ ExpressionType* ErrorType::CreateCanonicalType()
+ {
+ return this;
+ }
+
+ int ErrorType::GetHashCode()
+ {
+ return (int)(int64_t)(void*)this;
+ }
+
+
+ // NamedExpressionType
+
+ String NamedExpressionType::ToString()
+ {
+ return declRef.GetName();
+ }
+
+ bool NamedExpressionType::EqualsImpl(ExpressionType * /*type*/)
+ {
+ assert(!"unreachable");
+ return false;
+ }
+
+ ExpressionType* NamedExpressionType::CreateCanonicalType()
+ {
+ return declRef.GetType()->GetCanonicalType();
+ }
+
+ int NamedExpressionType::GetHashCode()
+ {
+ assert(!"unreachable");
+ return 0;
+ }
+
+ // FuncType
+
+ String FuncType::ToString()
+ {
+ // TODO: a better approach than this
+ if (declRef)
+ return declRef.GetName();
+ else
+ return "/* unknown FuncType */";
+ }
+
+ bool FuncType::EqualsImpl(ExpressionType * type)
+ {
+ if (auto funcType = type->As<FuncType>())
+ {
+ return declRef == funcType->declRef;
+ }
+ return false;
+ }
+
+ ExpressionType* FuncType::CreateCanonicalType()
+ {
+ return this;
+ }
+
+ int FuncType::GetHashCode()
+ {
+ return declRef.GetHashCode();
+ }
+
+ // TypeType
+
+ String TypeType::ToString()
+ {
+ StringBuilder sb;
+ sb << "typeof(" << type->ToString() << ")";
+ return sb.ProduceString();
+ }
+
+ bool TypeType::EqualsImpl(ExpressionType * t)
+ {
+ if (auto typeType = t->As<TypeType>())
+ {
+ return t->Equals(typeType->type);
+ }
+ return false;
+ }
+
+ ExpressionType* TypeType::CreateCanonicalType()
+ {
+ auto canType = new TypeType(type->GetCanonicalType());
+ sCanonicalTypes.Add(canType);
+ return canType;
+ }
+
+ int TypeType::GetHashCode()
+ {
+ assert(!"unreachable");
+ return 0;
+ }
+
+ // GenericDeclRefType
+
+ String GenericDeclRefType::ToString()
+ {
+ // TODO: what is appropriate here?
+ return "<GenericDeclRef>";
+ }
+
+ bool GenericDeclRefType::EqualsImpl(ExpressionType * type)
+ {
+ if (auto genericDeclRefType = type->As<GenericDeclRefType>())
+ {
+ return declRef.Equals(genericDeclRefType->declRef);
+ }
+ return false;
+ }
+
+ int GenericDeclRefType::GetHashCode()
+ {
+ return declRef.GetHashCode();
+ }
+
+ ExpressionType* GenericDeclRefType::CreateCanonicalType()
+ {
+ return this;
+ }
+
+ // ArithmeticExpressionType
+
+ // VectorExpressionType
+
+ String VectorExpressionType::ToString()
+ {
+ StringBuilder sb;
+ sb << "vector<" << elementType->ToString() << "," << elementCount->ToString() << ">";
+ return sb.ProduceString();
+ }
+
+ BasicExpressionType* VectorExpressionType::GetScalarType()
+ {
+ return elementType->AsBasicType();
+ }
+
+ // MatrixExpressionType
+
+ String MatrixExpressionType::ToString()
+ {
+ StringBuilder sb;
+ sb << "matrix<" << getElementType()->ToString() << "," << getRowCount()->ToString() << "," << getColumnCount()->ToString() << ">";
+ return sb.ProduceString();
+ }
+
+ BasicExpressionType* MatrixExpressionType::GetScalarType()
+ {
+ return getElementType()->AsBasicType();
+ }
+
+ ExpressionType* MatrixExpressionType::getElementType()
+ {
+ return this->declRef.substitutions->args[0].As<ExpressionType>().Ptr();
+ }
+
+ IntVal* MatrixExpressionType::getRowCount()
+ {
+ return this->declRef.substitutions->args[1].As<IntVal>().Ptr();
+ }
+
+ IntVal* MatrixExpressionType::getColumnCount()
+ {
+ return this->declRef.substitutions->args[2].As<IntVal>().Ptr();
+ }
+
+ //
+
+#if 0
+ String GetOperatorFunctionName(Operator op)
+ {
+ switch (op)
+ {
+ case Operator::Add:
+ case Operator::AddAssign:
+ return "+";
+ case Operator::Sub:
+ case Operator::SubAssign:
+ return "-";
+ case Operator::Neg:
+ return "-";
+ case Operator::Not:
+ return "!";
+ case Operator::BitNot:
+ return "~";
+ case Operator::PreInc:
+ case Operator::PostInc:
+ return "++";
+ case Operator::PreDec:
+ case Operator::PostDec:
+ return "--";
+ case Operator::Mul:
+ case Operator::MulAssign:
+ return "*";
+ case Operator::Div:
+ case Operator::DivAssign:
+ return "/";
+ case Operator::Mod:
+ case Operator::ModAssign:
+ return "%";
+ case Operator::Lsh:
+ case Operator::LshAssign:
+ return "<<";
+ case Operator::Rsh:
+ case Operator::RshAssign:
+ return ">>";
+ case Operator::Eql:
+ return "==";
+ case Operator::Neq:
+ return "!=";
+ case Operator::Greater:
+ return ">";
+ case Operator::Less:
+ return "<";
+ case Operator::Geq:
+ return ">=";
+ case Operator::Leq:
+ return "<=";
+ case Operator::BitAnd:
+ case Operator::AndAssign:
+ return "&";
+ case Operator::BitXor:
+ case Operator::XorAssign:
+ return "^";
+ case Operator::BitOr:
+ case Operator::OrAssign:
+ return "|";
+ case Operator::And:
+ return "&&";
+ case Operator::Or:
+ return "||";
+ case Operator::Sequence:
+ return ",";
+ case Operator::Select:
+ return "?:";
+ case Operator::Assign:
+ return "=";
+ default:
+ return "";
+ }
+ }
+#endif
+ String OperatorToString(Operator op)
+ {
+ switch (op)
+ {
+ case Slang::Compiler::Operator::Neg:
+ return "-";
+ case Slang::Compiler::Operator::Not:
+ return "!";
+ case Slang::Compiler::Operator::PreInc:
+ return "++";
+ case Slang::Compiler::Operator::PreDec:
+ return "--";
+ case Slang::Compiler::Operator::PostInc:
+ return "++";
+ case Slang::Compiler::Operator::PostDec:
+ return "--";
+ case Slang::Compiler::Operator::Mul:
+ case Slang::Compiler::Operator::MulAssign:
+ return "*";
+ case Slang::Compiler::Operator::Div:
+ case Slang::Compiler::Operator::DivAssign:
+ return "/";
+ case Slang::Compiler::Operator::Mod:
+ case Slang::Compiler::Operator::ModAssign:
+ return "%";
+ case Slang::Compiler::Operator::Add:
+ case Slang::Compiler::Operator::AddAssign:
+ return "+";
+ case Slang::Compiler::Operator::Sub:
+ case Slang::Compiler::Operator::SubAssign:
+ return "-";
+ case Slang::Compiler::Operator::Lsh:
+ case Slang::Compiler::Operator::LshAssign:
+ return "<<";
+ case Slang::Compiler::Operator::Rsh:
+ case Slang::Compiler::Operator::RshAssign:
+ return ">>";
+ case Slang::Compiler::Operator::Eql:
+ return "==";
+ case Slang::Compiler::Operator::Neq:
+ return "!=";
+ case Slang::Compiler::Operator::Greater:
+ return ">";
+ case Slang::Compiler::Operator::Less:
+ return "<";
+ case Slang::Compiler::Operator::Geq:
+ return ">=";
+ case Slang::Compiler::Operator::Leq:
+ return "<=";
+ case Slang::Compiler::Operator::BitAnd:
+ case Slang::Compiler::Operator::AndAssign:
+ return "&";
+ case Slang::Compiler::Operator::BitXor:
+ case Slang::Compiler::Operator::XorAssign:
+ return "^";
+ case Slang::Compiler::Operator::BitOr:
+ case Slang::Compiler::Operator::OrAssign:
+ return "|";
+ case Slang::Compiler::Operator::And:
+ return "&&";
+ case Slang::Compiler::Operator::Or:
+ return "||";
+ case Slang::Compiler::Operator::Assign:
+ return "=";
+ default:
+ return "ERROR";
+ }
+ }
+
+ // TypeExp
+
+ TypeExp TypeExp::Accept(SyntaxVisitor* visitor)
+ {
+ return visitor->VisitTypeExp(*this);
+ }
+
+ // BuiltinTypeModifier
+
+ // MagicTypeModifier
+
+ // GenericDecl
+
+ RefPtr<SyntaxNode> GenericDecl::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitGenericDecl(this);
+ }
+
+ // GenericTypeParamDecl
+
+ RefPtr<SyntaxNode> GenericTypeParamDecl::Accept(SyntaxVisitor * /*visitor*/) {
+ //throw "unimplemented";
+ return this;
+ }
+
+ // GenericTypeConstraintDecl
+
+ RefPtr<SyntaxNode> GenericTypeConstraintDecl::Accept(SyntaxVisitor * visitor)
+ {
+ return this;
+ }
+
+ // GenericValueParamDecl
+
+ RefPtr<SyntaxNode> GenericValueParamDecl::Accept(SyntaxVisitor * /*visitor*/) {
+ //throw "unimplemented";
+ return this;
+ }
+
+ // GenericParamIntVal
+
+ bool GenericParamIntVal::EqualsVal(Val* val)
+ {
+ if (auto genericParamVal = dynamic_cast<GenericParamIntVal*>(val))
+ {
+ return declRef.Equals(genericParamVal->declRef);
+ }
+ return false;
+ }
+
+ String GenericParamIntVal::ToString()
+ {
+ return declRef.GetName();
+ }
+
+ int GenericParamIntVal::GetHashCode()
+ {
+ return declRef.GetHashCode() ^ 0xFFFF;
+ }
+
+ RefPtr<Val> GenericParamIntVal::SubstituteImpl(Substitutions* subst, int* ioDiff)
+ {
+ // search for a substitution that might apply to us
+ for (auto s = subst; s; s = s->outer.Ptr())
+ {
+ // the generic decl associated with the substitution list must be
+ // the generic decl that declared this parameter
+ auto genericDecl = s->genericDecl;
+ if (genericDecl != declRef.GetDecl()->ParentDecl)
+ continue;
+
+ int index = 0;
+ for (auto m : genericDecl->Members)
+ {
+ if (m.Ptr() == declRef.GetDecl())
+ {
+ // We've found it, so return the corresponding specialization argument
+ (*ioDiff)++;
+ return s->args[index];
+ }
+ else if(auto typeParam = m.As<GenericTypeParamDecl>())
+ {
+ index++;
+ }
+ else if(auto valParam = m.As<GenericValueParamDecl>())
+ {
+ index++;
+ }
+ else
+ {
+ }
+ }
+ }
+
+ // Nothing found: don't substittue.
+ return this;
+ }
+
+ // ExtensionDecl
+
+ RefPtr<SyntaxNode> ExtensionDecl::Accept(SyntaxVisitor * visitor)
+ {
+ visitor->VisitExtensionDecl(this);
+ return this;
+ }
+
+ // ConstructorDecl
+
+ RefPtr<SyntaxNode> ConstructorDecl::Accept(SyntaxVisitor * visitor)
+ {
+ visitor->VisitConstructorDecl(this);
+ return this;
+ }
+
+ // SubscriptDecl
+
+ RefPtr<SyntaxNode> SubscriptDecl::Accept(SyntaxVisitor * visitor)
+ {
+ visitor->visitSubscriptDecl(this);
+ return this;
+ }
+
+ // AccessorDecl
+
+ RefPtr<SyntaxNode> AccessorDecl::Accept(SyntaxVisitor * visitor)
+ {
+ visitor->visitAccessorDecl(this);
+ return this;
+ }
+
+ // Substitutions
+
+ RefPtr<Substitutions> Substitutions::SubstituteImpl(Substitutions* subst, int* ioDiff)
+ {
+ if (!this) return nullptr;
+
+ int diff = 0;
+ auto outerSubst = outer->SubstituteImpl(subst, &diff);
+
+ List<RefPtr<Val>> substArgs;
+ for (auto a : args)
+ {
+ substArgs.Add(a->SubstituteImpl(subst, &diff));
+ }
+
+ if (!diff) return this;
+
+ (*ioDiff)++;
+ auto substSubst = new Substitutions();
+ substSubst->genericDecl = genericDecl;
+ substSubst->args = substArgs;
+ return substSubst;
+ }
+
+ bool Substitutions::Equals(Substitutions* subst)
+ {
+ // both must be NULL, or non-NULL
+ if (!this || !subst)
+ return !this && !subst;
+
+ if (genericDecl != subst->genericDecl)
+ return false;
+
+ int argCount = args.Count();
+ assert(args.Count() == subst->args.Count());
+ for (int aa = 0; aa < argCount; ++aa)
+ {
+ if (!args[aa]->EqualsVal(subst->args[aa].Ptr()))
+ return false;
+ }
+
+ if (!outer->Equals(subst->outer.Ptr()))
+ return false;
+
+ return true;
+ }
+
+
+ // DeclRef
+
+ RefPtr<ExpressionType> DeclRef::Substitute(RefPtr<ExpressionType> type) const
+ {
+ // No substitutions? Easy.
+ if (!substitutions)
+ return type;
+
+ // Otherwise we need to recurse on the type structure
+ // and apply substitutions where it makes sense
+
+ return type->Substitute(substitutions.Ptr()).As<ExpressionType>();
+ }
+
+ DeclRef DeclRef::Substitute(DeclRef declRef) const
+ {
+ if(!substitutions)
+ return declRef;
+
+ int diff = 0;
+ return declRef.SubstituteImpl(substitutions.Ptr(), &diff);
+ }
+
+ RefPtr<ExpressionSyntaxNode> DeclRef::Substitute(RefPtr<ExpressionSyntaxNode> expr) const
+ {
+ // No substitutions? Easy.
+ if (!substitutions)
+ return expr;
+
+ assert(!"unimplemented");
+
+ return expr;
+ }
+
+
+ DeclRef DeclRef::SubstituteImpl(Substitutions* subst, int* ioDiff)
+ {
+ if (!substitutions) return *this;
+
+ int diff = 0;
+ RefPtr<Substitutions> substSubst = substitutions->SubstituteImpl(subst, &diff);
+
+ if (!diff)
+ return *this;
+
+ *ioDiff += diff;
+
+ DeclRef substDeclRef;
+ substDeclRef.decl = decl;
+ substDeclRef.substitutions = substSubst;
+ return substDeclRef;
+ }
+
+
+ // Check if this is an equivalent declaration reference to another
+ bool DeclRef::Equals(DeclRef const& declRef) const
+ {
+ if (decl != declRef.decl)
+ return false;
+
+ if (!substitutions->Equals(declRef.substitutions.Ptr()))
+ return false;
+
+ return true;
+ }
+
+ // Convenience accessors for common properties of declarations
+ String const& DeclRef::GetName() const
+ {
+ return decl->Name.Content;
+ }
+
+ DeclRef DeclRef::GetParent() const
+ {
+ auto parentDecl = decl->ParentDecl;
+ if (auto parentGeneric = dynamic_cast<GenericDecl*>(parentDecl))
+ {
+ // We need to strip away one layer of specialization
+ assert(substitutions);
+ return DeclRef(parentGeneric, substitutions->outer);
+ }
+ else
+ {
+ // If the parent isn't a generic, then it must
+ // use the same specializations as this declaration
+ return DeclRef(parentDecl, substitutions);
+ }
+
+ }
+
+ int DeclRef::GetHashCode() const
+ {
+ auto rs = PointerHash<1>::GetHashCode(decl);
+ if (substitutions)
+ {
+ rs *= 16777619;
+ rs ^= substitutions->GetHashCode();
+ }
+ return rs;
+ }
+
+ // Val
+
+ RefPtr<Val> Val::Substitute(Substitutions* subst)
+ {
+ if (!this) return nullptr;
+ if (!subst) return this;
+ int diff = 0;
+ return SubstituteImpl(subst, &diff);
+ }
+
+ RefPtr<Val> Val::SubstituteImpl(Substitutions* /*subst*/, int* /*ioDiff*/)
+ {
+ // Default behavior is to not substitute at all
+ return this;
+ }
+
+ // IntVal
+
+ int GetIntVal(RefPtr<IntVal> val)
+ {
+ if (auto constantVal = val.As<ConstantIntVal>())
+ {
+ return constantVal->value;
+ }
+ assert(!"unexpected");
+ return 0;
+ }
+
+ // ConstantIntVal
+
+ bool ConstantIntVal::EqualsVal(Val* val)
+ {
+ if (auto intVal = dynamic_cast<ConstantIntVal*>(val))
+ return value == intVal->value;
+ return false;
+ }
+
+ String ConstantIntVal::ToString()
+ {
+ return String(value);
+ }
+
+ int ConstantIntVal::GetHashCode()
+ {
+ return value;
+ }
+
+ // SwitchStmt
+
+ RefPtr<SyntaxNode> SwitchStmt::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitSwitchStmt(this);
+ }
+
+ RefPtr<SyntaxNode> CaseStmt::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitCaseStmt(this);
+ }
+
+ RefPtr<SyntaxNode> DefaultStmt::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitDefaultStmt(this);
+ }
+
+ // TraitDecl
+
+ RefPtr<SyntaxNode> TraitDecl::Accept(SyntaxVisitor * visitor)
+ {
+ visitor->VisitTraitDecl(this);
+ return this;
+ }
+
+ // TraitConformanceDecl
+
+ RefPtr<SyntaxNode> TraitConformanceDecl::Accept(SyntaxVisitor * visitor)
+ {
+ visitor->VisitTraitConformanceDecl(this);
+ return this;
+ }
+
+ // SharedTypeExpr
+
+ RefPtr<SyntaxNode> SharedTypeExpr::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitSharedTypeExpr(this);
+ }
+
+ // OperatorExpressionSyntaxNode
+
+#if 0
+ void OperatorExpressionSyntaxNode::SetOperator(RefPtr<Scope> scope, Slang::Compiler::Operator op)
+ {
+ this->Operator = op;
+ auto opExpr = new VarExpressionSyntaxNode();
+ opExpr->Variable = GetOperatorFunctionName(Operator);
+ opExpr->scope = scope;
+ opExpr->Position = this->Position;
+ this->FunctionExpr = opExpr;
+ }
+#endif
+
+ RefPtr<SyntaxNode> OperatorExpressionSyntaxNode::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->VisitOperatorExpression(this);
+ }
+
+ // DeclGroup
+
+ RefPtr<SyntaxNode> DeclGroup::Accept(SyntaxVisitor * visitor)
+ {
+ visitor->VisitDeclGroup(this);
+ return this;
+ }
+
+ //
+
+ void RegisterBuiltinDecl(
+ RefPtr<Decl> decl,
+ RefPtr<BuiltinTypeModifier> modifier)
+ {
+ auto type = DeclRefType::Create(DeclRef(decl.Ptr(), nullptr));
+ ExpressionType::sBuiltinTypes[(int)modifier->tag] = type;
+ }
+
+ void RegisterMagicDecl(
+ RefPtr<Decl> decl,
+ RefPtr<MagicTypeModifier> modifier)
+ {
+ ExpressionType::sMagicDecls[modifier->name] = decl.Ptr();
+ }
+
+ RefPtr<Decl> findMagicDecl(
+ String const& name)
+ {
+ return ExpressionType::sMagicDecls[name].GetValue();
+ }
+
+ ExpressionType* ExpressionType::GetBool()
+ {
+ return sBuiltinTypes[(int)BaseType::Bool].GetValue().Ptr();
+ }
+
+ ExpressionType* ExpressionType::GetFloat()
+ {
+ return sBuiltinTypes[(int)BaseType::Float].GetValue().Ptr();
+ }
+
+ ExpressionType* ExpressionType::GetInt()
+ {
+ return sBuiltinTypes[(int)BaseType::Int].GetValue().Ptr();
+ }
+
+ ExpressionType* ExpressionType::GetUInt()
+ {
+ return sBuiltinTypes[(int)BaseType::UInt].GetValue().Ptr();
+ }
+
+ ExpressionType* ExpressionType::GetVoid()
+ {
+ return sBuiltinTypes[(int)BaseType::Void].GetValue().Ptr();
+ }
+
+ ExpressionType* ExpressionType::getInitializerListType()
+ {
+ return initializerListType.Ptr();
+ }
+
+ ExpressionType* ExpressionType::GetError()
+ {
+ return ExpressionType::Error.Ptr();
+ }
+
+ //
+
+ RefPtr<SyntaxNode> UnparsedStmt::Accept(SyntaxVisitor * visitor)
+ {
+ return this;
+ }
+
+ //
+
+ RefPtr<SyntaxNode> InitializerListExpr::Accept(SyntaxVisitor * visitor)
+ {
+ return visitor->visitInitializerListExpr(this);
+ }
+
+ //
+
+ RefPtr<SyntaxNode> ModifierDecl::Accept(SyntaxVisitor * visitor)
+ {
+ return this;
+ }
+
+ //
+
+ RefPtr<SyntaxNode> EmptyDecl::Accept(SyntaxVisitor * visitor)
+ {
+ return this;
+ }
+
+ //
+
+ SyntaxNodeBase* createInstanceOfSyntaxClassByName(
+ String const& name)
+ {
+ if(0) {}
+ #define CASE(NAME) \
+ else if(name == #NAME) return new NAME()
+
+ CASE(GLSLBufferModifier);
+ CASE(GLSLWriteOnlyModifier);
+ CASE(GLSLReadOnlyModifier);
+ CASE(GLSLPatchModifier);
+ CASE(SimpleModifier);
+
+ #undef CASE
+ else
+ {
+ assert(!"unexpected");
+ return nullptr;
+ }
+ }
+
+ IntrinsicOp findIntrinsicOp(char const* name)
+ {
+ // TODO: need to make this faster by using a dictionary...
+
+ if (0) {}
+#define INTRINSIC(NAME) else if(strcmp(name, #NAME) == 0) return IntrinsicOp::NAME;
+#include "intrinsic-defs.h"
+
+ return IntrinsicOp::Unknown;
+ }
+
+ }
+} \ No newline at end of file
diff --git a/source/slang/syntax.h b/source/slang/syntax.h
new file mode 100644
index 000000000..9e4486d4e
--- /dev/null
+++ b/source/slang/syntax.h
@@ -0,0 +1,2771 @@
+#ifndef RASTER_RENDERER_SYNTAX_H
+#define RASTER_RENDERER_SYNTAX_H
+
+#include "../core/basic.h"
+#include "Lexer.h"
+#include "Profile.h"
+
+#include "../../Slang.h"
+
+#include <assert.h>
+
+namespace Slang
+{
+ namespace Compiler
+ {
+ using namespace CoreLib::Basic;
+ class SyntaxVisitor;
+ class FunctionSyntaxNode;
+
+ class SyntaxNodeBase : public RefObject
+ {
+ public:
+ CodePosition Position;
+ };
+
+
+
+
+ //
+ // Other modifiers may have more elaborate data, and so
+ // are represented as heap-allocated objects, in a linked
+ // list.
+ //
+ class Modifier : public SyntaxNodeBase
+ {
+ public:
+ // Next modifier in linked list of modifiers on same piece of syntax
+ RefPtr<Modifier> next;
+
+ // The token that was used to name this modifier.
+ Token nameToken;
+ };
+
+#define SIMPLE_MODIFIER(NAME) \
+ class NAME##Modifier : public Modifier {}
+
+ SIMPLE_MODIFIER(Uniform);
+ SIMPLE_MODIFIER(In);
+ SIMPLE_MODIFIER(Out);
+ SIMPLE_MODIFIER(Const);
+ SIMPLE_MODIFIER(Instance);
+ SIMPLE_MODIFIER(Builtin);
+ SIMPLE_MODIFIER(Inline);
+ SIMPLE_MODIFIER(Public);
+ SIMPLE_MODIFIER(Require);
+ SIMPLE_MODIFIER(Param);
+ SIMPLE_MODIFIER(Extern);
+ SIMPLE_MODIFIER(Input);
+ SIMPLE_MODIFIER(Transparent);
+ SIMPLE_MODIFIER(FromStdLib);
+ SIMPLE_MODIFIER(Prefix);
+ SIMPLE_MODIFIER(Postfix);
+
+#undef SIMPLE_MODIFIER
+
+ enum class IntrinsicOp
+ {
+ Unknown = 0,
+#define INTRINSIC(NAME) NAME,
+#include "intrinsic-defs.h"
+ };
+
+ IntrinsicOp findIntrinsicOp(char const* name);
+
+ class IntrinsicModifier : public Modifier
+ {
+ public:
+ // token that names the intrinsic op
+ Token opToken;
+
+ // The opcode for the intrinsic operation
+ IntrinsicOp op = IntrinsicOp::Unknown;
+ };
+
+
+ class InOutModifier : public OutModifier {};
+
+ // This is a special sentinel modifier that gets added
+ // to the list when we have multiple variable declarations
+ // all sharing the same modifiers:
+ //
+ // static uniform int a : FOO, *b : register(x0);
+ //
+ // In this case both `a` and `b` share the syntax
+ // for part of their modifier list, but then have
+ // their own modifiers as well:
+ //
+ // a: SemanticModifier("FOO") --> SharedModifiers --> StaticModifier --> UniformModifier
+ // /
+ // b: RegisterModifier("x0") /
+ //
+ class SharedModifiers : public Modifier {};
+
+ // A GLSL `layout` modifier
+ //
+ // We use a distinct modifier for each key that
+ // appears within the `layout(...)` construct,
+ // and each key might have an optional value token.
+ //
+ // TODO: We probably want a notion of "modifier groups"
+ // so that we can recover good source location info
+ // for modifiers that were part of the same vs.
+ // different constructs.
+ class GLSLLayoutModifier : public Modifier
+ {
+ public:
+ // THe token used to introduce the modifier is stored
+ // as the `nameToken` field.
+
+ // TODO: may want to accept a full expression here
+ Token valToken;
+ };
+
+ // We divide GLSL `layout` modifiers into those we have parsed
+ // (in the sense of having some notion of their semantics), and
+ // those we have not.
+ class GLSLParsedLayoutModifier : public GLSLLayoutModifier {};
+ class GLSLUnparsedLayoutModifier : public GLSLLayoutModifier {};
+
+ // Specific cases for known GLSL `layout` modifiers that we need to work with
+ class GLSLConstantIDLayoutModifier : public GLSLParsedLayoutModifier {};
+ class GLSLBindingLayoutModifier : public GLSLParsedLayoutModifier {};
+ class GLSLSetLayoutModifier : public GLSLParsedLayoutModifier {};
+ class GLSLLocationLayoutModifier : public GLSLParsedLayoutModifier {};
+
+ // A catch-all for single-keyword modifiers
+ class SimpleModifier : public Modifier {};
+
+ // Some GLSL-specific modifiers
+ class GLSLBufferModifier : public SimpleModifier {};
+ class GLSLWriteOnlyModifier : public SimpleModifier {};
+ class GLSLReadOnlyModifier : public SimpleModifier {};
+ class GLSLPatchModifier : public SimpleModifier {};
+
+ // Indicates that this is a variable declaration that corresponds to
+ // a parameter block declaration in the source program.
+ class ImplicitParameterBlockVariableModifier : public Modifier {};
+
+ // Indicates that this is a type that corresponds to the element
+ // type of a parameter block declaration in the source program.
+ class ImplicitParameterBlockElementTypeModifier : public Modifier {};
+
+ // An HLSL semantic
+ class HLSLSemantic : public Modifier
+ {
+ public:
+ Token name;
+ };
+
+
+ // An HLSL semantic that affects layout
+ class HLSLLayoutSemantic : public HLSLSemantic
+ {
+ public:
+ Token registerName;
+ Token componentMask;
+ };
+
+ // An HLSL `register` semantic
+ class HLSLRegisterSemantic : public HLSLLayoutSemantic
+ {
+ };
+
+ // TODO(tfoley): `packoffset`
+ class HLSLPackOffsetSemantic : public HLSLLayoutSemantic
+ {
+ };
+
+ // An HLSL semantic that just associated a declaration with a semantic name
+ class HLSLSimpleSemantic : public HLSLSemantic
+ {
+ };
+
+ // GLSL
+
+ // Directives that came in via the preprocessor, but
+ // that we need to keep around for later steps
+ class GLSLPreprocessorDirective : public Modifier
+ {
+ };
+
+ // A GLSL `#version` directive
+ class GLSLVersionDirective : public GLSLPreprocessorDirective
+ {
+ public:
+ // Token giving the version number to use
+ Token versionNumberToken;
+
+ // Optional token giving the sub-profile to be used
+ Token glslProfileToken;
+ };
+
+ // A GLSL `#extension` directive
+ class GLSLExtensionDirective : public GLSLPreprocessorDirective
+ {
+ public:
+ // Token giving the version number to use
+ Token extensionNameToken;
+
+ // Optional token giving the sub-profile to be used
+ Token dispositionToken;
+ };
+
+ class ParameterBlockReflectionName : public Modifier
+ {
+ public:
+ Token nameToken;
+ };
+
+ // Helper class for iterating over a list of heap-allocated modifiers
+ struct ModifierList
+ {
+ struct Iterator
+ {
+ Modifier* current;
+
+ Modifier* operator*()
+ {
+ return current;
+ }
+
+ void operator++()
+ {
+ current = current->next.Ptr();
+ }
+
+ bool operator!=(Iterator other)
+ {
+ return current != other.current;
+ };
+
+ Iterator()
+ : current(nullptr)
+ {}
+
+ Iterator(Modifier* modifier)
+ : current(modifier)
+ {}
+ };
+
+ ModifierList()
+ : modifiers(nullptr)
+ {}
+
+ ModifierList(Modifier* modifiers)
+ : modifiers(modifiers)
+ {}
+
+ Iterator begin() { return Iterator(modifiers); }
+ Iterator end() { return Iterator(nullptr); }
+
+ Modifier* modifiers;
+ };
+
+ // Helper class for iterating over heap-allocated modifiers
+ // of a specific type.
+ template<typename T>
+ struct FilteredModifierList
+ {
+ struct Iterator
+ {
+ Modifier* current;
+
+ T* operator*()
+ {
+ return (T*)current;
+ }
+
+ void operator++()
+ {
+ current = Adjust(current->next.Ptr());
+ }
+
+ bool operator!=(Iterator other)
+ {
+ return current != other.current;
+ };
+
+ Iterator()
+ : current(nullptr)
+ {}
+
+ Iterator(Modifier* modifier)
+ : current(modifier)
+ {}
+ };
+
+ FilteredModifierList()
+ : modifiers(nullptr)
+ {}
+
+ FilteredModifierList(Modifier* modifiers)
+ : modifiers(Adjust(modifiers))
+ {}
+
+ Iterator begin() { return Iterator(modifiers); }
+ Iterator end() { return Iterator(nullptr); }
+
+ static Modifier* Adjust(Modifier* modifier)
+ {
+ Modifier* m = modifier;
+ for (;;)
+ {
+ if (!m) return m;
+ if (dynamic_cast<T*>(m)) return m;
+ m = m->next.Ptr();
+ }
+ }
+
+ Modifier* modifiers;
+ };
+
+ // A set of modifiers attached to a syntax node
+ struct Modifiers
+ {
+ // The first modifier in the linked list of heap-allocated modifiers
+ RefPtr<Modifier> first;
+
+ template<typename T>
+ FilteredModifierList<T> getModifiersOfType() { return FilteredModifierList<T>(first.Ptr()); }
+
+ // Find the first modifier of a given type, or return `nullptr` if none is found.
+ template<typename T>
+ T* findModifier()
+ {
+ return *getModifiersOfType<T>().begin();
+ }
+
+ template<typename T>
+ bool hasModifier() { return findModifier<T>() != nullptr; }
+
+ FilteredModifierList<Modifier>::Iterator begin() { return FilteredModifierList<Modifier>::Iterator(first.Ptr()); }
+ FilteredModifierList<Modifier>::Iterator end() { return FilteredModifierList<Modifier>::Iterator(nullptr); }
+ };
+
+
+ enum class BaseType
+ {
+ // Note(tfoley): These are ordered in terms of promotion rank, so be vareful when messing with this
+
+ Void = 0,
+ Bool,
+ Int,
+ UInt,
+ UInt64,
+ Float,
+#if 0
+ Texture2D = 48,
+ TextureCube = 49,
+ Texture2DArray = 50,
+ Texture2DShadow = 51,
+ TextureCubeShadow = 52,
+ Texture2DArrayShadow = 53,
+ Texture3D = 54,
+ SamplerState = 4096, SamplerComparisonState = 4097,
+ Error = 16384,
+#endif
+ };
+
+ class Decl;
+ class StructSyntaxNode;
+ class BasicExpressionType;
+ class ArrayExpressionType;
+ class TypeDefDecl;
+ class DeclRefType;
+ class NamedExpressionType;
+ class TypeType;
+ class GenericDeclRefType;
+ class VectorExpressionType;
+ class MatrixExpressionType;
+ class ArithmeticExpressionType;
+ class GenericDecl;
+ class Substitutions;
+ class TextureType;
+ class SamplerStateType;
+
+ // A compile-time constant value (usually a type)
+ class Val : public RefObject
+ {
+ public:
+ // construct a new value by applying a set of parameter
+ // substitutions to this one
+ RefPtr<Val> Substitute(Substitutions* subst);
+
+ // Lower-level interface for substition. Like the basic
+ // `Substitute` above, but also takes a by-reference
+ // integer parameter that should be incremented when
+ // returning a modified value (this can help the caller
+ // decide whether they need to do anything).
+ virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff);
+
+ virtual bool EqualsVal(Val* val) = 0;
+ virtual String ToString() = 0;
+ virtual int GetHashCode() = 0;
+ bool operator == (const Val & v)
+ {
+ return EqualsVal(const_cast<Val*>(&v));
+ }
+ };
+
+ // A compile-time integer (may not have a specific concrete value)
+ class IntVal : public Val
+ {
+ };
+
+ // Try to extract a simple integer value from an `IntVal`.
+ // This fill assert-fail if the object doesn't represent a literal value.
+ int GetIntVal(RefPtr<IntVal> val);
+
+ // Trivial case of a value that is just a constant integer
+ class ConstantIntVal : public IntVal
+ {
+ public:
+ int value;
+
+ ConstantIntVal(int value)
+ : value(value)
+ {}
+
+ virtual bool EqualsVal(Val* val) override;
+ virtual String ToString() override;
+ virtual int GetHashCode() override;
+ };
+
+ // TODO(tfoley): classes for more general compile-time integers,
+ // including references to template parameters
+
+ // A type, representing a classifier for some term in the AST.
+ //
+ // Types can include "sugar" in that they may refer to a
+ // `typedef` which gives them a good name when printed as
+ // part of diagnostic messages.
+ //
+ // In order to operation on types, though, we often want
+ // to look past any sugar, and operate on an underlying
+ // "canonical" type. The reprsentation caches a pointer to
+ // a canonical type on every type, so we can easily
+ // operate on the raw representation when needed.
+ class ExpressionType : public Val
+ {
+ public:
+ static RefPtr<ExpressionType> Error;
+ static RefPtr<ExpressionType> initializerListType;
+ static RefPtr<ExpressionType> Overloaded;
+
+ static Dictionary<int, RefPtr<ExpressionType>> sBuiltinTypes;
+ static Dictionary<String, Decl*> sMagicDecls;
+
+ // Note: just exists to make sure we can clean up
+ // canonical types we create along the way
+ static List<RefPtr<ExpressionType>> sCanonicalTypes;
+
+
+
+ static ExpressionType* GetBool();
+ static ExpressionType* GetFloat();
+ static ExpressionType* GetInt();
+ static ExpressionType* GetUInt();
+ static ExpressionType* GetVoid();
+ static ExpressionType* getInitializerListType();
+ static ExpressionType* GetError();
+
+ public:
+ virtual String ToString() = 0;
+
+ bool Equals(ExpressionType * type);
+ bool Equals(RefPtr<ExpressionType> type);
+
+ bool IsVectorType() { return As<VectorExpressionType>() != nullptr; }
+ bool IsArray() { return As<ArrayExpressionType>() != nullptr; }
+
+ template<typename T>
+ T* As()
+ {
+ return dynamic_cast<T*>(GetCanonicalType());
+ }
+
+ // Convenience/legacy wrappers for `As<>`
+ ArithmeticExpressionType * AsArithmeticType() { return As<ArithmeticExpressionType>(); }
+ BasicExpressionType * AsBasicType() { return As<BasicExpressionType>(); }
+ VectorExpressionType * AsVectorType() { return As<VectorExpressionType>(); }
+ MatrixExpressionType * AsMatrixType() { return As<MatrixExpressionType>(); }
+ ArrayExpressionType * AsArrayType() { return As<ArrayExpressionType>(); }
+
+ DeclRefType* AsDeclRefType() { return As<DeclRefType>(); }
+
+ NamedExpressionType* AsNamedType();
+
+ bool IsTextureOrSampler();
+ bool IsTexture() { return As<TextureType>() != nullptr; }
+ bool IsSampler() { return As<SamplerStateType>() != nullptr; }
+ bool IsStruct();
+ bool IsClass();
+ static void Init();
+ static void Finalize();
+ ExpressionType* GetCanonicalType();
+
+ virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override;
+
+ virtual bool EqualsVal(Val* val) override;
+ protected:
+ virtual bool EqualsImpl(ExpressionType * type) = 0;
+
+ virtual ExpressionType* CreateCanonicalType() = 0;
+ ExpressionType* canonicalType = nullptr;
+ };
+
+ // A substitution represents a binding of certain
+ // type-level variables to concrete argument values
+ class Substitutions : public RefObject
+ {
+ public:
+ // The generic declaration that defines the
+ // parametesr we are binding to arguments
+ GenericDecl* genericDecl;
+
+ // The actual values of the arguments
+ List<RefPtr<Val>> args;
+
+ // Any further substitutions, relating to outer generic declarations
+ RefPtr<Substitutions> outer;
+
+ // Apply a set of substitutions to the bindings in this substitution
+ RefPtr<Substitutions> SubstituteImpl(Substitutions* subst, int* ioDiff);
+
+ // Check if these are equivalent substitutiosn to another set
+ bool Equals(Substitutions* subst);
+ bool operator == (const Substitutions & subst)
+ {
+ return Equals(const_cast<Substitutions*>(&subst));
+ }
+ int GetHashCode() const
+ {
+ int rs = 0;
+ for (auto && v : args)
+ {
+ rs ^= v->GetHashCode();
+ rs *= 16777619;
+ }
+ return rs;
+ }
+ };
+
+ class SyntaxNode : public SyntaxNodeBase
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) = 0;
+ };
+
+ class ContainerDecl;
+ class SpecializeModifier;
+
+ // Represents how much checking has been applied to a declaration.
+ enum class DeclCheckState : uint8_t
+ {
+ // The declaration has been parsed, but not checked
+ Unchecked,
+
+ // We are in the process of checking the declaration "header"
+ // (those parts of the declaration needed in order to
+ // reference it)
+ CheckingHeader,
+
+ // We are done checking the declaration header.
+ CheckedHeader,
+
+ // We have checked the declaration fully.
+ Checked,
+ };
+
+ // A syntax node which can have modifiers appled
+ class ModifiableSyntaxNode : public SyntaxNode
+ {
+ public:
+ Modifiers modifiers;
+
+ template<typename T>
+ FilteredModifierList<T> GetModifiersOfType() { return FilteredModifierList<T>(modifiers.first.Ptr()); }
+
+ // Find the first modifier of a given type, or return `nullptr` if none is found.
+ template<typename T>
+ T* FindModifier()
+ {
+ return *GetModifiersOfType<T>().begin();
+ }
+
+ template<typename T>
+ bool HasModifier() { return FindModifier<T>() != nullptr; }
+ };
+
+ void addModifier(
+ RefPtr<ModifiableSyntaxNode> syntax,
+ RefPtr<Modifier> modifier);
+
+
+ // An intermediate type to represent either a single declaration, or a group of declarations
+ class DeclBase : public ModifiableSyntaxNode
+ {
+ public:
+ };
+
+ class Decl : public DeclBase
+ {
+ public:
+ ContainerDecl* ParentDecl;
+
+ Token Name;
+ String const& getName() { return Name.Content; }
+ Token const& getNameToken() { return Name; }
+
+
+ DeclCheckState checkState = DeclCheckState::Unchecked;
+
+ // The next declaration defined in the same container with the same name
+ Decl* nextInContainerWithSameName = nullptr;
+
+ bool IsChecked(DeclCheckState state) { return checkState >= state; }
+ void SetCheckState(DeclCheckState state)
+ {
+ assert(state >= checkState);
+ checkState = state;
+ }
+ };
+
+ struct QualType
+ {
+ RefPtr<ExpressionType> type;
+ bool IsLeftValue;
+
+ QualType()
+ : IsLeftValue(false)
+ {}
+
+ QualType(RefPtr<ExpressionType> type)
+ : type(type)
+ , IsLeftValue(false)
+ {}
+
+ QualType(ExpressionType* type)
+ : type(type)
+ , IsLeftValue(false)
+ {}
+
+ void operator=(RefPtr<ExpressionType> t)
+ {
+ *this = QualType(t);
+ }
+
+ void operator=(ExpressionType* t)
+ {
+ *this = QualType(t);
+ }
+
+ ExpressionType* Ptr() { return type.Ptr(); }
+
+ operator RefPtr<ExpressionType>() { return type; }
+ RefPtr<ExpressionType> operator->() { return type; }
+ };
+
+ class ExpressionSyntaxNode : public SyntaxNode
+ {
+ public:
+ QualType Type;
+ ExpressionSyntaxNode()
+ {}
+ };
+
+
+
+
+ // A reference to a declaration, which may include
+ // substitutions for generic parameters.
+ struct DeclRef
+ {
+ typedef Decl DeclType;
+
+ // The underlying declaration
+ Decl* decl = nullptr;
+ Decl* GetDecl() const { return decl; }
+
+ // Optionally, a chain of substititions to perform
+ RefPtr<Substitutions> substitutions;
+
+ DeclRef()
+ {}
+
+ DeclRef(Decl* decl, RefPtr<Substitutions> substitutions)
+ : decl(decl)
+ , substitutions(substitutions)
+ {}
+
+ // Apply substitutions to a type or ddeclaration
+ RefPtr<ExpressionType> Substitute(RefPtr<ExpressionType> type) const;
+ DeclRef Substitute(DeclRef declRef) const;
+
+ // Apply substitutions to an expression
+ RefPtr<ExpressionSyntaxNode> Substitute(RefPtr<ExpressionSyntaxNode> expr) const;
+
+ // Apply substitutions to this declaration reference
+ DeclRef SubstituteImpl(Substitutions* subst, int* ioDiff);
+
+ // Check if this is an equivalent declaration reference to another
+ bool Equals(DeclRef const& declRef) const;
+ bool operator == (const DeclRef& other) const
+ {
+ return Equals(other);
+ }
+
+ // Convenience accessors for common properties of declarations
+ String const& GetName() const;
+ DeclRef GetParent() const;
+
+ // "dynamic cast" to a more specific declaration reference type
+ template<typename T>
+ T As() const
+ {
+ T result;
+ result.decl = dynamic_cast<T::DeclType*>(decl);
+ result.substitutions = substitutions;
+ return result;
+ }
+
+ // Implicit conversion mostly so we can use a `DeclRef`
+ // in a conditional context
+ operator Decl*() const
+ {
+ return decl;
+ }
+
+ int GetHashCode() const;
+ };
+
+ // Helper macro for defining `DeclRef` subtypes
+ #define SLANG_DECLARE_DECL_REF(D) \
+ typedef D DeclType; \
+ D* GetDecl() const { return (D*) decl; } \
+ /* */
+
+
+
+ // The type of a reference to an overloaded name
+ class OverloadGroupType : public ExpressionType
+ {
+ public:
+ virtual String ToString() override;
+
+ protected:
+ virtual bool EqualsImpl(ExpressionType * type) override;
+ virtual ExpressionType* CreateCanonicalType() override;
+ virtual int GetHashCode() override;
+ };
+
+ // The type of an initializer-list expression (before it has
+ // been coerced to some other type)
+ class InitializerListType : public ExpressionType
+ {
+ public:
+ virtual String ToString() override;
+
+ protected:
+ virtual bool EqualsImpl(ExpressionType * type) override;
+ virtual ExpressionType* CreateCanonicalType() override;
+ virtual int GetHashCode() override;
+ };
+
+ // The type of an expression that was erroneous
+ class ErrorType : public ExpressionType
+ {
+ public:
+ virtual String ToString() override;
+
+ protected:
+ virtual bool EqualsImpl(ExpressionType * type) override;
+ virtual ExpressionType* CreateCanonicalType() override;
+ virtual int GetHashCode() override;
+ };
+
+ // A type that takes the form of a reference to some declaration
+ class DeclRefType : public ExpressionType
+ {
+ public:
+ DeclRef declRef;
+
+ virtual String ToString() override;
+ virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override;
+
+ static DeclRefType* Create(DeclRef declRef);
+
+ protected:
+ DeclRefType()
+ {}
+ DeclRefType(DeclRef declRef)
+ : declRef(declRef)
+ {}
+ virtual int GetHashCode() override;
+ virtual bool EqualsImpl(ExpressionType * type) override;
+ virtual ExpressionType* CreateCanonicalType() override;
+ };
+
+ // Base class for types that can be used in arithmetic expressions
+ class ArithmeticExpressionType : public DeclRefType
+ {
+ public:
+ virtual BasicExpressionType* GetScalarType() = 0;
+ };
+
+ class FunctionDeclBase;
+
+ class BasicExpressionType : public ArithmeticExpressionType
+ {
+ public:
+ BaseType BaseType;
+
+ BasicExpressionType()
+ {
+ BaseType = Compiler::BaseType::Int;
+ }
+ BasicExpressionType(Compiler::BaseType baseType)
+ {
+ BaseType = baseType;
+ }
+ virtual CoreLib::Basic::String ToString() override;
+ protected:
+ virtual BasicExpressionType* GetScalarType() override;
+ virtual bool EqualsImpl(ExpressionType * type) override;
+ virtual ExpressionType* CreateCanonicalType() override;
+ };
+
+
+ class TextureTypeBase : public DeclRefType
+ {
+ public:
+ // The type that results from fetching an element from this texture
+ RefPtr<ExpressionType> elementType;
+
+ // Bits representing the kind of texture type we are looking at
+ // (e.g., `Texture2DMS` vs. `TextureCubeArray`)
+ typedef uint16_t Flavor;
+ Flavor flavor;
+
+ enum
+ {
+ // Mask for the overall "shape" of the texture
+ ShapeMask = SLANG_RESOURCE_BASE_SHAPE_MASK,
+
+ // Flag for whether the shape has "array-ness"
+ ArrayFlag = SLANG_TEXTURE_ARRAY_FLAG,
+
+ // Whether or not the texture stores multiple samples per pixel
+ MultisampleFlag = SLANG_TEXTURE_MULTISAMPLE_FLAG,
+
+ // Whether or not this is a shadow texture
+ //
+ // TODO(tfoley): is this even meaningful/used?
+ // ShadowFlag = 0x80,
+ };
+
+ enum Shape : uint8_t
+ {
+ Shape1D = SLANG_TEXTURE_1D,
+ Shape2D = SLANG_TEXTURE_2D,
+ Shape3D = SLANG_TEXTURE_3D,
+ ShapeCube = SLANG_TEXTURE_CUBE,
+
+ Shape1DArray = Shape1D | ArrayFlag,
+ Shape2DArray = Shape2D | ArrayFlag,
+ // No Shape3DArray
+ ShapeCubeArray = ShapeCube | ArrayFlag,
+ };
+
+
+ Shape GetBaseShape() const { return Shape(flavor & ShapeMask); }
+ bool isArray() const { return (flavor & ArrayFlag) != 0; }
+ bool isMultisample() const { return (flavor & MultisampleFlag) != 0; }
+// bool isShadow() const { return (flavor & ShadowFlag) != 0; }
+
+ SlangResourceShape getShape() const { return flavor & 0xFF; }
+ SlangResourceAccess getAccess() const { return (flavor >> 8) & 0xFF; }
+
+ TextureTypeBase(
+ Flavor flavor,
+ RefPtr<ExpressionType> elementType)
+ : elementType(elementType)
+ , flavor(flavor)
+ {}
+ };
+
+ class TextureType : public TextureTypeBase
+ {
+ public:
+ TextureType(
+ Flavor flavor,
+ RefPtr<ExpressionType> elementType)
+ : TextureTypeBase(flavor, elementType)
+ {}
+ };
+
+ // This is a base type for texture/sampler pairs,
+ // as they exist in, e.g., GLSL
+ class TextureSamplerType : public TextureTypeBase
+ {
+ public:
+ TextureSamplerType(
+ Flavor flavor,
+ RefPtr<ExpressionType> elementType)
+ : TextureTypeBase(flavor, elementType)
+ {}
+ };
+
+ // This is a base type for `image*` types, as they exist in GLSL
+ class GLSLImageType : public TextureTypeBase
+ {
+ public:
+ GLSLImageType(
+ Flavor flavor,
+ RefPtr<ExpressionType> elementType)
+ : TextureTypeBase(flavor, elementType)
+ {}
+ };
+
+ class SamplerStateType : public DeclRefType
+ {
+ public:
+ // What flavor of sampler state is this
+ enum class Flavor : uint8_t
+ {
+ SamplerState,
+ SamplerComparisonState,
+ };
+ Flavor flavor;
+ };
+
+ // Other cases of generic types known to the compiler
+ class BuiltinGenericType : public DeclRefType
+ {
+ public:
+ RefPtr<ExpressionType> elementType;
+ };
+
+ // Types that behave like pointers, in that they can be
+ // dereferenced (implicitly) to access members defined
+ // in the element type.
+ class PointerLikeType : public BuiltinGenericType
+ {};
+
+ // Generic types used in existing Slang code
+ // TODO(tfoley): check that these are actually working right...
+ class PatchType : public PointerLikeType {};
+ class StorageBufferType : public BuiltinGenericType {};
+ class UniformBufferType : public PointerLikeType {};
+ class PackedBufferType : public BuiltinGenericType {};
+
+ // HLSL buffer-type resources
+
+ class HLSLBufferType : public BuiltinGenericType {};
+ class HLSLRWBufferType : public BuiltinGenericType {};
+ class HLSLStructuredBufferType : public BuiltinGenericType {};
+ class HLSLRWStructuredBufferType : public BuiltinGenericType {};
+
+ class UntypedBufferResourceType : public DeclRefType {};
+ class HLSLByteAddressBufferType : public UntypedBufferResourceType {};
+ class HLSLRWByteAddressBufferType : public UntypedBufferResourceType {};
+
+ class HLSLAppendStructuredBufferType : public BuiltinGenericType {};
+ class HLSLConsumeStructuredBufferType : public BuiltinGenericType {};
+
+ class HLSLInputPatchType : public BuiltinGenericType {};
+ class HLSLOutputPatchType : public BuiltinGenericType {};
+
+ // HLSL geometry shader output stream types
+
+ class HLSLStreamOutputType : public BuiltinGenericType {};
+ class HLSLPointStreamType : public HLSLStreamOutputType {};
+ class HLSLLineStreamType : public HLSLStreamOutputType {};
+ class HLSLTriangleStreamType : public HLSLStreamOutputType {};
+
+ //
+ class GLSLInputAttachmentType : public DeclRefType {};
+
+ // Base class for types used when desugaring parameter block
+ // declarations, includeing HLSL `cbuffer` or GLSL `uniform` blocks.
+ class ParameterBlockType : public PointerLikeType {};
+
+ class UniformParameterBlockType : public ParameterBlockType {};
+ class VaryingParameterBlockType : public ParameterBlockType {};
+
+ // Type for HLSL `cbuffer` declarations, and `ConstantBuffer<T>`
+ // ALso used for GLSL `uniform` blocks.
+ class ConstantBufferType : public UniformParameterBlockType {};
+
+ // Type for HLSL `tbuffer` declarations, and `TextureBuffer<T>`
+ class TextureBufferType : public UniformParameterBlockType {};
+
+ // Type for GLSL `in` and `out` blocks
+ class GLSLInputParameterBlockType : public VaryingParameterBlockType {};
+ class GLSLOutputParameterBlockType : public VaryingParameterBlockType {};
+
+ // Type for GLLSL `buffer` blocks
+ class GLSLShaderStorageBufferType : public UniformParameterBlockType {};
+
+ class ArrayExpressionType : public ExpressionType
+ {
+ public:
+ RefPtr<ExpressionType> BaseType;
+ RefPtr<IntVal> ArrayLength;
+ virtual CoreLib::Basic::String ToString() override;
+ protected:
+ virtual bool EqualsImpl(ExpressionType * type) override;
+ virtual ExpressionType* CreateCanonicalType() override;
+ virtual int GetHashCode() override;
+ };
+
+ // The "type" of an expression that resolves to a type.
+ // For example, in the expression `float(2)` the sub-expression,
+ // `float` would have the type `TypeType(float)`.
+ class TypeType : public ExpressionType
+ {
+ public:
+ TypeType(RefPtr<ExpressionType> type)
+ : type(type)
+ {}
+
+ // The type that this is the type of...
+ RefPtr<ExpressionType> type;
+
+
+ virtual String ToString() override;
+
+ protected:
+ virtual bool EqualsImpl(ExpressionType * type) override;
+ virtual ExpressionType* CreateCanonicalType() override;
+ virtual int GetHashCode() override;
+ };
+
+ class GenericDecl;
+
+ // A vector type, e.g., `vector<T,N>`
+ class VectorExpressionType : public ArithmeticExpressionType
+ {
+ public:
+#if 0
+ VectorExpressionType(
+ RefPtr<ExpressionType> elementType,
+ RefPtr<IntVal> elementCount)
+ : elementType(elementType)
+ , elementCount(elementCount)
+ {}
+#endif
+
+ // The type of vector elements.
+ // As an invariant, this should be a basic type or an alias.
+ RefPtr<ExpressionType> elementType;
+
+ // The number of elements
+ RefPtr<IntVal> elementCount;
+
+ virtual String ToString() override;
+
+ protected:
+ virtual BasicExpressionType* GetScalarType() override;
+ };
+
+ // A matrix type, e.g., `matrix<T,R,C>`
+ class MatrixExpressionType : public ArithmeticExpressionType
+ {
+ public:
+ // TODO: consider adding these back for convenience,
+ // with a way to initialize them on-demand from the
+ // real storage (which is in the `DeclRefType`
+#if 0
+ // The type of vector elements.
+ // As an invariant, this should be a basic type or an alias.
+ RefPtr<ExpressionType> elementType;
+
+ // The type of the matrix rows
+ RefPtr<VectorExpressionType> rowType;
+
+ // The number of rows and columns
+ RefPtr<IntVal> rowCount;
+ RefPtr<IntVal> colCount;
+#endif
+ ExpressionType* getElementType();
+ IntVal* getRowCount();
+ IntVal* getColumnCount();
+
+
+ virtual String ToString() override;
+
+ protected:
+ virtual BasicExpressionType* GetScalarType() override;
+ };
+
+ inline BaseType GetVectorBaseType(VectorExpressionType* vecType) {
+ return vecType->elementType->AsBasicType()->BaseType;
+ }
+
+ inline int GetVectorSize(VectorExpressionType* vecType)
+ {
+ auto constantVal = vecType->elementCount.As<ConstantIntVal>();
+ if (constantVal)
+ return constantVal->value;
+ // TODO: what to do in this case?
+ return 0;
+ }
+
+ class Type
+ {
+ public:
+ RefPtr<ExpressionType> DataType;
+ // ContrainedWorlds: Implementation must be defined at at least one of of these worlds in order to satisfy global dependency
+ // FeasibleWorlds: The component can be computed at any of these worlds
+ EnumerableHashSet<String> ConstrainedWorlds, FeasibleWorlds;
+ EnumerableHashSet<String> PinnedWorlds;
+ };
+
+ class ContainerDecl;
+
+
+ // A group of declarations that should be treated as a unit
+ class DeclGroup : public DeclBase
+ {
+ public:
+ List<RefPtr<Decl>> decls;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ template<typename T>
+ struct FilteredMemberList
+ {
+ typedef RefPtr<Decl> Element;
+
+ FilteredMemberList()
+ : mBegin(NULL)
+ , mEnd(NULL)
+ {}
+
+ explicit FilteredMemberList(
+ List<Element> const& list)
+ : mBegin(Adjust(list.begin(), list.end()))
+ , mEnd(list.end())
+ {}
+
+ struct Iterator
+ {
+ Element* mCursor;
+ Element* mEnd;
+
+ bool operator!=(Iterator const& other)
+ {
+ return mCursor != other.mCursor;
+ }
+
+ void operator++()
+ {
+ mCursor = Adjust(mCursor + 1, mEnd);
+ }
+
+ RefPtr<T>& operator*()
+ {
+ return *(RefPtr<T>*)mCursor;
+ }
+ };
+
+ Iterator begin()
+ {
+ Iterator iter = { mBegin, mEnd };
+ return iter;
+ }
+
+ Iterator end()
+ {
+ Iterator iter = { mEnd, mEnd };
+ return iter;
+ }
+
+ static Element* Adjust(Element* cursor, Element* end)
+ {
+ while (cursor != end)
+ {
+ if ((*cursor).As<T>())
+ return cursor;
+ cursor++;
+ }
+ return cursor;
+ }
+
+ // TODO(tfoley): It is ugly to have these.
+ // We should probably fix the call sites instead.
+ RefPtr<T>& First() { return *begin(); }
+ int Count()
+ {
+ int count = 0;
+ for (auto iter : (*this))
+ {
+ (void)iter;
+ count++;
+ }
+ return count;
+ }
+
+ List<RefPtr<T>> ToArray()
+ {
+ List<RefPtr<T>> result;
+ for (auto element : (*this))
+ {
+ result.Add(element);
+ }
+ return result;
+ }
+
+ Element* mBegin;
+ Element* mEnd;
+ };
+
+ struct TransparentMemberInfo
+ {
+ // The declaration of the transparent member
+ Decl* decl;
+ };
+
+ // A "container" decl is a parent to other declarations
+ class ContainerDecl : public Decl
+ {
+ public:
+ List<RefPtr<Decl>> Members;
+
+ template<typename T>
+ FilteredMemberList<T> GetMembersOfType()
+ {
+ return FilteredMemberList<T>(Members);
+ }
+
+
+ // Dictionary for looking up members by name.
+ // This is built on demand before performing lookup.
+ Dictionary<String, Decl*> memberDictionary;
+
+ // Whether the `memberDictionary` is valid.
+ // Should be set to `false` if any members get added/remoed.
+ bool memberDictionaryIsValid = false;
+
+ // A list of transparent members, to be used in lookup
+ // Note: this is only valid if `memberDictionaryIsValid` is true
+ List<TransparentMemberInfo> transparentMembers;
+ };
+
+ template<typename T>
+ struct FilteredMemberRefList
+ {
+ List<RefPtr<Decl>> const& decls;
+ RefPtr<Substitutions> substitutions;
+
+ FilteredMemberRefList(
+ List<RefPtr<Decl>> const& decls,
+ RefPtr<Substitutions> substitutions)
+ : decls(decls)
+ , substitutions(substitutions)
+ {}
+
+ int Count() const
+ {
+ int count = 0;
+ for (auto d : *this)
+ count++;
+ return count;
+ }
+
+ List<T> ToArray() const
+ {
+ List<T> result;
+ for (auto d : *this)
+ result.Add(d);
+ return result;
+ }
+
+ struct Iterator
+ {
+ FilteredMemberRefList const* list;
+ RefPtr<Decl>* ptr;
+ RefPtr<Decl>* end;
+
+ Iterator() : list(nullptr), ptr(nullptr) {}
+ Iterator(
+ FilteredMemberRefList const* list,
+ RefPtr<Decl>* ptr,
+ RefPtr<Decl>* end)
+ : list(list)
+ , ptr(ptr)
+ , end(end)
+ {}
+
+ bool operator!=(Iterator other)
+ {
+ return ptr != other.ptr;
+ }
+
+ void operator++()
+ {
+ ptr = list->Adjust(ptr + 1, end);
+ }
+
+ T operator*()
+ {
+ return DeclRef(ptr->Ptr(), list->substitutions).As<T>();
+ }
+ };
+
+ Iterator begin() const { return Iterator(this, Adjust(decls.begin(), decls.end()), decls.end()); }
+ Iterator end() const { return Iterator(this, decls.end(), decls.end()); }
+
+ RefPtr<Decl>* Adjust(RefPtr<Decl>* ptr, RefPtr<Decl>* end) const
+ {
+ while (ptr != end)
+ {
+ DeclRef declRef(ptr->Ptr(), substitutions);
+ if (declRef.As<T>())
+ return ptr;
+ ptr++;
+ }
+ return end;
+ }
+ };
+
+ struct ContainerDeclRef : DeclRef
+ {
+ SLANG_DECLARE_DECL_REF(ContainerDecl);
+
+ FilteredMemberRefList<DeclRef> GetMembers() const
+ {
+ return FilteredMemberRefList<DeclRef>(GetDecl()->Members, substitutions);
+ }
+
+ template<typename T>
+ FilteredMemberRefList<T> GetMembersOfType() const
+ {
+ return FilteredMemberRefList<T>(GetDecl()->Members, substitutions);
+ }
+
+ };
+
+ //
+ // Type Expressions
+ //
+
+ // A "type expression" is a term that we expect to resolve to a type during checking.
+ // We store both the original syntax and the resolved type here.
+ struct TypeExp
+ {
+ TypeExp() {}
+ TypeExp(TypeExp const& other)
+ : exp(other.exp)
+ , type(other.type)
+ {}
+ explicit TypeExp(RefPtr<ExpressionSyntaxNode> exp)
+ : exp(exp)
+ {}
+ TypeExp(RefPtr<ExpressionSyntaxNode> exp, RefPtr<ExpressionType> type)
+ : exp(exp)
+ , type(type)
+ {}
+
+ RefPtr<ExpressionSyntaxNode> exp;
+ RefPtr<ExpressionType> type;
+
+ bool Equals(ExpressionType* other) {
+ return type->Equals(other);
+ }
+ bool Equals(RefPtr<ExpressionType> other) {
+ return type->Equals(other.Ptr());
+ }
+ ExpressionType* Ptr() { return type.Ptr(); }
+ operator RefPtr<ExpressionType>()
+ {
+ return type;
+ }
+ ExpressionType* operator->() { return Ptr(); }
+
+ TypeExp Accept(SyntaxVisitor* visitor);
+ };
+
+
+ //
+ // Declarations
+ //
+
+ // Base class for all variable-like declarations
+ class VarDeclBase : public Decl
+ {
+ public:
+ // Type of the variable
+ TypeExp Type;
+
+ ExpressionType* getType() { return Type.type.Ptr(); }
+
+ // Initializer expression (optional)
+ RefPtr<ExpressionSyntaxNode> Expr;
+ };
+
+ struct VarDeclBaseRef : DeclRef
+ {
+ SLANG_DECLARE_DECL_REF(VarDeclBase);
+
+ RefPtr<ExpressionType> GetType() const { return Substitute(GetDecl()->Type); }
+
+ RefPtr<ExpressionSyntaxNode> getInitExpr() const { return Substitute(GetDecl()->Expr); }
+ };
+
+ // A field of a `struct` type
+ class StructField : public VarDeclBase
+ {
+ public:
+ StructField()
+ {}
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ struct FieldDeclRef : VarDeclBaseRef
+ {
+ SLANG_DECLARE_DECL_REF(StructField)
+ };
+
+ // An extension to apply to an existing type
+ class ExtensionDecl : public ContainerDecl
+ {
+ public:
+ TypeExp targetType;
+
+ // next extension attached to the same nominal type
+ ExtensionDecl* nextCandidateExtension = nullptr;
+
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ struct ExtensionDeclRef : ContainerDeclRef
+ {
+ SLANG_DECLARE_DECL_REF(ExtensionDecl);
+
+ RefPtr<ExpressionType> GetTargetType() const { return Substitute(GetDecl()->targetType); }
+ };
+
+ // Declaration of a type that represents some sort of aggregate
+ class AggTypeDecl : public ContainerDecl
+ {
+ public:
+ // extensions that might apply to this declaration
+ ExtensionDecl* candidateExtensions = nullptr;
+ FilteredMemberList<StructField> GetFields()
+ {
+ return GetMembersOfType<StructField>();
+ }
+ StructField* FindField(String name)
+ {
+ for (auto field : GetFields())
+ {
+ if (field->Name.Content == name)
+ return field.Ptr();
+ }
+ return nullptr;
+ }
+ int FindFieldIndex(String name)
+ {
+ int index = 0;
+ for (auto field : GetFields())
+ {
+ if (field->Name.Content == name)
+ return index;
+ index++;
+ }
+ return -1;
+ }
+ };
+
+ struct AggTypeDeclRef : public ContainerDeclRef
+ {
+ SLANG_DECLARE_DECL_REF(AggTypeDecl);
+
+ ExtensionDecl* GetCandidateExtensions() const { return GetDecl()->candidateExtensions; }
+ };
+
+ class StructSyntaxNode : public AggTypeDecl
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ struct StructDeclRef : public AggTypeDeclRef
+ {
+ SLANG_DECLARE_DECL_REF(StructSyntaxNode);
+
+ FilteredMemberRefList<FieldDeclRef> GetFields() const { return GetMembersOfType<FieldDeclRef>(); }
+ };
+
+ class ClassSyntaxNode : public AggTypeDecl
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ struct ClassDeclRef : public AggTypeDeclRef
+ {
+ SLANG_DECLARE_DECL_REF(ClassSyntaxNode);
+
+ FilteredMemberRefList<FieldDeclRef> GetFields() const { return GetMembersOfType<FieldDeclRef>(); }
+ };
+
+ // A trait which other types can conform to
+ class TraitDecl : public AggTypeDecl
+ {
+ public:
+ List<TypeExp> bases;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ struct TraitDeclRef : public AggTypeDeclRef
+ {
+ SLANG_DECLARE_DECL_REF(TraitDecl);
+ };
+
+
+ // A declaration that states that the enclosing type supports a given trait
+ //
+ // TODO: this same construct might be used for represent other inheritance-like cases
+ class TraitConformanceDecl : public Decl
+ {
+ public:
+ // The type expression as written
+ TypeExp base;
+
+ // The trait that we found we conform to...
+ TraitDeclRef traitDeclRef;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ struct TraitConformanceDeclRef : public DeclRef
+ {
+ SLANG_DECLARE_DECL_REF(TraitConformanceDecl);
+
+ TraitDeclRef GetTraitDeclRef() { return Substitute(GetDecl()->traitDeclRef).As<TraitDeclRef>(); }
+ };
+
+ // A declaration that represents a simple (non-aggregate) type
+ class SimpleTypeDecl : public Decl
+ {
+ };
+
+ struct SimpleTypeDeclRef : DeclRef
+ {
+ SLANG_DECLARE_DECL_REF(SimpleTypeDecl)
+ };
+
+ // A `typedef` declaration
+ class TypeDefDecl : public SimpleTypeDecl
+ {
+ public:
+ TypeExp Type;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ struct TypeDefDeclRef : SimpleTypeDeclRef
+ {
+ SLANG_DECLARE_DECL_REF(TypeDefDecl);
+
+ RefPtr<ExpressionType> GetType() const { return Substitute(GetDecl()->Type); }
+ };
+
+ // A type alias of some kind (e.g., via `typedef`)
+ class NamedExpressionType : public ExpressionType
+ {
+ public:
+ NamedExpressionType(TypeDefDeclRef declRef)
+ : declRef(declRef)
+ {}
+
+ TypeDefDeclRef declRef;
+
+ virtual String ToString() override;
+
+ protected:
+ virtual bool EqualsImpl(ExpressionType * type) override;
+ virtual ExpressionType* CreateCanonicalType() override;
+ virtual int GetHashCode() override;
+ };
+
+
+ class StatementSyntaxNode : public ModifiableSyntaxNode
+ {
+ public:
+ };
+
+ // A scope for local declarations (e.g., as part of a statement)
+ class ScopeDecl : public ContainerDecl
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class ScopeStmt : public StatementSyntaxNode
+ {
+ public:
+ RefPtr<ScopeDecl> scopeDecl;
+ };
+
+ class BlockStatementSyntaxNode : public ScopeStmt
+ {
+ public:
+ List<RefPtr<StatementSyntaxNode>> Statements;
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class UnparsedStmt : public StatementSyntaxNode
+ {
+ public:
+ // The tokens that were contained between `{` and `}`
+ List<Token> tokens;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class ParameterSyntaxNode : public VarDeclBase
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ struct ParamDeclRef : VarDeclBaseRef
+ {
+ SLANG_DECLARE_DECL_REF(ParameterSyntaxNode);
+ };
+
+ // Base class for things that have parameter lists and can thus be applied to arguments ("called")
+ class CallableDecl : public ContainerDecl
+ {
+ public:
+ FilteredMemberList<ParameterSyntaxNode> GetParameters()
+ {
+ return GetMembersOfType<ParameterSyntaxNode>();
+ }
+ TypeExp ReturnType;
+ };
+
+ struct CallableDeclRef : ContainerDeclRef
+ {
+ SLANG_DECLARE_DECL_REF(CallableDecl);
+
+ RefPtr<ExpressionType> GetResultType() const
+ {
+ return Substitute(GetDecl()->ReturnType.type.Ptr());
+ }
+
+ FilteredMemberRefList<ParamDeclRef> GetParameters()
+ {
+ return GetMembersOfType<ParamDeclRef>();
+ }
+ };
+
+ // Base class for callable things that may also have a body that is evaluated to produce their result
+ class FunctionDeclBase : public CallableDecl
+ {
+ public:
+ RefPtr<StatementSyntaxNode> Body;
+ };
+
+ struct FuncDeclBaseRef : CallableDeclRef
+ {
+ SLANG_DECLARE_DECL_REF(FunctionDeclBase);
+ };
+
+ // Function types are currently used for references to symbols that name
+ // either ordinary functions, or "component functions."
+ // We do not directly store a representation of the type, and instead
+ // use a reference to the symbol to stand in for its logical type
+ class FuncType : public ExpressionType
+ {
+ public:
+ CallableDeclRef declRef;
+
+ virtual String ToString() override;
+ protected:
+ virtual bool EqualsImpl(ExpressionType * type) override;
+ virtual ExpressionType* CreateCanonicalType() override;
+ virtual int GetHashCode() override;
+ };
+
+ // A constructor/initializer to create instances of a type
+ class ConstructorDecl : public FunctionDeclBase
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ struct ConstructorDeclRef : FuncDeclBaseRef
+ {
+ SLANG_DECLARE_DECL_REF(ConstructorDecl);
+ };
+
+ // A subscript operation used to index instances of a type
+ class SubscriptDecl : public CallableDecl
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ struct SubscriptDeclRef : CallableDeclRef
+ {
+ SLANG_DECLARE_DECL_REF(SubscriptDecl);
+ };
+
+ // An "accessor" for a subscript or property
+ class AccessorDecl : public FunctionDeclBase
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class GetterDecl : public AccessorDecl
+ {
+ };
+
+ class SetterDecl : public AccessorDecl
+ {
+ };
+
+ //
+
+ class FunctionSyntaxNode : public FunctionDeclBase
+ {
+ public:
+ String InternalName;
+ bool IsInline() { return HasModifier<InlineModifier>(); }
+ bool IsExtern() { return HasModifier<ExternModifier>(); }
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ FunctionSyntaxNode()
+ {
+ }
+ };
+
+ struct FuncDeclRef : FuncDeclBaseRef
+ {
+ SLANG_DECLARE_DECL_REF(FunctionSyntaxNode);
+ };
+
+
+ struct Scope : public RefObject
+ {
+ // The parent of this scope (where lookup should go if nothing is found locally)
+ RefPtr<Scope> parent;
+
+ // The next sibling of this scope (a peer for lookup)
+ RefPtr<Scope> nextSibling;
+
+ // The container to use for lookup
+ //
+ // Note(tfoley): This is kept as an unowned pointer
+ // so that a scope can't keep parts of the AST alive,
+ // but the opposite it allowed.
+ ContainerDecl* containerDecl;
+ };
+
+ // Base class for expressions that will reference declarations
+ class DeclRefExpr : public ExpressionSyntaxNode
+ {
+ public:
+ // The scope in which to perform lookup
+ RefPtr<Scope> scope;
+
+ // The declaration of the symbol being referenced
+ DeclRef declRef;
+ };
+
+ class VarExpressionSyntaxNode : public DeclRefExpr
+ {
+ public:
+ // The name of the symbol being referenced
+ String Variable;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ // Masks to be applied when lookup up declarations
+ enum class LookupMask : uint8_t
+ {
+ Type = 0x1,
+ Function = 0x2,
+ Value = 0x4,
+
+ All = Type | Function | Value,
+ };
+
+ // Represents one item found during lookup
+ struct LookupResultItem
+ {
+ // Sometimes lookup finds an item, but there were additional
+ // "hops" taken to reach it. We need to remember these steps
+ // so that if/when we consturct a full expression we generate
+ // appropriate AST nodes for all the steps.
+ //
+ // We build up a list of these "breadcrumbs" while doing
+ // lookup, and store them alongside each item found.
+ class Breadcrumb : public RefObject
+ {
+ public:
+ enum class Kind
+ {
+ Member, // A member was references
+ Deref, // A value with pointer(-like) type was dereferenced
+ };
+
+ Kind kind;
+ DeclRef declRef;
+ RefPtr<Breadcrumb> next;
+
+ Breadcrumb(Kind kind, DeclRef declRef, RefPtr<Breadcrumb> next)
+ : kind(kind)
+ , declRef(declRef)
+ , next(next)
+ {}
+ };
+
+ // A properly-specialized reference to the declaration that was found.
+ DeclRef declRef;
+
+ // Any breadcrumbs needed in order to turn that declaration
+ // reference into a well-formed expression.
+ //
+ // This is unused in the simple case where a declaration
+ // is being referenced directly (rather than through
+ // transparent members).
+ RefPtr<Breadcrumb> breadcrumbs;
+
+ LookupResultItem() = default;
+ explicit LookupResultItem(DeclRef declRef)
+ : declRef(declRef)
+ {}
+ LookupResultItem(DeclRef declRef, RefPtr<Breadcrumb> breadcrumbs)
+ : declRef(declRef)
+ , breadcrumbs(breadcrumbs)
+ {}
+ };
+
+
+ // Result of looking up a name in some lexical/semantic environment.
+ // Can be used to enumerate all the declarations matching that name,
+ // in the case where the result is overloaded.
+ struct LookupResult
+ {
+ // The one item that was found, in the smple case
+ LookupResultItem item;
+
+ // All of the items that were found, in the complex case.
+ // Note: if there was no overloading, then this list isn't
+ // used at all, to avoid allocation.
+ List<LookupResultItem> items;
+
+ // Was at least one result found?
+ bool isValid() const { return item.declRef.GetDecl() != nullptr; }
+
+ bool isOverloaded() const { return items.Count() > 1; }
+ };
+
+ struct LookupRequest
+ {
+ RefPtr<Scope> scope = nullptr;
+ RefPtr<Scope> endScope = nullptr;
+
+ LookupMask mask = LookupMask::All;
+ };
+
+ // An expression that references an overloaded set of declarations
+ // having the same name.
+ class OverloadedExpr : public ExpressionSyntaxNode
+ {
+ public:
+ // Optional: the base expression is this overloaded result
+ // arose from a member-reference expression.
+ RefPtr<ExpressionSyntaxNode> base;
+
+ // The lookup result that was ambiguous
+ LookupResult lookupResult2;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ typedef double FloatingPointLiteralValue;
+
+ class ConstantExpressionSyntaxNode : public ExpressionSyntaxNode
+ {
+ public:
+ enum class ConstantType
+ {
+ Int, Bool, Float
+ };
+ ConstantType ConstType;
+ union
+ {
+ int IntValue;
+ FloatingPointLiteralValue FloatValue;
+ };
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ enum class Operator
+ {
+ Neg, Not, BitNot, PreInc, PreDec, PostInc, PostDec,
+ Mul, Div, Mod,
+ Add, Sub,
+ Lsh, Rsh,
+ Eql, Neq, Greater, Less, Geq, Leq,
+ BitAnd, BitXor, BitOr,
+ And,
+ Or,
+ Sequence,
+ Select,
+ Assign = 200, AddAssign, SubAssign, MulAssign, DivAssign, ModAssign,
+ LshAssign, RshAssign, OrAssign, AndAssign, XorAssign,
+ };
+ String GetOperatorFunctionName(Operator op);
+ String OperatorToString(Operator op);
+
+ // An initializer list, e.g. `{ 1, 2, 3 }`
+ class InitializerListExpr : public ExpressionSyntaxNode
+ {
+ public:
+ List<RefPtr<ExpressionSyntaxNode>> args;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ // A base expression being applied to arguments: covers
+ // both ordinary `()` function calls and `<>` generic application
+ class AppExprBase : public ExpressionSyntaxNode
+ {
+ public:
+ RefPtr<ExpressionSyntaxNode> FunctionExpr;
+ List<RefPtr<ExpressionSyntaxNode>> Arguments;
+ };
+
+
+ class InvokeExpressionSyntaxNode : public AppExprBase
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class OperatorExpressionSyntaxNode : public InvokeExpressionSyntaxNode
+ {
+ public:
+// Operator Operator;
+// void SetOperator(RefPtr<Scope> scope, Slang::Compiler::Operator op);
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class InfixExpr : public OperatorExpressionSyntaxNode {};
+ class PrefixExpr : public OperatorExpressionSyntaxNode {};
+ class PostfixExpr : public OperatorExpressionSyntaxNode {};
+
+ class IndexExpressionSyntaxNode : public ExpressionSyntaxNode
+ {
+ public:
+ RefPtr<ExpressionSyntaxNode> BaseExpression;
+ RefPtr<ExpressionSyntaxNode> IndexExpression;
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class MemberExpressionSyntaxNode : public DeclRefExpr
+ {
+ public:
+ RefPtr<ExpressionSyntaxNode> BaseExpression;
+ String MemberName;
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class SwizzleExpr : public ExpressionSyntaxNode
+ {
+ public:
+ RefPtr<ExpressionSyntaxNode> base;
+ int elementCount;
+ int elementIndices[4];
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ // A dereference of a pointer or pointer-like type
+ class DerefExpr : public ExpressionSyntaxNode
+ {
+ public:
+ RefPtr<ExpressionSyntaxNode> base;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class TypeCastExpressionSyntaxNode : public ExpressionSyntaxNode
+ {
+ public:
+ TypeExp TargetType;
+ RefPtr<ExpressionSyntaxNode> Expression;
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class SelectExpressionSyntaxNode : public OperatorExpressionSyntaxNode
+ {
+ public:
+ };
+
+
+ class EmptyStatementSyntaxNode : public StatementSyntaxNode
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class DiscardStatementSyntaxNode : public StatementSyntaxNode
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ struct Variable : public VarDeclBase
+ {
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class VarDeclrStatementSyntaxNode : public StatementSyntaxNode
+ {
+ public:
+ RefPtr<DeclBase> decl;
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class UsingFileDecl : public Decl
+ {
+ public:
+ Token fileName;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class ProgramSyntaxNode : public ContainerDecl
+ {
+ public:
+ // Access members of specific types
+ FilteredMemberList<UsingFileDecl> GetUsings()
+ {
+ return GetMembersOfType<UsingFileDecl>();
+ }
+ FilteredMemberList<FunctionSyntaxNode> GetFunctions()
+ {
+ return GetMembersOfType<FunctionSyntaxNode>();
+ }
+
+ FilteredMemberList<ClassSyntaxNode> GetClasses()
+ {
+ return GetMembersOfType<ClassSyntaxNode>();
+ }
+ FilteredMemberList<StructSyntaxNode> GetStructs()
+ {
+ return GetMembersOfType<StructSyntaxNode>();
+ }
+ FilteredMemberList<TypeDefDecl> GetTypeDefs()
+ {
+ return GetMembersOfType<TypeDefDecl>();
+ }
+#if 0
+ void Include(ProgramSyntaxNode * other)
+ {
+ Members.AddRange(other->Members);
+ }
+#endif
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class IfStatementSyntaxNode : public StatementSyntaxNode
+ {
+ public:
+ RefPtr<ExpressionSyntaxNode> Predicate;
+ RefPtr<StatementSyntaxNode> PositiveStatement;
+ RefPtr<StatementSyntaxNode> NegativeStatement;
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ // A statement that can be escaped with a `break`
+ class BreakableStmt : public ScopeStmt
+ {};
+
+ class SwitchStmt : public BreakableStmt
+ {
+ public:
+ RefPtr<ExpressionSyntaxNode> condition;
+ RefPtr<StatementSyntaxNode> body;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ // A statement that is expected to appear lexically nested inside
+ // some other construct, and thus needs to keep track of the
+ // outer statement that it is associated with...
+ class ChildStmt : public StatementSyntaxNode
+ {
+ public:
+ StatementSyntaxNode* parentStmt = nullptr;
+ };
+
+ // a `case` or `default` statement inside a `switch`
+ //
+ // Note(tfoley): A correct AST for a C-like language would treat
+ // these as a labelled statement, and so they would contain a
+ // sub-statement. I'm leaving that out for now for simplicity.
+ class CaseStmtBase : public ChildStmt
+ {
+ public:
+ };
+
+ // a `case` statement inside a `switch`
+ class CaseStmt : public CaseStmtBase
+ {
+ public:
+ RefPtr<ExpressionSyntaxNode> expr;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ // a `default` statement inside a `switch`
+ class DefaultStmt : public CaseStmtBase
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ // A statement that represents a loop, and can thus be escaped with a `continue`
+ class LoopStmt : public BreakableStmt
+ {};
+
+ class ForStatementSyntaxNode : public LoopStmt
+ {
+ public:
+ RefPtr<StatementSyntaxNode> InitialStatement;
+ RefPtr<ExpressionSyntaxNode> SideEffectExpression, PredicateExpression;
+ RefPtr<StatementSyntaxNode> Statement;
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class WhileStatementSyntaxNode : public LoopStmt
+ {
+ public:
+ RefPtr<ExpressionSyntaxNode> Predicate;
+ RefPtr<StatementSyntaxNode> Statement;
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class DoWhileStatementSyntaxNode : public LoopStmt
+ {
+ public:
+ RefPtr<StatementSyntaxNode> Statement;
+ RefPtr<ExpressionSyntaxNode> Predicate;
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ // The case of child statements that do control flow relative
+ // to their parent statement.
+ class JumpStmt : public ChildStmt
+ {
+ public:
+ StatementSyntaxNode* parentStmt = nullptr;
+ };
+
+ class BreakStatementSyntaxNode : public JumpStmt
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class ContinueStatementSyntaxNode : public JumpStmt
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class ReturnStatementSyntaxNode : public StatementSyntaxNode
+ {
+ public:
+ RefPtr<ExpressionSyntaxNode> Expression;
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ class ExpressionStatementSyntaxNode : public StatementSyntaxNode
+ {
+ public:
+ RefPtr<ExpressionSyntaxNode> Expression;
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ // Note(tfoley): Moved this further down in the file because it depends on
+ // `ExpressionSyntaxNode` and a forward reference just isn't good enough
+ // for `RefPtr`.
+ //
+ class GenericAppExpr : public AppExprBase
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ // An expression representing re-use of the syntax for a type in more
+ // than once conceptually-distinct declaration
+ class SharedTypeExpr : public ExpressionSyntaxNode
+ {
+ public:
+ // The underlying type expression that we want to share
+ TypeExp base;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+
+ // A modifier that indicates a built-in base type (e.g., `float`)
+ class BuiltinTypeModifier : public Modifier
+ {
+ public:
+ BaseType tag;
+ };
+
+ // A modifier that indicates a built-in type that isn't a base type (e.g., `vector`)
+ //
+ // TODO(tfoley): This deserves a better name than "magic"
+ class MagicTypeModifier : public Modifier
+ {
+ public:
+ String name;
+ uint32_t tag;
+ };
+
+ // Modifiers that affect the storage layout for matrices
+ class MatrixLayoutModifier : public Modifier {};
+
+ // Modifiers that specify row- and column-major layout, respectively
+ class RowMajorLayoutModifier : public MatrixLayoutModifier {};
+ class ColumnMajorLayoutModifier : public MatrixLayoutModifier {};
+
+ // The HLSL flavor of those modifiers
+ class HLSLRowMajorLayoutModifier : public RowMajorLayoutModifier {};
+ class HLSLColumnMajorLayoutModifier : public ColumnMajorLayoutModifier {};
+
+ // The GLSL flavor of those modifiers
+ //
+ // Note(tfoley): The GLSL versions of these modifiers are "backwards"
+ // in the sense that when a GLSL programmer requests row-major layout,
+ // we actually interpret that as requesting column-major. This makes
+ // sense because we interpret matrix conventions backwards from how
+ // GLSL specifies them.
+ class GLSLRowMajorLayoutModifier : public ColumnMajorLayoutModifier {};
+ class GLSLColumnMajorLayoutModifier : public RowMajorLayoutModifier {};
+
+ // More HLSL Keyword
+
+ // HLSL `nointerpolation` modifier
+ class HLSLNoInterpolationModifier : public Modifier {};
+
+ // HLSL `linear` modifier
+ class HLSLLinearModifier : public Modifier {};
+
+ // HLSL `sample` modifier
+ class HLSLSampleModifier : public Modifier {};
+
+ // HLSL `centroid` modifier
+ class HLSLCentroidModifier : public Modifier {};
+
+ // HLSL `precise` modifier
+ class HLSLPreciseModifier : public Modifier {};
+
+ // HLSL `shared` modifier (which is used by the effect system,
+ // and shouldn't be confused with `groupshared`)
+ class HLSLEffectSharedModifier : public Modifier {};
+
+ // HLSL `groupshared` modifier
+ class HLSLGroupSharedModifier : public Modifier {};
+
+ // HLSL `static` modifier (probably doesn't need to be
+ // treated as HLSL-specific)
+ class HLSLStaticModifier : public Modifier {};
+
+ // HLSL `uniform` modifier (distinct meaning from GLSL
+ // use of the keyword)
+ class HLSLUniformModifier : public Modifier {};
+
+ // HLSL `volatile` modifier (ignored)
+ class HLSLVolatileModifier : public Modifier {};
+
+ // An HLSL `[name(arg0, ...)]` style attribute.
+ class HLSLAttribute : public Modifier
+ {
+ public:
+ Token nameToken;
+ List<RefPtr<ExpressionSyntaxNode>> args;
+ };
+
+ // An HLSL `[name(...)]` attribute that hasn't undergone
+ // any semantic analysis.
+ // After analysis, this might be transformed into a more specific case.
+ class HLSLUncheckedAttribute : public HLSLAttribute
+ {
+ public:
+ };
+
+ // An HLSL `[numthreads(x,y,z)]` attribute
+ class HLSLNumThreadsAttribute : public HLSLAttribute
+ {
+ public:
+ // The number of threads to use along each axis
+ int32_t x;
+ int32_t y;
+ int32_t z;
+ };
+
+ // HLSL modifiers for geometry shader input topology
+ class HLSLGeometryShaderInputPrimitiveTypeModifier : public Modifier {};
+ class HLSLPointModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {};
+ class HLSLLineModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {};
+ class HLSLTriangleModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {};
+ class HLSLLineAdjModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {};
+ class HLSLTriangleAdjModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {};
+
+ //
+
+ // A generic declaration, parameterized on types/values
+ class GenericDecl : public ContainerDecl
+ {
+ public:
+ // The decl that is genericized...
+ RefPtr<Decl> inner;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ struct GenericDeclRef : ContainerDeclRef
+ {
+ SLANG_DECLARE_DECL_REF(GenericDecl);
+
+ Decl* GetInner() const { return GetDecl()->inner.Ptr(); }
+ };
+
+ // The "type" of an expression that names a generic declaration.
+ class GenericDeclRefType : public ExpressionType
+ {
+ public:
+ GenericDeclRefType(GenericDeclRef declRef)
+ : declRef(declRef)
+ {}
+
+ GenericDeclRef declRef;
+ GenericDeclRef const& GetDeclRef() const { return declRef; }
+
+ virtual String ToString() override;
+
+ protected:
+ virtual bool EqualsImpl(ExpressionType * type) override;
+ virtual int GetHashCode() override;
+ virtual ExpressionType* CreateCanonicalType() override;
+ };
+
+
+
+ class GenericTypeParamDecl : public SimpleTypeDecl
+ {
+ public:
+ // The bound for the type parameter represents a trait that any
+ // type used as this parameter must conform to
+// TypeExp bound;
+
+ // The "initializer" for the parameter represents a default value
+ TypeExp initType;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ struct GenericTypeParamDeclRef : SimpleTypeDeclRef
+ {
+ SLANG_DECLARE_DECL_REF(GenericTypeParamDecl);
+ };
+
+ // A constraint placed as part of a generic declaration
+ class GenericTypeConstraintDecl : public Decl
+ {
+ public:
+ // A type constraint like `T : U` is constraining `T` to be "below" `U`
+ // on a lattice of types. This may not be a subtyping relationship
+ // per se, but it makes sense to use that terminology here, so we
+ // think of these fields as the sub-type and sup-ertype, respectively.
+ TypeExp sub;
+ TypeExp sup;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ struct GenericTypeConstraintDeclRef : DeclRef
+ {
+ SLANG_DECLARE_DECL_REF(GenericTypeConstraintDecl);
+
+ RefPtr<ExpressionType> GetSub() { return Substitute(GetDecl()->sub); }
+ RefPtr<ExpressionType> GetSup() { return Substitute(GetDecl()->sup); }
+ };
+
+
+ class GenericValueParamDecl : public VarDeclBase
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ struct GenericValueParamDeclRef : VarDeclBaseRef
+ {
+ SLANG_DECLARE_DECL_REF(GenericValueParamDecl);
+ };
+
+ // The logical "value" of a rererence to a generic value parameter
+ class GenericParamIntVal : public IntVal
+ {
+ public:
+ VarDeclBaseRef declRef;
+
+ GenericParamIntVal(VarDeclBaseRef declRef)
+ : declRef(declRef)
+ {}
+
+ virtual bool EqualsVal(Val* val) override;
+ virtual String ToString() override;
+ virtual int GetHashCode() override;
+ virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override;
+ };
+
+ // Declaration of a user-defined modifier
+ class ModifierDecl : public Decl
+ {
+ public:
+ // The name of the C++ class to instantiate
+ // (this is a reference to a class in the compiler source code,
+ // and not the user's source code)
+ Token classNameToken;
+
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ // An empty declaration (which might still have modifiers attached).
+ //
+ // An empty declaration is uncommon in HLSL, but
+ // in GLSL it is often used at the global scope
+ // to declare metadata that logically belongs
+ // to the entry point, e.g.:
+ //
+ // layout(local_size_x = 16) in;
+ //
+ class EmptyDecl : public Decl
+ {
+ public:
+ virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override;
+ };
+
+ //
+
+ class SyntaxVisitor : public Object
+ {
+ protected:
+ DiagnosticSink * sink = nullptr;
+ DiagnosticSink* getSink() { return sink; }
+
+ SourceLanguage sourceLanguage = SourceLanguage::Unknown;
+ public:
+ void setSourceLanguage(SourceLanguage language)
+ {
+ sourceLanguage = language;
+ }
+
+ SyntaxVisitor(DiagnosticSink * sink)
+ : sink(sink)
+ {}
+ virtual RefPtr<ProgramSyntaxNode> VisitProgram(ProgramSyntaxNode* program)
+ {
+ for (auto & m : program->Members)
+ m = m->Accept(this).As<Decl>();
+ return program;
+ }
+
+ virtual RefPtr<UsingFileDecl> VisitUsingFileDecl(UsingFileDecl * decl)
+ {
+ return decl;
+ }
+
+ virtual RefPtr<FunctionSyntaxNode> VisitFunction(FunctionSyntaxNode* func)
+ {
+ func->ReturnType = func->ReturnType.Accept(this);
+ for (auto & member : func->Members)
+ member = member->Accept(this).As<Decl>();
+ if (func->Body)
+ func->Body = func->Body->Accept(this).As<BlockStatementSyntaxNode>();
+ return func;
+ }
+ virtual RefPtr<ScopeDecl> VisitScopeDecl(ScopeDecl* decl)
+ {
+ // By default don't visit children, because they will always
+ // be encountered in the ordinary flow of the corresponding statement.
+ return decl;
+ }
+ virtual RefPtr<StructSyntaxNode> VisitStruct(StructSyntaxNode * s)
+ {
+ for (auto & f : s->Members)
+ f = f->Accept(this).As<Decl>();
+ return s;
+ }
+ virtual RefPtr<ClassSyntaxNode> VisitClass(ClassSyntaxNode * s)
+ {
+ for (auto & f : s->Members)
+ f = f->Accept(this).As<Decl>();
+ return s;
+ }
+ virtual RefPtr<GenericDecl> VisitGenericDecl(GenericDecl * decl)
+ {
+ for (auto & m : decl->Members)
+ m = m->Accept(this).As<Decl>();
+ decl->inner = decl->inner->Accept(this).As<Decl>();
+ return decl;
+ }
+ virtual RefPtr<TypeDefDecl> VisitTypeDefDecl(TypeDefDecl* decl)
+ {
+ decl->Type = decl->Type.Accept(this);
+ return decl;
+ }
+ virtual RefPtr<StatementSyntaxNode> VisitDiscardStatement(DiscardStatementSyntaxNode * stmt)
+ {
+ return stmt;
+ }
+ virtual RefPtr<StructField> VisitStructField(StructField * f)
+ {
+ f->Type = f->Type.Accept(this);
+ return f;
+ }
+ virtual RefPtr<StatementSyntaxNode> VisitBlockStatement(BlockStatementSyntaxNode* stmt)
+ {
+ for (auto & s : stmt->Statements)
+ s = s->Accept(this).As<StatementSyntaxNode>();
+ return stmt;
+ }
+ virtual RefPtr<StatementSyntaxNode> VisitBreakStatement(BreakStatementSyntaxNode* stmt)
+ {
+ return stmt;
+ }
+ virtual RefPtr<StatementSyntaxNode> VisitContinueStatement(ContinueStatementSyntaxNode* stmt)
+ {
+ return stmt;
+ }
+
+ virtual RefPtr<StatementSyntaxNode> VisitDoWhileStatement(DoWhileStatementSyntaxNode* stmt)
+ {
+ if (stmt->Predicate)
+ stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>();
+ if (stmt->Statement)
+ stmt->Statement = stmt->Statement->Accept(this).As<StatementSyntaxNode>();
+ return stmt;
+ }
+ virtual RefPtr<StatementSyntaxNode> VisitEmptyStatement(EmptyStatementSyntaxNode* stmt)
+ {
+ return stmt;
+ }
+ virtual RefPtr<StatementSyntaxNode> VisitForStatement(ForStatementSyntaxNode* stmt)
+ {
+ if (stmt->InitialStatement)
+ stmt->InitialStatement = stmt->InitialStatement->Accept(this).As<StatementSyntaxNode>();
+ if (stmt->PredicateExpression)
+ stmt->PredicateExpression = stmt->PredicateExpression->Accept(this).As<ExpressionSyntaxNode>();
+ if (stmt->SideEffectExpression)
+ stmt->SideEffectExpression = stmt->SideEffectExpression->Accept(this).As<ExpressionSyntaxNode>();
+ if (stmt->Statement)
+ stmt->Statement = stmt->Statement->Accept(this).As<StatementSyntaxNode>();
+ return stmt;
+ }
+ virtual RefPtr<StatementSyntaxNode> VisitIfStatement(IfStatementSyntaxNode* stmt)
+ {
+ if (stmt->Predicate)
+ stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>();
+ if (stmt->PositiveStatement)
+ stmt->PositiveStatement = stmt->PositiveStatement->Accept(this).As<StatementSyntaxNode>();
+ if (stmt->NegativeStatement)
+ stmt->NegativeStatement = stmt->NegativeStatement->Accept(this).As<StatementSyntaxNode>();
+ return stmt;
+ }
+ virtual RefPtr<SwitchStmt> VisitSwitchStmt(SwitchStmt* stmt)
+ {
+ if (stmt->condition)
+ stmt->condition = stmt->condition->Accept(this).As<ExpressionSyntaxNode>();
+ if (stmt->body)
+ stmt->body = stmt->body->Accept(this).As<BlockStatementSyntaxNode>();
+ return stmt;
+ }
+ virtual RefPtr<CaseStmt> VisitCaseStmt(CaseStmt* stmt)
+ {
+ if (stmt->expr)
+ stmt->expr = stmt->expr->Accept(this).As<ExpressionSyntaxNode>();
+ return stmt;
+ }
+ virtual RefPtr<DefaultStmt> VisitDefaultStmt(DefaultStmt* stmt)
+ {
+ return stmt;
+ }
+ virtual RefPtr<StatementSyntaxNode> VisitReturnStatement(ReturnStatementSyntaxNode* stmt)
+ {
+ if (stmt->Expression)
+ stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>();
+ return stmt;
+ }
+ virtual RefPtr<StatementSyntaxNode> VisitVarDeclrStatement(VarDeclrStatementSyntaxNode* stmt)
+ {
+ stmt->decl = stmt->decl->Accept(this).As<DeclBase>();
+ return stmt;
+ }
+ virtual RefPtr<StatementSyntaxNode> VisitWhileStatement(WhileStatementSyntaxNode* stmt)
+ {
+ if (stmt->Predicate)
+ stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>();
+ if (stmt->Statement)
+ stmt->Statement = stmt->Statement->Accept(this).As<StatementSyntaxNode>();
+ return stmt;
+ }
+ virtual RefPtr<StatementSyntaxNode> VisitExpressionStatement(ExpressionStatementSyntaxNode* stmt)
+ {
+ if (stmt->Expression)
+ stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>();
+ return stmt;
+ }
+
+ virtual RefPtr<ExpressionSyntaxNode> VisitOperatorExpression(OperatorExpressionSyntaxNode* expr)
+ {
+ for (auto && child : expr->Arguments)
+ child->Accept(this);
+ return expr;
+ }
+ virtual RefPtr<ExpressionSyntaxNode> VisitConstantExpression(ConstantExpressionSyntaxNode* expr)
+ {
+ return expr;
+ }
+ virtual RefPtr<ExpressionSyntaxNode> VisitIndexExpression(IndexExpressionSyntaxNode* expr)
+ {
+ if (expr->BaseExpression)
+ expr->BaseExpression = expr->BaseExpression->Accept(this).As<ExpressionSyntaxNode>();
+ if (expr->IndexExpression)
+ expr->IndexExpression = expr->IndexExpression->Accept(this).As<ExpressionSyntaxNode>();
+ return expr;
+ }
+ virtual RefPtr<ExpressionSyntaxNode> VisitMemberExpression(MemberExpressionSyntaxNode * stmt)
+ {
+ if (stmt->BaseExpression)
+ stmt->BaseExpression = stmt->BaseExpression->Accept(this).As<ExpressionSyntaxNode>();
+ return stmt;
+ }
+ virtual RefPtr<ExpressionSyntaxNode> VisitSwizzleExpression(SwizzleExpr * expr)
+ {
+ if (expr->base)
+ expr->base->Accept(this);
+ return expr;
+ }
+ virtual RefPtr<ExpressionSyntaxNode> VisitInvokeExpression(InvokeExpressionSyntaxNode* stmt)
+ {
+ stmt->FunctionExpr->Accept(this);
+ for (auto & arg : stmt->Arguments)
+ arg = arg->Accept(this).As<ExpressionSyntaxNode>();
+ return stmt;
+ }
+ virtual RefPtr<ExpressionSyntaxNode> VisitTypeCastExpression(TypeCastExpressionSyntaxNode * stmt)
+ {
+ if (stmt->Expression)
+ stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>();
+ return stmt->Expression;
+ }
+ virtual RefPtr<ExpressionSyntaxNode> VisitVarExpression(VarExpressionSyntaxNode* expr)
+ {
+ return expr;
+ }
+
+ virtual RefPtr<ParameterSyntaxNode> VisitParameter(ParameterSyntaxNode* param)
+ {
+ return param;
+ }
+ virtual RefPtr<ExpressionSyntaxNode> VisitGenericApp(GenericAppExpr* type)
+ {
+ return type;
+ }
+
+ virtual RefPtr<Variable> VisitDeclrVariable(Variable* dclr)
+ {
+ if (dclr->Expr)
+ dclr->Expr = dclr->Expr->Accept(this).As<ExpressionSyntaxNode>();
+ return dclr;
+ }
+
+ virtual TypeExp VisitTypeExp(TypeExp const& typeExp)
+ {
+ TypeExp result = typeExp;
+ result.exp = typeExp.exp->Accept(this).As<ExpressionSyntaxNode>();
+ if (auto typeType = result.exp->Type.type.As<TypeType>())
+ {
+ result.type = typeType->type;
+ }
+ return result;
+ }
+
+ virtual void VisitExtensionDecl(ExtensionDecl* /*decl*/)
+ {}
+
+ virtual void VisitConstructorDecl(ConstructorDecl* /*decl*/)
+ {}
+
+ virtual void visitSubscriptDecl(SubscriptDecl* decl) = 0;
+
+ virtual void visitAccessorDecl(AccessorDecl* decl) = 0;
+
+ virtual void VisitTraitDecl(TraitDecl* /*decl*/)
+ {}
+
+ virtual void VisitTraitConformanceDecl(TraitConformanceDecl* /*decl*/)
+ {}
+
+ virtual RefPtr<ExpressionSyntaxNode> VisitSharedTypeExpr(SharedTypeExpr* typeExpr)
+ {
+ return typeExpr;
+ }
+
+ virtual void VisitDeclGroup(DeclGroup* declGroup)
+ {
+ for (auto decl : declGroup->decls)
+ {
+ decl->Accept(this);
+ }
+ }
+
+ virtual RefPtr<ExpressionSyntaxNode> visitInitializerListExpr(InitializerListExpr* expr) = 0;
+ };
+
+ // Note(tfoley): These logically belong to `ExpressionType`,
+ // but order-of-declaration stuff makes that tricky
+ //
+ // TODO(tfoley): These should really belong to the compilation context!
+ //
+ void RegisterBuiltinDecl(
+ RefPtr<Decl> decl,
+ RefPtr<BuiltinTypeModifier> modifier);
+ void RegisterMagicDecl(
+ RefPtr<Decl> decl,
+ RefPtr<MagicTypeModifier> modifier);
+
+ // Look up a magic declaration by its name
+ RefPtr<Decl> findMagicDecl(
+ String const& name);
+
+ // Create an instance of a syntax class by name
+ SyntaxNodeBase* createInstanceOfSyntaxClassByName(
+ String const& name);
+
+ }
+}
+
+#endif \ No newline at end of file
diff --git a/source/slang/token-defs.h b/source/slang/token-defs.h
new file mode 100644
index 000000000..f29574bbb
--- /dev/null
+++ b/source/slang/token-defs.h
@@ -0,0 +1,93 @@
+// token-defs.h
+
+// This file is meant to be included multiple times, to produce different
+// pieces of code related to tokens
+//
+// Each token is declared here with:
+//
+// TOKEN(id, desc)
+//
+// where `id` is the identifier that will be used for the token in
+// ordinary code, while `desc` is name we should print when
+// referring to this token in diagnostic messages.
+
+
+#ifndef TOKEN
+#error Need to define TOKEN(ID, DESC) before including "token-defs.h"
+#endif
+
+TOKEN(Unknown, "<unknown>")
+TOKEN(EndOfFile, "end of file")
+TOKEN(EndOfDirective, "end of line")
+TOKEN(Invalid, "invalid character")
+TOKEN(Identifier, "identifier")
+TOKEN(IntLiterial, "integer literal")
+TOKEN(DoubleLiterial, "floating-point literal")
+TOKEN(StringLiterial, "string literal")
+TOKEN(CharLiterial, "character literal")
+TOKEN(WhiteSpace, "whitespace")
+TOKEN(NewLine, "newline")
+TOKEN(LineComment, "line comment")
+TOKEN(BlockComment, "block comment")
+
+#define PUNCTUATION(id, text) \
+ TOKEN(id, "'" text "'")
+
+PUNCTUATION(Semicolon, ";")
+PUNCTUATION(Comma, ",")
+PUNCTUATION(Dot, ".")
+
+PUNCTUATION(LBrace, "{")
+PUNCTUATION(RBrace, "}")
+PUNCTUATION(LBracket, "[")
+PUNCTUATION(RBracket, "]")
+PUNCTUATION(LParent, "(")
+PUNCTUATION(RParent, ")")
+
+PUNCTUATION(OpAssign, "=")
+PUNCTUATION(OpAdd, "+")
+PUNCTUATION(OpSub, "-")
+PUNCTUATION(OpMul, "*")
+PUNCTUATION(OpDiv, "/")
+PUNCTUATION(OpMod, "%")
+PUNCTUATION(OpNot, "!")
+PUNCTUATION(OpBitNot, "~")
+PUNCTUATION(OpLsh, "<<")
+PUNCTUATION(OpRsh, ">>")
+PUNCTUATION(OpEql, "==")
+PUNCTUATION(OpNeq, "!=")
+PUNCTUATION(OpGreater, ">")
+PUNCTUATION(OpLess, "<")
+PUNCTUATION(OpGeq, ">=")
+PUNCTUATION(OpLeq, "<=")
+PUNCTUATION(OpAnd, "&&")
+PUNCTUATION(OpOr, "||")
+PUNCTUATION(OpBitAnd, "&")
+PUNCTUATION(OpBitOr, "|")
+PUNCTUATION(OpBitXor, "^")
+PUNCTUATION(OpInc, "++")
+PUNCTUATION(OpDec, "--")
+
+PUNCTUATION(OpAddAssign, "+=")
+PUNCTUATION(OpSubAssign, "-=")
+PUNCTUATION(OpMulAssign, "*=")
+PUNCTUATION(OpDivAssign, "/=")
+PUNCTUATION(OpModAssign, "%=")
+PUNCTUATION(OpShlAssign, "<<=")
+PUNCTUATION(OpShrAssign, ">>=")
+PUNCTUATION(OpAndAssign, "&=")
+PUNCTUATION(OpOrAssign, "|=")
+PUNCTUATION(OpXorAssign, "^=")
+
+PUNCTUATION(QuestionMark, "?")
+PUNCTUATION(Colon, ":")
+PUNCTUATION(RightArrow, "->")
+PUNCTUATION(At, "@")
+PUNCTUATION(Dollar, "$")
+PUNCTUATION(Pound, "#")
+PUNCTUATION(PoundPound, "##")
+
+#undef PUNCTUATION
+
+// Un-define the `TOKEN` macro so that client doesn't have to
+#undef TOKEN
diff --git a/source/slang/token.cpp b/source/slang/token.cpp
new file mode 100644
index 000000000..436a3a740
--- /dev/null
+++ b/source/slang/token.cpp
@@ -0,0 +1,22 @@
+// token.cpp
+#include "token.h"
+
+#include <assert.h>
+
+namespace Slang {
+namespace Compiler {
+
+char const* TokenTypeToString(TokenType type)
+{
+ switch( type )
+ {
+ default:
+ assert(!"unexpected");
+ return "<uknown>";
+
+#define TOKEN(NAME, DESC) case TokenType::NAME: return DESC;
+#include "token-defs.h"
+ }
+}
+
+}}
diff --git a/source/slang/token.h b/source/slang/token.h
new file mode 100644
index 000000000..00a55feb1
--- /dev/null
+++ b/source/slang/token.h
@@ -0,0 +1,50 @@
+// token.h
+#ifndef SLANG_TOKEN_H_INCLUDED
+#define SLANG_TOKEN_H_INCLUDED
+
+#include "../core/basic.h"
+
+#include "source-loc.h"
+
+namespace Slang {
+namespace Compiler {
+
+using namespace CoreLib::Basic;
+
+enum class TokenType
+{
+#define TOKEN(NAME, DESC) NAME,
+#include "token-defs.h"
+};
+
+char const* TokenTypeToString(TokenType type);
+
+enum TokenFlag : unsigned int
+{
+ AtStartOfLine = 1 << 0,
+ AfterWhitespace = 1 << 1,
+};
+typedef unsigned int TokenFlags;
+
+class Token
+{
+public:
+ TokenType Type = TokenType::Unknown;
+ String Content;
+ CodePosition Position;
+ TokenFlags flags = 0;
+ Token() = default;
+ Token(TokenType type, const String & content, int line, int col, int pos, String fileName, TokenFlags flags = 0)
+ : flags(flags)
+ {
+ Type = type;
+ Content = content;
+ Position = CodePosition(line, col, pos, fileName);
+ }
+};
+
+
+
+}}
+
+#endif
diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp
new file mode 100644
index 000000000..4e5c98ed5
--- /dev/null
+++ b/source/slang/type-layout.cpp
@@ -0,0 +1,1125 @@
+// TypeLayout.cpp
+#include "type-layout.h"
+
+#include "syntax.h"
+
+#include <assert.h>
+
+namespace Slang {
+namespace Compiler {
+
+size_t RoundToAlignment(size_t offset, size_t alignment)
+{
+ size_t remainder = offset % alignment;
+ if (remainder == 0)
+ return offset;
+ else
+ return offset + (alignment - remainder);
+}
+
+static size_t RoundUpToPowerOfTwo( size_t value )
+{
+ // TODO(tfoley): I know this isn't a fast approach
+ size_t result = 1;
+ while (result < value)
+ result *= 2;
+ return result;
+}
+
+struct DefaultLayoutRulesImpl : SimpleLayoutRulesImpl
+{
+ // Get size and alignment for a single value of base type.
+ SimpleLayoutInfo GetScalarLayout(BaseType baseType) override
+ {
+ switch (baseType)
+ {
+ case BaseType::Int:
+ case BaseType::UInt:
+ case BaseType::Float:
+ case BaseType::Bool:
+ return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4, 4 );
+
+ default:
+ assert(!"unimplemented");
+ return SimpleLayoutInfo( LayoutResourceKind::Uniform, 0, 1 );
+ }
+ }
+
+ virtual SimpleLayoutInfo GetScalarLayout(slang::TypeReflection::ScalarType scalarType)
+ {
+ switch( scalarType )
+ {
+ case slang::TypeReflection::ScalarType::Void: return SimpleLayoutInfo();
+ case slang::TypeReflection::ScalarType::None: return SimpleLayoutInfo();
+
+ // TODO(tfoley): At some point we don't want to lay out `bool` as 4 bytes by default...
+ case slang::TypeReflection::ScalarType::Bool: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4,4);
+ case slang::TypeReflection::ScalarType::Int32: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4,4);
+ case slang::TypeReflection::ScalarType::UInt32: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4,4);
+ case slang::TypeReflection::ScalarType::Int64: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 8,8);
+ case slang::TypeReflection::ScalarType::UInt64: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 8,8);
+
+ // TODO(tfoley): What actually happens if you use `half` in a constant buffer?
+ case slang::TypeReflection::ScalarType::Float16: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 2,2);
+ case slang::TypeReflection::ScalarType::Float32: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 4,4);
+ case slang::TypeReflection::ScalarType::Float64: return SimpleLayoutInfo( LayoutResourceKind::Uniform, 8,8);
+
+ default:
+ assert(!"unimplemented");
+ return SimpleLayoutInfo();
+ }
+ }
+
+ SimpleArrayLayoutInfo GetArrayLayout( SimpleLayoutInfo elementInfo, size_t elementCount) override
+ {
+ size_t stride = elementInfo.size;
+
+ SimpleArrayLayoutInfo arrayInfo;
+ arrayInfo.kind = elementInfo.kind;
+ arrayInfo.size = stride * elementCount;
+ arrayInfo.alignment = elementInfo.alignment;
+ arrayInfo.elementStride = stride;
+ return arrayInfo;
+ }
+
+ SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) override
+ {
+ SimpleLayoutInfo vectorInfo;
+ vectorInfo.kind = elementInfo.kind;
+ vectorInfo.size = elementInfo.size * elementCount;
+ vectorInfo.alignment = elementInfo.alignment;
+ return vectorInfo;
+ }
+
+ SimpleLayoutInfo GetMatrixLayout(SimpleLayoutInfo elementInfo, size_t rowCount, size_t columnCount) override
+ {
+ return GetArrayLayout(
+ GetVectorLayout(elementInfo, columnCount),
+ rowCount);
+ }
+
+ UniformLayoutInfo BeginStructLayout() override
+ {
+ UniformLayoutInfo structInfo(0, 1);
+ return structInfo;
+ }
+
+ size_t AddStructField(UniformLayoutInfo* ioStructInfo, UniformLayoutInfo fieldInfo) override
+ {
+ // Skip zero-size fields
+ if(fieldInfo.size == 0)
+ return ioStructInfo->size;
+
+ ioStructInfo->alignment = std::max(ioStructInfo->alignment, fieldInfo.alignment);
+ ioStructInfo->size = RoundToAlignment(ioStructInfo->size, fieldInfo.alignment);
+ size_t fieldOffset = ioStructInfo->size;
+ ioStructInfo->size += fieldInfo.size;
+ return fieldOffset;
+ }
+
+
+ void EndStructLayout(UniformLayoutInfo* ioStructInfo) override
+ {
+ ioStructInfo->size = RoundToAlignment(ioStructInfo->size, ioStructInfo->alignment);
+ }
+};
+
+// Capture common behavior betwen HLSL and GLSL (`std140`) constnat buffer rules
+struct DefaultConstantBufferLayoutRulesImpl : DefaultLayoutRulesImpl
+{
+ // The `std140` rules require that all array elements
+ // be a multiple of 16 bytes.
+ //
+ // HLSL agrees.
+ SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, size_t elementCount) override
+ {
+ if(elementInfo.kind == LayoutResourceKind::Uniform)
+ {
+ if (elementInfo.alignment < 16)
+ elementInfo.alignment = 16;
+ elementInfo.size = RoundToAlignment(elementInfo.size, elementInfo.alignment);
+ }
+ return DefaultLayoutRulesImpl::GetArrayLayout(elementInfo, elementCount);
+ }
+
+ // The `std140` rules require that a `struct` type be
+ // aligned to at least 16.
+ //
+ // HLSL agrees.
+ UniformLayoutInfo BeginStructLayout() override
+ {
+ return UniformLayoutInfo(0, 16);
+ }
+};
+
+struct GLSLConstantBufferLayoutRulesImpl : DefaultConstantBufferLayoutRulesImpl
+{
+};
+
+struct Std140LayoutRulesImpl : GLSLConstantBufferLayoutRulesImpl
+{
+ // The `std140` rules require vectors to be aligned to the next power of two
+ // up from their size (so a `float2` is 8-byte aligned, and a `float3` is
+ // 16-byte aligned).
+ SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) override
+ {
+ assert(elementInfo.kind == LayoutResourceKind::Uniform);
+ SimpleLayoutInfo vectorInfo(
+ LayoutResourceKind::Uniform,
+ elementInfo.size * elementCount,
+ RoundUpToPowerOfTwo(elementInfo.size * elementInfo.alignment));
+ return vectorInfo;
+ }
+};
+
+struct HLSLConstantBufferLayoutRulesImpl : DefaultConstantBufferLayoutRulesImpl
+{
+ // Can't let a `struct` field straddle a register (16-byte) boundary
+ size_t AddStructField(UniformLayoutInfo* ioStructInfo, UniformLayoutInfo fieldInfo) override
+ {
+ // Skip zero-size fields
+ if(fieldInfo.size == 0)
+ return ioStructInfo->size;
+
+ ioStructInfo->alignment = std::max(ioStructInfo->alignment, fieldInfo.alignment);
+ ioStructInfo->size = RoundToAlignment(ioStructInfo->size, fieldInfo.alignment);
+
+ size_t fieldOffset = ioStructInfo->size;
+ size_t fieldSize = fieldInfo.size;
+
+ // Would this field cross a 16-byte boundary?
+ auto registerSize = 16;
+ auto startRegister = fieldOffset / registerSize;
+ auto endRegister = (fieldOffset + fieldSize - 1) / registerSize;
+ if (startRegister != endRegister)
+ {
+ ioStructInfo->size = RoundToAlignment(ioStructInfo->size, size_t(registerSize));
+ fieldOffset = ioStructInfo->size;
+ }
+
+ ioStructInfo->size += fieldInfo.size;
+ return fieldOffset;
+ }
+};
+
+struct HLSLStructuredBufferLayoutRulesImpl : DefaultLayoutRulesImpl
+{
+ // TODO: customize these to be correct...
+};
+
+struct Std430LayoutRulesImpl : GLSLConstantBufferLayoutRulesImpl
+{
+};
+
+struct DefaultVaryingLayoutRulesImpl : DefaultLayoutRulesImpl
+{
+ LayoutResourceKind kind;
+
+ DefaultVaryingLayoutRulesImpl(LayoutResourceKind kind)
+ : kind(kind)
+ {}
+
+
+ // hook to allow differentiating for input/output
+ virtual LayoutResourceKind getKind()
+ {
+ return kind;
+ }
+
+ SimpleLayoutInfo GetScalarLayout(BaseType baseType) override
+ {
+ // Assume that all scalars take up one "slot"
+ return SimpleLayoutInfo(
+ getKind(),
+ 1);
+ }
+
+ virtual SimpleLayoutInfo GetScalarLayout(slang::TypeReflection::ScalarType scalarType)
+ {
+ // Assume that all scalars take up one "slot"
+ return SimpleLayoutInfo(
+ getKind(),
+ 1);
+ }
+
+ SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) override
+ {
+ // Vectors take up one slot by default
+ //
+ // TODO: some platforms may decide that vectors of `double` need
+ // special handling
+ return SimpleLayoutInfo(
+ getKind(),
+ 1);
+ }
+};
+
+struct GLSLVaryingLayoutRulesImpl : DefaultVaryingLayoutRulesImpl
+{
+ GLSLVaryingLayoutRulesImpl(LayoutResourceKind kind)
+ : DefaultVaryingLayoutRulesImpl(kind)
+ {}
+};
+
+struct HLSLVaryingLayoutRulesImpl : DefaultVaryingLayoutRulesImpl
+{
+ HLSLVaryingLayoutRulesImpl(LayoutResourceKind kind)
+ : DefaultVaryingLayoutRulesImpl(kind)
+ {}
+};
+
+//
+
+struct GLSLSpecializationConstantLayoutRulesImpl : DefaultLayoutRulesImpl
+{
+ LayoutResourceKind getKind()
+ {
+ return LayoutResourceKind::SpecializationConstant;
+ }
+
+ SimpleLayoutInfo GetScalarLayout(BaseType baseType) override
+ {
+ // Assume that all scalars take up one "slot"
+ return SimpleLayoutInfo(
+ getKind(),
+ 1);
+ }
+
+ virtual SimpleLayoutInfo GetScalarLayout(slang::TypeReflection::ScalarType scalarType)
+ {
+ // Assume that all scalars take up one "slot"
+ return SimpleLayoutInfo(
+ getKind(),
+ 1);
+ }
+
+ SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) override
+ {
+ // GLSL doesn't support vectors of specialization constants,
+ // but we will assume that, if supported, they would use one slot per element.
+ return SimpleLayoutInfo(
+ getKind(),
+ elementCount);
+ }
+};
+
+GLSLSpecializationConstantLayoutRulesImpl kGLSLSpecializationConstantLayoutRulesImpl;
+
+//
+
+struct GLSLObjectLayoutRulesImpl : ObjectLayoutRulesImpl
+{
+ virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind kind) override
+ {
+ // In Vulkan GLSL, pretty much every object is just a descriptor-table slot.
+ // We can refine this method once we support a case where this isn't true.
+ return SimpleLayoutInfo(LayoutResourceKind::DescriptorTableSlot, 1);
+ }
+};
+GLSLObjectLayoutRulesImpl kGLSLObjectLayoutRulesImpl;
+
+struct HLSLObjectLayoutRulesImpl : ObjectLayoutRulesImpl
+{
+ virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind kind) override
+ {
+ switch( kind )
+ {
+ case ShaderParameterKind::ConstantBuffer:
+ return SimpleLayoutInfo(LayoutResourceKind::ConstantBuffer, 1);
+
+ case ShaderParameterKind::TextureUniformBuffer:
+ case ShaderParameterKind::StructuredBuffer:
+ case ShaderParameterKind::SampledBuffer:
+ case ShaderParameterKind::RawBuffer:
+ case ShaderParameterKind::Buffer:
+ case ShaderParameterKind::Texture:
+ return SimpleLayoutInfo(LayoutResourceKind::ShaderResource, 1);
+
+ case ShaderParameterKind::MutableStructuredBuffer:
+ case ShaderParameterKind::MutableSampledBuffer:
+ case ShaderParameterKind::MutableRawBuffer:
+ case ShaderParameterKind::MutableBuffer:
+ case ShaderParameterKind::MutableTexture:
+ return SimpleLayoutInfo(LayoutResourceKind::UnorderedAccess, 1);
+
+ case ShaderParameterKind::SamplerState:
+ return SimpleLayoutInfo(LayoutResourceKind::SamplerState, 1);
+
+ case ShaderParameterKind::TextureSampler:
+ case ShaderParameterKind::MutableTextureSampler:
+ case ShaderParameterKind::InputRenderTarget:
+ // TODO: how to handle these?
+ default:
+ assert(!"unimplemented");
+ return SimpleLayoutInfo();
+ }
+ }
+};
+HLSLObjectLayoutRulesImpl kHLSLObjectLayoutRulesImpl;
+
+Std140LayoutRulesImpl kStd140LayoutRulesImpl;
+Std430LayoutRulesImpl kStd430LayoutRulesImpl;
+HLSLConstantBufferLayoutRulesImpl kHLSLConstantBufferLayoutRulesImpl;
+HLSLStructuredBufferLayoutRulesImpl kHLSLStructuredBufferLayoutRulesImpl;
+
+GLSLVaryingLayoutRulesImpl kGLSLVaryingInputLayoutRulesImpl(LayoutResourceKind::VertexInput);
+GLSLVaryingLayoutRulesImpl kGLSLVaryingOutputLayoutRulesImpl(LayoutResourceKind::FragmentOutput);
+
+HLSLVaryingLayoutRulesImpl kHLSLVaryingInputLayoutRulesImpl(LayoutResourceKind::VertexInput);
+HLSLVaryingLayoutRulesImpl kHLSLVaryingOutputLayoutRulesImpl(LayoutResourceKind::FragmentOutput);
+
+//
+
+struct GLSLLayoutRulesFamilyImpl : LayoutRulesFamilyImpl
+{
+ virtual LayoutRulesImpl* getConstantBufferRules() override;
+ virtual LayoutRulesImpl* getTextureBufferRules() override;
+ virtual LayoutRulesImpl* getVaryingInputRules() override;
+ virtual LayoutRulesImpl* getVaryingOutputRules() override;
+ virtual LayoutRulesImpl* getSpecializationConstantRules() override;
+ virtual LayoutRulesImpl* getShaderStorageBufferRules() override;
+};
+
+struct HLSLLayoutRulesFamilyImpl : LayoutRulesFamilyImpl
+{
+ virtual LayoutRulesImpl* getConstantBufferRules() override;
+ virtual LayoutRulesImpl* getTextureBufferRules() override;
+ virtual LayoutRulesImpl* getVaryingInputRules() override;
+ virtual LayoutRulesImpl* getVaryingOutputRules() override;
+ virtual LayoutRulesImpl* getSpecializationConstantRules() override;
+ virtual LayoutRulesImpl* getShaderStorageBufferRules() override;
+};
+
+GLSLLayoutRulesFamilyImpl kGLSLLayoutRulesFamilyImpl;
+HLSLLayoutRulesFamilyImpl kHLSLLayoutRulesFamilyImpl;
+
+
+// GLSL cases
+
+LayoutRulesImpl kStd140LayoutRulesImpl_ = {
+ &kGLSLLayoutRulesFamilyImpl, &kStd140LayoutRulesImpl, &kGLSLObjectLayoutRulesImpl,
+};
+
+LayoutRulesImpl kStd430LayoutRulesImpl_ = {
+ &kGLSLLayoutRulesFamilyImpl, &kStd430LayoutRulesImpl, &kGLSLObjectLayoutRulesImpl,
+};
+
+LayoutRulesImpl kGLSLVaryingInputLayoutRulesImpl_ = {
+ &kGLSLLayoutRulesFamilyImpl, &kGLSLVaryingInputLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl,
+};
+
+LayoutRulesImpl kGLSLVaryingOutputLayoutRulesImpl_ = {
+ &kGLSLLayoutRulesFamilyImpl, &kGLSLVaryingOutputLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl,
+};
+
+LayoutRulesImpl kGLSLSpecializationConstantLayoutRulesImpl_ = {
+ &kGLSLLayoutRulesFamilyImpl, &kGLSLSpecializationConstantLayoutRulesImpl, &kGLSLObjectLayoutRulesImpl,
+};
+
+// HLSL cases
+
+LayoutRulesImpl kHLSLConstantBufferLayoutRulesImpl_ = {
+ &kHLSLLayoutRulesFamilyImpl, &kHLSLConstantBufferLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl,
+};
+
+LayoutRulesImpl kHLSLStructuredBufferLayoutRulesImpl_ = {
+ &kHLSLLayoutRulesFamilyImpl, &kHLSLStructuredBufferLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl,
+};
+
+LayoutRulesImpl kHLSLVaryingInputLayoutRulesImpl_ = {
+ &kHLSLLayoutRulesFamilyImpl, &kHLSLVaryingInputLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl,
+};
+
+LayoutRulesImpl kHLSLVaryingOutputLayoutRulesImpl_ = {
+ &kHLSLLayoutRulesFamilyImpl, &kHLSLVaryingOutputLayoutRulesImpl, &kHLSLObjectLayoutRulesImpl,
+};
+
+//
+
+LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getConstantBufferRules()
+{
+ return &kStd140LayoutRulesImpl_;
+}
+
+LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getTextureBufferRules()
+{
+ return nullptr;
+}
+
+LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getVaryingInputRules()
+{
+ return &kGLSLVaryingInputLayoutRulesImpl_;
+}
+
+LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getVaryingOutputRules()
+{
+ return &kGLSLVaryingOutputLayoutRulesImpl_;
+}
+
+LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getSpecializationConstantRules()
+{
+ return &kGLSLSpecializationConstantLayoutRulesImpl_;
+}
+
+LayoutRulesImpl* GLSLLayoutRulesFamilyImpl::getShaderStorageBufferRules()
+{
+ return &kStd430LayoutRulesImpl_;
+}
+
+//
+
+LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getConstantBufferRules()
+{
+ return &kHLSLConstantBufferLayoutRulesImpl_;
+}
+
+LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getTextureBufferRules()
+{
+ return nullptr;
+}
+
+LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getVaryingInputRules()
+{
+ return &kHLSLVaryingInputLayoutRulesImpl_;
+}
+
+LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getVaryingOutputRules()
+{
+ return &kHLSLVaryingOutputLayoutRulesImpl_;
+}
+
+LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getSpecializationConstantRules()
+{
+ return nullptr;
+}
+
+LayoutRulesImpl* HLSLLayoutRulesFamilyImpl::getShaderStorageBufferRules()
+{
+ return nullptr;
+}
+
+//
+
+LayoutRulesImpl* GetLayoutRulesImpl(LayoutRule rule)
+{
+ switch (rule)
+ {
+ case LayoutRule::Std140: return &kStd140LayoutRulesImpl_;
+ case LayoutRule::Std430: return &kStd430LayoutRulesImpl_;
+ case LayoutRule::HLSLConstantBuffer: return &kHLSLConstantBufferLayoutRulesImpl_;
+ case LayoutRule::HLSLStructuredBuffer: return &kHLSLStructuredBufferLayoutRulesImpl_;
+ default:
+ return nullptr;
+ }
+}
+
+LayoutRulesFamilyImpl* GetLayoutRulesFamilyImpl(LayoutRulesFamily rule)
+{
+ switch (rule)
+ {
+ case LayoutRulesFamily::HLSL: return &kHLSLLayoutRulesFamilyImpl;
+ case LayoutRulesFamily::GLSL: return &kGLSLLayoutRulesFamilyImpl;
+ default:
+ return nullptr;
+ }
+}
+
+LayoutRulesFamilyImpl* GetLayoutRulesFamilyImpl(SourceLanguage language)
+{
+ switch (language)
+ {
+ case SourceLanguage::Slang:
+ case SourceLanguage::HLSL:
+ return &kHLSLLayoutRulesFamilyImpl;
+
+ case SourceLanguage::GLSL:
+ return &kGLSLLayoutRulesFamilyImpl;
+
+ default:
+ return nullptr;
+ }
+}
+
+
+static int GetElementCount(RefPtr<IntVal> val)
+{
+ if (auto constantVal = val.As<ConstantIntVal>())
+ {
+ return constantVal->value;
+ }
+ else if( auto varRefVal = val.As<GenericParamIntVal>() )
+ {
+ // TODO(tfoley): do something sensible in this case
+ return 0;
+ }
+ assert(!"unexpected");
+ return 0;
+}
+
+bool IsResourceKind(LayoutResourceKind kind)
+{
+ switch (kind)
+ {
+ case LayoutResourceKind::None:
+ case LayoutResourceKind::Uniform:
+ return false;
+
+ default:
+ return true;
+ }
+
+}
+
+SimpleLayoutInfo GetSimpleLayoutImpl(
+ SimpleLayoutInfo info,
+ RefPtr<ExpressionType> type,
+ LayoutRulesImpl* rules,
+ RefPtr<TypeLayout>* outTypeLayout)
+{
+ if (outTypeLayout)
+ {
+ RefPtr<TypeLayout> typeLayout = new TypeLayout();
+ *outTypeLayout = typeLayout;
+
+ typeLayout->type = type;
+ typeLayout->rules = rules;
+
+ typeLayout->uniformAlignment = info.alignment;
+
+ typeLayout->addResourceUsage(info.kind, info.size);
+ }
+
+ return info;
+}
+
+static SimpleLayoutInfo getParameterBlockLayoutInfo(
+ RefPtr<ParameterBlockType> type,
+ LayoutRulesImpl* rules)
+{
+ if( type->As<ConstantBufferType>() )
+ {
+ return rules->GetObjectLayout(ShaderParameterKind::ConstantBuffer);
+ }
+ else if( type->As<TextureBufferType>() )
+ {
+ return rules->GetObjectLayout(ShaderParameterKind::TextureUniformBuffer);
+ }
+ else if( type->As<GLSLShaderStorageBufferType>() )
+ {
+ return rules->GetObjectLayout(ShaderParameterKind::ShaderStorageBuffer);
+ }
+ // TODO: the vertex-input and fragment-output cases should
+ // only actually apply when we are at the appropriate stage in
+ // the pipeline...
+ else if( type->As<GLSLInputParameterBlockType>() )
+ {
+ return SimpleLayoutInfo(LayoutResourceKind::VertexInput, 0);
+ }
+ else if( type->As<GLSLOutputParameterBlockType>() )
+ {
+ return SimpleLayoutInfo(LayoutResourceKind::FragmentOutput, 0);
+ }
+ else
+ {
+ assert(!"unexpected");
+ return SimpleLayoutInfo();
+ }
+}
+
+
+RefPtr<ParameterBlockTypeLayout>
+createParameterBlockTypeLayout(
+ RefPtr<ParameterBlockType> parameterBlockType,
+ RefPtr<TypeLayout> elementTypeLayout,
+ LayoutRulesImpl* rules)
+{
+ auto info = getParameterBlockLayoutInfo(
+ parameterBlockType,
+ rules);
+
+ auto typeLayout = new ParameterBlockTypeLayout();
+
+ typeLayout->type = parameterBlockType;
+ typeLayout->rules = rules;
+
+ typeLayout->elementTypeLayout = elementTypeLayout;
+
+ // The layout of the constant buffer if it gets stored
+ // in another constant buffer is just what we computed
+ // originally (which should be a single binding "slot"
+ // and hence no uniform data).
+ //
+ typeLayout->uniformAlignment = info.alignment;
+ assert(!typeLayout->FindResourceInfo(LayoutResourceKind::Uniform));
+ assert(typeLayout->uniformAlignment == 1);
+
+ // TODO(tfoley): There is a subtle question here of whether
+ // a constant buffer declaration that then contains zero
+ // bytes of uniform data should actually allocate a CB
+ // binding slot. For now I'm going to try to ignore it,
+ // but handling this robustly could let other code
+ // simply handle the "global scope" as a giant outer
+ // CB declaration...
+
+ // Make sure that we allocate resource usage for the
+ // parameter block itself.
+ if( info.size )
+ {
+ typeLayout->addResourceUsage(
+ info.kind,
+ info.size);
+ }
+
+ // Now, if the element type itself had any resources, then
+ // we need to make these part of the layout for our block
+ //
+ // TODO: re-consider this decision, since it creates
+ // complications...
+ for( auto elementResourceInfo : elementTypeLayout->resourceInfos )
+ {
+ // Skip uniform data, since that is encapsualted behind the constant buffer
+ if(elementResourceInfo.kind == LayoutResourceKind::Uniform)
+ break;
+
+ typeLayout->addResourceUsage(elementResourceInfo);
+ }
+
+ return typeLayout;
+}
+
+LayoutRulesImpl* getParameterBufferElementTypeLayoutRules(
+ RefPtr<ParameterBlockType> parameterBlockType,
+ LayoutRulesImpl* rules)
+{
+ if( parameterBlockType->As<ConstantBufferType>() )
+ {
+ return rules->getLayoutRulesFamily()->getConstantBufferRules();
+ }
+ else if( parameterBlockType->As<TextureBufferType>() )
+ {
+ return rules->getLayoutRulesFamily()->getTextureBufferRules();
+ }
+ else if( parameterBlockType->As<GLSLInputParameterBlockType>() )
+ {
+ return rules->getLayoutRulesFamily()->getVaryingInputRules();
+ }
+ else if( parameterBlockType->As<GLSLOutputParameterBlockType>() )
+ {
+ return rules->getLayoutRulesFamily()->getVaryingOutputRules();
+ }
+ else if( parameterBlockType->As<GLSLShaderStorageBufferType>() )
+ {
+ return rules->getLayoutRulesFamily()->getShaderStorageBufferRules();
+ }
+ else
+ {
+ assert(!"unexpected");
+ return nullptr;
+ }
+}
+
+RefPtr<ParameterBlockTypeLayout>
+createParameterBlockTypeLayout(
+ RefPtr<ParameterBlockType> parameterBlockType,
+ LayoutRulesImpl* rules)
+{
+ // Determine the layout rules to use for the contents of the block
+ auto parameterBlockLayoutRules = getParameterBufferElementTypeLayoutRules(
+ parameterBlockType,
+ rules);
+
+ // Create and save type layout for the buffer contents.
+ auto elementTypeLayout = CreateTypeLayout(
+ parameterBlockType->elementType.Ptr(),
+ parameterBlockLayoutRules);
+
+ return createParameterBlockTypeLayout(
+ parameterBlockType,
+ elementTypeLayout,
+ rules);
+}
+
+// Create a type layout for a structured buffer type.
+RefPtr<StructuredBufferTypeLayout>
+createStructuredBufferTypeLayout(
+ ShaderParameterKind kind,
+ RefPtr<ExpressionType> structuredBufferType,
+ RefPtr<TypeLayout> elementTypeLayout,
+ LayoutRulesImpl* rules)
+{
+ auto info = rules->GetObjectLayout(kind);
+
+ auto typeLayout = new StructuredBufferTypeLayout();
+
+ typeLayout->type = structuredBufferType;
+ typeLayout->rules = rules;
+
+ typeLayout->elementTypeLayout = elementTypeLayout;
+
+ typeLayout->uniformAlignment = info.alignment;
+ assert(!typeLayout->FindResourceInfo(LayoutResourceKind::Uniform));
+ assert(typeLayout->uniformAlignment == 1);
+
+ if( info.size != 0 )
+ {
+ typeLayout->addResourceUsage(info.kind, info.size);
+ }
+
+ // Note: for now we don't deal with the case of a structured
+ // buffer that might contain anything other than "uniform" data,
+ // because there really isn't a way to implement that.
+
+ return typeLayout;
+}
+
+// Create a type layout for a structured buffer type.
+RefPtr<StructuredBufferTypeLayout>
+createStructuredBufferTypeLayout(
+ ShaderParameterKind kind,
+ RefPtr<ExpressionType> structuredBufferType,
+ RefPtr<ExpressionType> elementType,
+ LayoutRulesImpl* rules)
+{
+ // TODO(tfoley): need to compute the layout for the constant
+ // buffer's contents...
+ auto structuredBufferLayoutRules = GetLayoutRulesImpl(
+ LayoutRule::HLSLStructuredBuffer);
+
+ // Create and save type layout for the buffer contents.
+ auto elementTypeLayout = CreateTypeLayout(
+ elementType.Ptr(),
+ structuredBufferLayoutRules);
+
+ return createStructuredBufferTypeLayout(
+ kind,
+ structuredBufferType,
+ elementTypeLayout,
+ rules);
+
+}
+
+SimpleLayoutInfo GetLayoutImpl(
+ ExpressionType* type,
+ LayoutRulesImpl* rules,
+ RefPtr<TypeLayout>* outTypeLayout)
+{
+ if (auto parameterBlockType = type->As<ParameterBlockType>())
+ {
+ // If the user is just interested in uniform layout info,
+ // then this is easy: a `ConstantBuffer<T>` is really no
+ // different from a `Texture2D<U>` in terms of how it
+ // should be handled as a member of a container.
+ //
+ auto info = getParameterBlockLayoutInfo(parameterBlockType, rules);
+
+ // The more interesting case, though, is when the user
+ // is requesting us to actually create a `TypeLayout`,
+ // since in that case we need to:
+ //
+ // 1. Compute a layout for the data inside the constant
+ // buffer, including offsets, etc.
+ //
+ // 2. Compute information about any object types inside
+ // the constant buffer, which need to be surfaces out
+ // to the top level.
+ //
+ if (outTypeLayout)
+ {
+ *outTypeLayout = createParameterBlockTypeLayout(
+ parameterBlockType,
+ rules);
+ }
+
+ return info;
+ }
+ else if (auto samplerStateType = type->As<SamplerStateType>())
+ {
+ return GetSimpleLayoutImpl(
+ rules->GetObjectLayout(ShaderParameterKind::SamplerState),
+ type,
+ rules,
+ outTypeLayout);
+ }
+ else if (auto textureType = type->As<TextureType>())
+ {
+ // TODO: the logic here should really be defined by the rules,
+ // and not at this top level...
+ ShaderParameterKind kind;
+ switch( textureType->getAccess() )
+ {
+ default:
+ kind = ShaderParameterKind::MutableTexture;
+ break;
+
+ case SLANG_RESOURCE_ACCESS_READ:
+ kind = ShaderParameterKind::Texture;
+ break;
+ }
+
+ return GetSimpleLayoutImpl(
+ rules->GetObjectLayout(kind),
+ type,
+ rules,
+ outTypeLayout);
+ }
+ else if (auto textureSamplerType = type->As<TextureSamplerType>())
+ {
+ // TODO: the logic here should really be defined by the rules,
+ // and not at this top level...
+ ShaderParameterKind kind;
+ switch( textureSamplerType->getAccess() )
+ {
+ default:
+ kind = ShaderParameterKind::MutableTextureSampler;
+ break;
+
+ case SLANG_RESOURCE_ACCESS_READ:
+ kind = ShaderParameterKind::TextureSampler;
+ break;
+ }
+
+ return GetSimpleLayoutImpl(
+ rules->GetObjectLayout(kind),
+ type,
+ rules,
+ outTypeLayout);
+ }
+
+ // TODO: need a better way to handle this stuff...
+#define CASE(TYPE, KIND) \
+ else if(auto type_##TYPE = type->As<TYPE>()) do { \
+ auto info = rules->GetObjectLayout(ShaderParameterKind::KIND); \
+ if (outTypeLayout) \
+ { \
+ *outTypeLayout = createStructuredBufferTypeLayout( \
+ ShaderParameterKind::KIND, \
+ type_##TYPE, \
+ type_##TYPE->elementType.Ptr(), \
+ rules); \
+ } \
+ return info; \
+ } while(0)
+
+ CASE(HLSLStructuredBufferType, StructuredBuffer);
+ CASE(HLSLRWStructuredBufferType, MutableStructuredBuffer);
+ CASE(HLSLAppendStructuredBufferType, MutableStructuredBuffer);
+ CASE(HLSLConsumeStructuredBufferType, MutableStructuredBuffer);
+
+#undef CASE
+
+
+ // TODO: need a better way to handle this stuff...
+#define CASE(TYPE, KIND) \
+ else if(type->As<TYPE>()) do { \
+ return GetSimpleLayoutImpl( \
+ rules->GetObjectLayout(ShaderParameterKind::KIND), \
+ type, rules, outTypeLayout); \
+ } while(0)
+
+ CASE(HLSLBufferType, SampledBuffer);
+ CASE(HLSLRWBufferType, MutableSampledBuffer);
+ CASE(HLSLByteAddressBufferType, RawBuffer);
+ CASE(HLSLRWByteAddressBufferType, MutableRawBuffer);
+
+ CASE(GLSLInputAttachmentType, InputRenderTarget);
+
+ // This case is mostly to allow users to add new resource types...
+ CASE(UntypedBufferResourceType, RawBuffer);
+
+#undef CASE
+
+ //
+ // TODO(tfoley): Need to recognize any UAV types here
+ //
+ else if(auto basicType = type->As<BasicExpressionType>())
+ {
+ return GetSimpleLayoutImpl(
+ rules->GetScalarLayout(basicType->BaseType),
+ type,
+ rules,
+ outTypeLayout);
+ }
+ else if(auto vecType = type->As<VectorExpressionType>())
+ {
+ return GetSimpleLayoutImpl(
+ rules->GetVectorLayout(
+ GetLayout(vecType->elementType.Ptr(), rules),
+ GetIntVal(vecType->elementCount)),
+ type,
+ rules,
+ outTypeLayout);
+ }
+ else if(auto matType = type->As<MatrixExpressionType>())
+ {
+ return GetSimpleLayoutImpl(
+ rules->GetMatrixLayout(
+ GetLayout(matType->getElementType(), rules),
+ GetIntVal(matType->getRowCount()),
+ GetIntVal(matType->getColumnCount())),
+ type,
+ rules,
+ outTypeLayout);
+ }
+ else if (auto arrayType = type->As<ArrayExpressionType>())
+ {
+ RefPtr<TypeLayout> elementTypeLayout;
+ auto elementInfo = GetLayoutImpl(
+ arrayType->BaseType.Ptr(),
+ rules,
+ outTypeLayout ? &elementTypeLayout : nullptr);
+
+ // For layout purposes, we treat an unsized array as an array of zero elements.
+ //
+ // TODO: Longer term we are going to need to be careful to include some indication
+ // that a type has logically "infinite" size in some resource kind. In particular
+ // this affects how we would allocate space for parameter binding purposes.
+ auto elementCount = arrayType->ArrayLength ? GetElementCount(arrayType->ArrayLength) : 0;
+ auto arrayUniformInfo = rules->GetArrayLayout(
+ elementInfo,
+ elementCount).getUniformLayout();
+
+ if (outTypeLayout)
+ {
+ RefPtr<ArrayTypeLayout> typeLayout = new ArrayTypeLayout();
+ *outTypeLayout = typeLayout;
+
+ typeLayout->type = type;
+ typeLayout->elementTypeLayout = elementTypeLayout;
+ typeLayout->rules = rules;
+
+ typeLayout->uniformAlignment = arrayUniformInfo.alignment;
+ typeLayout->uniformStride = arrayUniformInfo.elementStride;
+
+ typeLayout->addResourceUsage(LayoutResourceKind::Uniform, arrayUniformInfo.size);
+
+ // translate element-type resources into array-type resources
+ for( auto elementResourceInfo : elementTypeLayout->resourceInfos )
+ {
+ // The uniform case was already handled above
+ if( elementResourceInfo.kind == LayoutResourceKind::Uniform )
+ continue;
+
+ typeLayout->addResourceUsage(
+ elementResourceInfo.kind,
+ elementResourceInfo.count * elementCount);
+ }
+ }
+ return arrayUniformInfo;
+ }
+ else if (auto declRefType = type->As<DeclRefType>())
+ {
+ auto declRef = declRefType->declRef;
+
+ if (auto structDeclRef = declRef.As<StructDeclRef>())
+ {
+ RefPtr<StructTypeLayout> typeLayout;
+ if (outTypeLayout)
+ {
+ typeLayout = new StructTypeLayout();
+ typeLayout->type = type;
+ typeLayout->rules = rules;
+ *outTypeLayout = typeLayout;
+ }
+
+ UniformLayoutInfo info = rules->BeginStructLayout();
+
+ for (auto field : structDeclRef.GetFields())
+ {
+ RefPtr<TypeLayout> fieldTypeLayout;
+ UniformLayoutInfo fieldInfo = GetLayoutImpl(
+ field.GetType().Ptr(),
+ rules,
+ outTypeLayout ? &fieldTypeLayout : nullptr).getUniformLayout();
+
+ // Note: we don't add any zero-size fields
+ // when computing structure layout, just
+ // to avoid having a resource type impact
+ // the final layout.
+ //
+ // This means that the code to generate final
+ // declarations needs to *also* eliminate zero-size
+ // fields to be safe...
+ size_t uniformOffset = info.size;
+ if(fieldInfo.size != 0)
+ {
+ uniformOffset = rules->AddStructField(&info, fieldInfo);
+ }
+
+ if (outTypeLayout)
+ {
+ // If we are computing a complete layout,
+ // then we need to create variable layouts
+ // for each field of the structure.
+ RefPtr<VarLayout> fieldLayout = new VarLayout();
+ fieldLayout->varDecl = field;
+ fieldLayout->typeLayout = fieldTypeLayout;
+ typeLayout->fields.Add(fieldLayout);
+ typeLayout->mapVarToLayout.Add(field.GetDecl(), fieldLayout);
+
+ // Set up uniform offset information, if there is any uniform data in the field
+ if( fieldTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform) )
+ {
+ fieldLayout->AddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset;
+ }
+
+ // Add offset information for any other resource kinds
+ for( auto fieldTypeResourceInfo : fieldTypeLayout->resourceInfos )
+ {
+ // Uniforms were dealt with above
+ if(fieldTypeResourceInfo.kind == LayoutResourceKind::Uniform)
+ continue;
+
+ // We should not have already processed this resource type
+ assert(!fieldLayout->FindResourceInfo(fieldTypeResourceInfo.kind));
+
+ // The field will need offset information for this kind
+ auto fieldResourceInfo = fieldLayout->AddResourceInfo(fieldTypeResourceInfo.kind);
+
+ // Check how many slots of the given kind have already been added to the type
+ auto structTypeResourceInfo = typeLayout->findOrAddResourceInfo(fieldTypeResourceInfo.kind);
+ fieldResourceInfo->index = structTypeResourceInfo->count;
+ structTypeResourceInfo->count += fieldTypeResourceInfo.count;
+ }
+ }
+ }
+
+ rules->EndStructLayout(&info);
+ if (outTypeLayout)
+ {
+ typeLayout->uniformAlignment = info.alignment;
+ typeLayout->addResourceUsage(LayoutResourceKind::Uniform, info.size);
+ }
+
+ return info;
+ }
+ }
+
+ // catch-all case in case nothing matched
+ assert(!"unimplemented");
+ SimpleLayoutInfo info;
+ return GetSimpleLayoutImpl(
+ info,
+ type,
+ rules,
+ outTypeLayout);
+}
+
+SimpleLayoutInfo GetLayout(ExpressionType* inType, LayoutRulesImpl* rules)
+{
+ return GetLayoutImpl(inType, rules, nullptr);
+}
+
+RefPtr<TypeLayout> CreateTypeLayout(ExpressionType* type, LayoutRulesImpl* rules)
+{
+ RefPtr<TypeLayout> typeLayout;
+ GetLayoutImpl(type, rules, &typeLayout);
+ return typeLayout;
+}
+
+SimpleLayoutInfo GetLayout(ExpressionType* type, LayoutRule rule)
+{
+ LayoutRulesImpl* rulesImpl = GetLayoutRulesImpl(rule);
+ return GetLayout(type, rulesImpl);
+}
+
+}}
diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h
new file mode 100644
index 000000000..be54bbf53
--- /dev/null
+++ b/source/slang/type-layout.h
@@ -0,0 +1,550 @@
+#ifndef SLANG_TYPE_LAYOUT_H
+#define SLANG_TYPE_LAYOUT_H
+
+#include "../core/basic.h"
+#include "profile.h"
+#include "syntax.h"
+
+#include "../../slang.h"
+
+namespace Slang {
+
+typedef intptr_t Int;
+typedef uintptr_t UInt;
+
+namespace Compiler {
+
+// Forward declarations
+
+enum class BaseType;
+class ExpressionType;
+
+//
+
+enum class LayoutRule
+{
+ Std140,
+ Std430,
+ HLSLConstantBuffer,
+ HLSLStructuredBuffer,
+};
+
+enum class LayoutRulesFamily
+{
+ HLSL,
+ GLSL,
+};
+
+// Layout appropriate to "just memory" scenarios,
+// such as laying out the members of a constant buffer.
+struct UniformLayoutInfo
+{
+ size_t size;
+ size_t alignment;
+
+ UniformLayoutInfo()
+ : size(0)
+ , alignment(1)
+ {}
+
+ UniformLayoutInfo(
+ size_t size,
+ size_t alignment)
+ : size(size)
+ , alignment(alignment)
+ {}
+};
+
+// Extended information required for an array of uniform data,
+// including the "stride" of the array (the space between
+// consecutive elements).
+struct UniformArrayLayoutInfo : UniformLayoutInfo
+{
+ size_t elementStride;
+
+ UniformArrayLayoutInfo()
+ : elementStride(0)
+ {}
+
+ UniformArrayLayoutInfo(
+ size_t size,
+ size_t alignment,
+ size_t elementStride)
+ : UniformLayoutInfo(size, alignment)
+ , elementStride(elementStride)
+ {}
+};
+
+typedef slang::ParameterCategory LayoutResourceKind;
+
+// Layout information for a value that only consumes
+// a single reosurce kind.
+struct SimpleLayoutInfo
+{
+ // What kind of resource should we consume?
+ LayoutResourceKind kind;
+
+ // How many resources of that kind?
+ size_t size;
+
+ // only useful in the uniform case
+ size_t alignment;
+
+ SimpleLayoutInfo()
+ : kind(LayoutResourceKind::None)
+ , size(0)
+ , alignment(1)
+ {}
+
+ SimpleLayoutInfo(
+ UniformLayoutInfo uniformInfo)
+ : kind(LayoutResourceKind::Uniform)
+ , size(uniformInfo.size)
+ , alignment(uniformInfo.alignment)
+ {}
+
+ SimpleLayoutInfo(LayoutResourceKind kind, size_t size, size_t alignment=1)
+ : kind(kind)
+ , size(size)
+ , alignment(alignment)
+ {}
+
+ // Convert to layout for uniform data
+ UniformLayoutInfo getUniformLayout()
+ {
+ if(kind == LayoutResourceKind::Uniform)
+ {
+ return UniformLayoutInfo(size, alignment);
+ }
+ else
+ {
+ return UniformLayoutInfo(0, 1);
+ }
+ }
+};
+
+// Only useful in the case of a homogeneous array
+struct SimpleArrayLayoutInfo : SimpleLayoutInfo
+{
+ // This field is only useful in the uniform case
+ size_t elementStride;
+
+ // Convert to layout for uniform data
+ UniformArrayLayoutInfo getUniformLayout()
+ {
+ if(kind == LayoutResourceKind::Uniform)
+ {
+ return UniformArrayLayoutInfo(size, alignment, elementStride);
+ }
+ else
+ {
+ return UniformArrayLayoutInfo(0, 1, 0);
+ }
+ }
+};
+
+struct LayoutRulesImpl;
+
+// A reified reprsentation of a particular laid-out type
+class TypeLayout : public RefObject
+{
+public:
+ // The type that was laid out
+ RefPtr<ExpressionType> type;
+ ExpressionType* getType() { return type.Ptr(); }
+
+ // The layout rules that were used to produce this type
+ LayoutRulesImpl* rules;
+
+ struct ResourceInfo
+ {
+ // What kind of register was it?
+ LayoutResourceKind kind = LayoutResourceKind::None;
+
+ // How many registers of the above kind did we use?
+ UInt count;
+ };
+
+ List<ResourceInfo> resourceInfos;
+
+ // For uniform data, alignment matters, but not for
+ // any other resource category, so we don't waste
+ // the space storing it in the above array
+ UInt uniformAlignment = 1;
+
+ ResourceInfo* FindResourceInfo(LayoutResourceKind kind)
+ {
+ for(auto& rr : resourceInfos)
+ {
+ if(rr.kind == kind)
+ return &rr;
+ }
+ return nullptr;
+ }
+
+ ResourceInfo* findOrAddResourceInfo(LayoutResourceKind kind)
+ {
+ auto existing = FindResourceInfo(kind);
+ if(existing) return existing;
+
+ ResourceInfo info;
+ info.kind = kind;
+ info.count = 0;
+ resourceInfos.Add(info);
+ return &resourceInfos.Last();
+ }
+
+ void addResourceUsage(ResourceInfo info)
+ {
+ if(info.count == 0) return;
+
+ findOrAddResourceInfo(info.kind)->count += info.count;
+ }
+
+ void addResourceUsage(LayoutResourceKind kind, UInt count)
+ {
+ ResourceInfo info;
+ info.kind = kind;
+ info.count = count;
+ addResourceUsage(info);
+ }
+};
+
+typedef unsigned int VarLayoutFlags;
+enum VarLayoutFlag : VarLayoutFlags
+{
+ IsRedeclaration = 1 << 0, ///< This is a redeclaration of some shader parameter
+};
+
+// A reified layout for a particular variable, field, etc.
+class VarLayout : public RefObject
+{
+public:
+ // The variable we are laying out
+ VarDeclBaseRef varDecl;
+ VarDeclBase* getVariable() { return varDecl.GetDecl(); }
+
+ String const& getName() { return getVariable()->getName(); }
+
+ // The result of laying out the variable's type
+ RefPtr<TypeLayout> typeLayout;
+ TypeLayout* getTypeLayout() { return typeLayout.Ptr(); }
+
+ // Additional flags
+ VarLayoutFlags flags = 0;
+
+ // The start register(s) for any resources
+ struct ResourceInfo
+ {
+ // What kind of register was it?
+ LayoutResourceKind kind = LayoutResourceKind::None;
+
+ // What binding space (HLSL) or set (Vulkan) are we placed in?
+ UInt space;
+
+ // What is our starting register in that space?
+ //
+ // (In the case of uniform data, this is a byte offset)
+ UInt index;
+ };
+ List<ResourceInfo> resourceInfos;
+
+ ResourceInfo* FindResourceInfo(LayoutResourceKind kind)
+ {
+ for(auto& rr : resourceInfos)
+ {
+ if(rr.kind == kind)
+ return &rr;
+ }
+ return nullptr;
+ }
+
+ ResourceInfo* AddResourceInfo(LayoutResourceKind kind)
+ {
+ ResourceInfo info;
+ info.kind = kind;
+ info.space = 0;
+ info.index = 0;
+
+ resourceInfos.Add(info);
+ return &resourceInfos.Last();
+ }
+
+ ResourceInfo* findOrAddResourceInfo(LayoutResourceKind kind)
+ {
+ auto existing = FindResourceInfo(kind);
+ if(existing) return existing;
+
+ return AddResourceInfo(kind);
+ }
+};
+
+// Type layout for a variable that has a constant-buffer type
+class ParameterBlockTypeLayout : public TypeLayout
+{
+public:
+ RefPtr<TypeLayout> elementTypeLayout;
+};
+
+// Type layout for a variable that has a constant-buffer type
+class StructuredBufferTypeLayout : public TypeLayout
+{
+public:
+ RefPtr<TypeLayout> elementTypeLayout;
+};
+
+// Specific case of type layout for an array
+class ArrayTypeLayout : public TypeLayout
+{
+public:
+ // The layout used for the element type
+ RefPtr<TypeLayout> elementTypeLayout;
+
+ // the stride between elements when used in
+ // a uniform buffer
+ size_t uniformStride;
+};
+
+// Specific case of type layout for a struct
+class StructTypeLayout : public TypeLayout
+{
+public:
+ // An ordered list of layouts for the known fields
+ List<RefPtr<VarLayout>> fields;
+
+ // Map a variable to its layout directly.
+ //
+ // Note that in the general case, there may be entries
+ // in the `fields` array that came from multiple
+ // translation units, and in cases where there are
+ // multiple declarations of the same parameter, only
+ // one will appear in `fields`, while all of
+ // them will be reflected in `mapVarToLayout`.
+ //
+ Dictionary<Decl*, RefPtr<VarLayout>> mapVarToLayout;
+};
+
+// Layout information for a single shader entry point
+// within a program
+class EntryPointLayout : public RefObject
+{
+public:
+ // The corresponding function declaration
+ RefPtr<FunctionSyntaxNode> entryPoint;
+
+ // The shader profile that was used to compile the entry point
+ Profile profile;
+};
+
+// Layout information for the global scope of a program
+class ProgramLayout : public RefObject
+{
+public:
+ // We store a layout for the declarations at the global
+ // scope. Note that this will *either* be a single
+ // `StructTypeLayout` with the fields stored directly,
+ // or it will be a single `ParameterBlockTypeLayout`,
+ // where the global-scope fields are the members of
+ // that constant buffer.
+ //
+ // The `struct` case will be used if there are no
+ // "naked" global-scope uniform variables, and the
+ // constant-buffer case will be used if there are
+ // (since a constant buffer will have to be allocated
+ // to store them).
+ //
+ RefPtr<TypeLayout> globalScopeLayout;
+
+ // We catalog the requested entry points here,
+ // and any entry-point-specific parameter data
+ // will (eventually) belong there...
+ List<RefPtr<EntryPointLayout>> entryPoints;
+};
+
+// A modifier to be attached to syntax after we've computed layout
+class ComputedLayoutModifier : public Modifier
+{
+public:
+ RefPtr<TypeLayout> typeLayout;
+};
+
+
+struct LayoutRulesFamilyImpl;
+
+// A delineation of shader parameter types into fine-grained
+// categories that can then be mapped down to actual resources
+// by a given set of rules.
+//
+// TODO(tfoley): `SlangParameterCategory` and `slang::ParameterCategory`
+// are badly named, and need to be revised so they can't be confused
+// with this concept.
+enum class ShaderParameterKind
+{
+ ConstantBuffer,
+ TextureUniformBuffer,
+ ShaderStorageBuffer,
+
+ StructuredBuffer,
+ MutableStructuredBuffer,
+
+ SampledBuffer,
+ MutableSampledBuffer,
+
+ RawBuffer,
+ MutableRawBuffer,
+
+ Buffer,
+ MutableBuffer,
+
+ Texture,
+ MutableTexture,
+
+ TextureSampler,
+ MutableTextureSampler,
+
+ InputRenderTarget,
+
+ SamplerState,
+};
+
+struct SimpleLayoutRulesImpl
+{
+ // Get size and alignment for a single value of base type.
+ virtual SimpleLayoutInfo GetScalarLayout(BaseType baseType) = 0;
+ virtual SimpleLayoutInfo GetScalarLayout(slang::TypeReflection::ScalarType scalarType) = 0;
+
+ // Get size and alignment for an array of elements
+ virtual SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, size_t elementCount) = 0;
+
+ // Get layout for a vector or matrix type
+ virtual SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount) = 0;
+ virtual SimpleLayoutInfo GetMatrixLayout(SimpleLayoutInfo elementInfo, size_t rowCount, size_t columnCount) = 0;
+
+ // Begin doing layout on a `struct` type
+ virtual UniformLayoutInfo BeginStructLayout() = 0;
+
+ // Add a field to a `struct` type, and return the offset for the field
+ virtual size_t AddStructField(UniformLayoutInfo* ioStructInfo, UniformLayoutInfo fieldInfo) = 0;
+
+ // End layout for a struct, and finalize its size/alignment.
+ virtual void EndStructLayout(UniformLayoutInfo* ioStructInfo) = 0;
+};
+
+struct ObjectLayoutRulesImpl
+{
+ // Compute layout info for an object type
+ virtual SimpleLayoutInfo GetObjectLayout(ShaderParameterKind kind) = 0;
+};
+
+struct LayoutRulesImpl
+{
+ LayoutRulesFamilyImpl* family;
+ SimpleLayoutRulesImpl* simpleRules;
+ ObjectLayoutRulesImpl* objectRules;
+
+ // Forward `SimpleLayoutRulesImpl` interface
+
+ SimpleLayoutInfo GetScalarLayout(BaseType baseType)
+ {
+ return simpleRules->GetScalarLayout(baseType);
+ }
+
+ SimpleLayoutInfo GetScalarLayout(slang::TypeReflection::ScalarType scalarType)
+ {
+ return simpleRules->GetScalarLayout(scalarType);
+ }
+
+ SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, size_t elementCount)
+ {
+ return simpleRules->GetArrayLayout(elementInfo, elementCount);
+ }
+
+ SimpleLayoutInfo GetVectorLayout(SimpleLayoutInfo elementInfo, size_t elementCount)
+ {
+ return simpleRules->GetVectorLayout(elementInfo, elementCount);
+ }
+
+ SimpleLayoutInfo GetMatrixLayout(SimpleLayoutInfo elementInfo, size_t rowCount, size_t columnCount)
+ {
+ return simpleRules->GetMatrixLayout(elementInfo, rowCount, columnCount);
+ }
+
+ UniformLayoutInfo BeginStructLayout()
+ {
+ return simpleRules->BeginStructLayout();
+ }
+
+ size_t AddStructField(UniformLayoutInfo* ioStructInfo, UniformLayoutInfo fieldInfo)
+ {
+ return simpleRules->AddStructField(ioStructInfo, fieldInfo);
+ }
+
+ void EndStructLayout(UniformLayoutInfo* ioStructInfo)
+ {
+ return simpleRules->EndStructLayout(ioStructInfo);
+ }
+
+ // Forward `ObjectLayoutRulesImpl` interface
+
+ SimpleLayoutInfo GetObjectLayout(ShaderParameterKind kind)
+ {
+ return objectRules->GetObjectLayout(kind);
+ }
+
+ //
+
+ LayoutRulesFamilyImpl* getLayoutRulesFamily() { return family; }
+};
+
+struct LayoutRulesFamilyImpl
+{
+ virtual LayoutRulesImpl* getConstantBufferRules() = 0;
+ virtual LayoutRulesImpl* getTextureBufferRules() = 0;
+ virtual LayoutRulesImpl* getVaryingInputRules() = 0;
+ virtual LayoutRulesImpl* getVaryingOutputRules() = 0;
+ virtual LayoutRulesImpl* getSpecializationConstantRules() = 0;
+ virtual LayoutRulesImpl* getShaderStorageBufferRules() = 0;
+};
+
+LayoutRulesImpl* GetLayoutRulesImpl(LayoutRule rule);
+LayoutRulesFamilyImpl* GetLayoutRulesFamilyImpl(LayoutRulesFamily rule);
+LayoutRulesFamilyImpl* GetLayoutRulesFamilyImpl(SourceLanguage language);
+
+SimpleLayoutInfo GetLayout(ExpressionType* type, LayoutRulesImpl* rules);
+
+SimpleLayoutInfo GetLayout(ExpressionType* type, LayoutRule rule = LayoutRule::Std430);
+
+RefPtr<TypeLayout> CreateTypeLayout(ExpressionType* type, LayoutRulesImpl* rules);
+
+//
+
+// Create a type layout for a parameter block type.
+RefPtr<ParameterBlockTypeLayout>
+createParameterBlockTypeLayout(
+ RefPtr<ParameterBlockType> parameterBlockType,
+ LayoutRulesImpl* rules);
+
+// Create a type layout for a constant buffer type,
+// in the case where we already know the layout
+// for the element type.
+RefPtr<ParameterBlockTypeLayout>
+createParameterBlockTypeLayout(
+ RefPtr<ParameterBlockType> parameterBlockType,
+ RefPtr<TypeLayout> elementTypeLayout,
+ LayoutRulesImpl* rules);
+
+
+// Create a type layout for a structured buffer type.
+RefPtr<StructuredBufferTypeLayout>
+createStructuredBufferTypeLayout(
+ ShaderParameterKind kind,
+ RefPtr<ExpressionType> structuredBufferType,
+ RefPtr<ExpressionType> elementType,
+ LayoutRulesImpl* rules);
+
+
+//
+
+}}
+
+#endif \ No newline at end of file