diff options
| author | Yong He <yonghe@outlook.com> | 2019-01-31 13:35:03 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2019-01-31 13:35:03 -0800 |
| commit | bcb361db7c5a6f8baa9b2012b9ee9778421f1386 (patch) | |
| tree | 9e5e1703e3b06e109bae6db136bbc2e816f96a2f | |
| parent | c1fe5f295e843d10e24ae0d053fc3813a29aec89 (diff) | |
| parent | f20c64c348393602ed2a9c873386345cc4b493e8 (diff) | |
Merge branch 'master' into crashfix
31 files changed, 2136 insertions, 1467 deletions
diff --git a/source/core/smart-pointer.h b/source/core/smart-pointer.h index dd00acabd..d026388c0 100644 --- a/source/core/smart-pointer.h +++ b/source/core/smart-pointer.h @@ -63,30 +63,31 @@ namespace Slang { return referenceCount; } - - // Use instead of dynamic_cast as it allows for replacement without using Rtti in the future - template<typename T> - SLANG_FORCE_INLINE const T* dynamicCast() const - { - return dynamic_cast<const T*>(this); - } - template<typename T> - SLANG_FORCE_INLINE T* dynamicCast() - { - return dynamic_cast<T*>(this); - } }; - inline void addReference(RefObject* obj) + SLANG_FORCE_INLINE void addReference(RefObject* obj) { if(obj) obj->addReference(); } - inline void releaseReference(RefObject* obj) + SLANG_FORCE_INLINE void releaseReference(RefObject* obj) { if(obj) obj->releaseReference(); } + // For straight dynamic cast. + // Use instead of dynamic_cast as it allows for replacement without using Rtti in the future + template <typename T> + SLANG_FORCE_INLINE T* dynamicCast(RefObject* obj) { return dynamic_cast<T*>(obj); } + template <typename T> + SLANG_FORCE_INLINE const T* dynamicCast(const RefObject* obj) { return dynamic_cast<const T*>(obj); } + + // Like a dynamicCast, but allows a type to implement a specific implementation that is suitable for it + template <typename T> + SLANG_FORCE_INLINE T* as(RefObject* obj) { return dynamicCast<T>(obj); } + template <typename T> + SLANG_FORCE_INLINE const T* as(const RefObject* obj) { return dynamicCast<T>(obj); } + // "Smart" pointer to a reference-counted object template<typename T> struct RefPtr @@ -182,9 +183,15 @@ namespace Slang } template<typename U> - RefPtr<U> As() const + RefPtr<U> dynamicCast() const + { + return RefPtr<U>(Slang::dynamicCast<U>(pointer)); + } + + template<typename U> + RefPtr<U> as() const { - return RefPtr<U>(pointer->template dynamicCast<U>()); + return RefPtr<U>(Slang::as<U>(pointer)); } ~RefPtr() @@ -238,4 +245,4 @@ namespace Slang }; } -#endif
\ No newline at end of file +#endif diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 74adaccda..199e733ce 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -19,7 +19,7 @@ namespace Slang { // Things at the global scope are always "members" of their module. // - if(parentDecl->As<ModuleDecl>()) + if(as<ModuleDecl>(parentDecl)) return false; // Anything explicitly marked `static` and not at module scope @@ -55,7 +55,7 @@ namespace Slang // explicit representation of up-cast operations in the // AST. // - if(decl->As<TypeConstraintDecl>()) + if(as<TypeConstraintDecl>(decl)) return false; return false; @@ -74,7 +74,7 @@ namespace Slang // function for it. auto parentDecl = decl->ParentDecl; - if(auto genericDecl = parentDecl->As<GenericDecl>()) + if(auto genericDecl = as<GenericDecl>(parentDecl)) parentDecl = genericDecl->ParentDecl; return isEffectivelyStatic(decl, parentDecl); @@ -119,29 +119,31 @@ namespace Slang bool fromType(Type* typeIn) { aggVal = 0; - if (auto basicType = typeIn->AsBasicType()) + if (auto basicType = as<BasicExpressionType>(typeIn)) { data.type = (unsigned char)basicType->baseType; data.dim1 = data.dim2 = 0; } - else if (auto vectorType = typeIn->AsVectorType()) + else if (auto vectorType = as<VectorExpressionType>(typeIn)) { - if (auto elemCount = vectorType->elementCount.As<ConstantIntVal>()) + if (auto elemCount = vectorType->elementCount.dynamicCast<ConstantIntVal>()) { data.dim1 = elemCount->value - 1; - data.type = (unsigned char)vectorType->elementType->AsBasicType()->baseType; + auto elementBasicType = as<BasicExpressionType>(vectorType->elementType); + data.type = (unsigned char)elementBasicType->baseType; data.dim2 = 0; } else return false; } - else if (auto matrixType = typeIn->AsMatrixType()) + else if (auto matrixType = as<MatrixExpressionType>(typeIn)) { if (auto elemCount1 = dynamic_cast<ConstantIntVal*>(matrixType->getRowCount())) { if (auto elemCount2 = dynamic_cast<ConstantIntVal*>(matrixType->getColumnCount())) { - data.type = (unsigned char)matrixType->getElementType()->AsBasicType()->baseType; + auto elemBasicType = as<BasicExpressionType>(matrixType->getElementType()); + data.type = (unsigned char)elemBasicType->baseType; data.dim1 = elemCount1->value - 1; data.dim2 = elemCount2->value - 1; } @@ -241,16 +243,16 @@ namespace Slang // attached to an overloaded definition (filtered for // definitions that could conceivably apply to us). // - // TODO: This should really be pased on the operator name + // TODO: This should really be parsed on the operator name // plus fixity, rather than the intrinsic opcode... // // We will need to reject postfix definitions for prefix // operators, and vice versa, to ensure things work. // - auto prefixExpr = opExpr->As<PrefixExpr>(); - auto postfixExpr = opExpr->As<PostfixExpr>(); + auto prefixExpr = as<PrefixExpr>(opExpr); + auto postfixExpr = as<PostfixExpr>(opExpr); - if (auto overloadedBase = opExpr->FunctionExpr->As<OverloadedExpr>()) + if (auto overloadedBase = as<OverloadedExpr>(opExpr->FunctionExpr)) { for(auto item : overloadedBase->lookupResult2 ) { @@ -258,7 +260,7 @@ namespace Slang // see if it gives us a key to work with. // Decl* funcDecl = overloadedBase->lookupResult2.item.declRef.decl; - if (auto genDecl = funcDecl->As<GenericDecl>()) + if (auto genDecl = as<GenericDecl>(funcDecl)) funcDecl = genDecl->inner.Ptr(); // Reject definitions that have the wrong fixity. @@ -457,7 +459,7 @@ namespace Slang RefPtr<Type> ExtractTypeFromTypeRepr(const RefPtr<Expr>& typeRepr) { if (!typeRepr) return nullptr; - if (auto typeType = typeRepr->type->As<TypeType>()) + if (auto typeType = as<TypeType>(typeRepr->type)) { return typeType->type; } @@ -493,10 +495,10 @@ namespace Slang RefPtr<DeclRefType> getExprDeclRefType(Expr * expr) { - if (auto typetype = expr->type->As<TypeType>()) - return typetype->type.As<DeclRefType>(); + if (auto typetype = as<TypeType>(expr->type)) + return typetype->type.dynamicCast<DeclRefType>(); else - return expr->type->As<DeclRefType>(); + return as<DeclRefType>(expr->type); } /// Is `decl` usable as a static member? @@ -506,16 +508,16 @@ namespace Slang if(decl->HasModifier<HLSLStaticModifier>()) return true; - if(decl->As<ConstructorDecl>()) + if(as<ConstructorDecl>(decl)) return true; - if(decl->As<EnumCaseDecl>()) + if(as<EnumCaseDecl>(decl)) return true; - if(decl->As<AggTypeDeclBase>()) + if(as<AggTypeDeclBase>(decl)) return true; - if(decl->As<SimpleTypeDecl>()) + if(as<SimpleTypeDecl>(decl)) return true; return false; @@ -568,20 +570,20 @@ namespace Slang { auto exprType = expr->type.type; - if(auto declRefType = exprType->As<DeclRefType>()) + if(auto declRefType = as<DeclRefType>(exprType)) { - if(auto interfaceDeclRef = declRefType->declRef.As<InterfaceDecl>()) + if(auto interfaceDeclRef = declRefType->declRef.as<InterfaceDecl>()) { // Is there an this-type substitution being applied, so that // we are referencing the interface type through a concrete - // type (e.g., a type parameter constrainted to this interface)? + // type (e.g., a type parameter constrained to this interface)? // // Because of the way that substitutions need to mirror the nesting // hierarchy of declarations, any this-type substitution pertaining // to the chosen interface decl must be the first substitution on // the list (which is a linked list from the "inside" out). // - auto thisTypeSubst = interfaceDeclRef.substitutions.substitutions.As<ThisTypeSubstitution>(); + auto thisTypeSubst = interfaceDeclRef.substitutions.substitutions.dynamicCast<ThisTypeSubstitution>(); if(thisTypeSubst && thisTypeSubst->interfaceDecl == interfaceDeclRef.decl) { // This isn't really an existential type, because somebody @@ -647,7 +649,7 @@ namespace Slang // auto type = GetTypeForDeclRef(declRef); - // Construct an appropriate expression based on teh structured of + // Construct an appropriate expression based on the structured of // the declaration reference. // if (baseExpr) @@ -664,7 +666,7 @@ namespace Slang // form (e.g., for a member function, return a value usable // for referencing it as a free function). // - if (baseExpr->type->As<TypeType>()) + if (as<TypeType>(baseExpr->type)) { auto expr = new StaticMemberExpr(); expr->loc = loc; @@ -735,7 +737,7 @@ namespace Slang RefPtr<Expr> base, SourceLoc loc) { - auto ptrLikeType = base->type->As<PointerLikeType>(); + auto ptrLikeType = as<PointerLikeType>(base->type); SLANG_ASSERT(ptrLikeType); auto derefExpr = new DerefExpr(); @@ -800,7 +802,7 @@ namespace Slang // The member was looked up via a `this` expression, // so we need to create one here. - if (auto extensionDeclRef = breadcrumb->declRef.As<ExtensionDecl>()) + if (auto extensionDeclRef = breadcrumb->declRef.as<ExtensionDecl>()) { bb = createImplicitThisMemberExpr( GetTargetType(extensionDeclRef), @@ -883,16 +885,16 @@ namespace Slang RefPtr<Expr> ExpectATypeRepr(RefPtr<Expr> expr) { - if (auto overloadedExpr = expr.As<OverloadedExpr>()) + if (auto overloadedExpr = expr.dynamicCast<OverloadedExpr>()) { expr = ResolveOverloadedExpr(overloadedExpr, LookupMask::type); } - if (auto typeType = expr->type.type->As<TypeType>()) + if (auto typeType = as<TypeType>(expr->type.type)) { return expr; } - else if (auto errorType = expr->type.type->As<ErrorType>()) + else if (auto errorType = as<ErrorType>(expr->type.type)) { return expr; } @@ -904,7 +906,7 @@ namespace Slang RefPtr<Type> ExpectAType(RefPtr<Expr> expr) { auto typeRepr = ExpectATypeRepr(expr); - if (auto typeType = typeRepr->type->As<TypeType>()) + if (auto typeType = as<TypeType>(typeRepr->type)) { return typeType->type; } @@ -923,17 +925,17 @@ namespace Slang RefPtr<Val> ExtractGenericArgVal(RefPtr<Expr> exp) { - if (auto overloadedExpr = exp.As<OverloadedExpr>()) + if (auto overloadedExpr = exp.dynamicCast<OverloadedExpr>()) { // assume that if it is overloaded, we want a type exp = ResolveOverloadedExpr(overloadedExpr, LookupMask::type); } - if (auto typeType = exp->type->As<TypeType>()) + if (auto typeType = as<TypeType>(exp->type)) { return typeType->type; } - else if (auto errorType = exp->type->As<ErrorType>()) + else if (auto errorType = as<ErrorType>(exp->type)) { return exp->type.type; } @@ -943,7 +945,7 @@ namespace Slang } } - // Construct a type reprsenting the instantiation of + // Construct a type representing the instantiation of // the given generic declaration for the given arguments. // The arguments should already be checked against // the declaration. @@ -1041,9 +1043,9 @@ namespace Slang // this is a quick fix that at least alerts the user to how we are // interpreting their code. // - if (auto varDecl = decl.As<VarDecl>()) + if (auto varDecl = decl.dynamicCast<VarDecl>()) { - if (auto parenScope = varDecl->ParentDecl->As<ScopeDecl>()) + if (auto parenScope = as<ScopeDecl>(varDecl->ParentDecl)) { // TODO: This diagnostic should be emitted on the line that is referencing // the declaration. That requires `EnsureDecl` to take the requesting @@ -1074,7 +1076,7 @@ namespace Slang void EnusreAllDeclsRec(RefPtr<Decl> decl) { checkDecl(decl); - if (auto containerDecl = decl.As<ContainerDecl>()) + if (auto containerDecl = decl.dynamicCast<ContainerDecl>()) { for (auto m : containerDecl->Members) { @@ -1109,7 +1111,7 @@ namespace Slang Type* type = typeExp.type.Ptr(); if(!type && typeExp.exp) { - if(auto typeType = typeExp.exp->type.type.As<TypeType>()) + if(auto typeType = typeExp.exp->type.type.dynamicCast<TypeType>()) { type = typeType->type; } @@ -1124,7 +1126,7 @@ namespace Slang return false; } - if (auto genericDeclRefType = type->As<GenericDeclRefType>()) + if (auto genericDeclRefType = as<GenericDeclRefType>(type)) { // We are using a reference to a generic declaration as a concrete // type. This means we should substitute in any default parameter values @@ -1139,7 +1141,7 @@ namespace Slang List<RefPtr<Expr>> args; for (RefPtr<Decl> member : genericDeclRef.getDecl()->Members) { - if (auto typeParam = member.As<GenericTypeParamDecl>()) + if (auto typeParam = member.dynamicCast<GenericTypeParamDecl>()) { if (!typeParam->initType.exp) { @@ -1155,7 +1157,7 @@ namespace Slang if (outProperType) args.Add(typeParam->initType.exp); } - else if (auto valParam = member.As<GenericValueParamDecl>()) + else if (auto valParam = member.dynamicCast<GenericValueParamDecl>()) { if (!valParam->initExpr) { @@ -1227,7 +1229,7 @@ namespace Slang { TypeExp result = CoerceToProperType(typeExp); Type* type = result.type.Ptr(); - if (auto basicType = type->As<BasicExpressionType>()) + if (auto basicType = as<BasicExpressionType>(type)) { // TODO: `void` shouldn't be a basic type, to make this easier to avoid if (basicType->baseType == BaseType::Void) @@ -1263,7 +1265,7 @@ namespace Slang { // TODO: we may want other cases here... - if (auto errorType = expr->type.As<ErrorType>()) + if (auto errorType = as<ErrorType>(expr->type)) return true; return false; @@ -1272,11 +1274,11 @@ namespace Slang // Capture the "base" expression in case this is a member reference RefPtr<Expr> GetBaseExpr(RefPtr<Expr> expr) { - if (auto memberExpr = expr.As<MemberExpr>()) + if (auto memberExpr = expr.dynamicCast<MemberExpr>()) { return memberExpr->BaseExpression; } - else if(auto overloadedExpr = expr.As<OverloadedExpr>()) + else if(auto overloadedExpr = expr.dynamicCast<OverloadedExpr>()) { return overloadedExpr->base; } @@ -1291,17 +1293,17 @@ namespace Slang { if(left == right) return true; - if(auto leftConst = left.As<ConstantIntVal>()) + if(auto leftConst = left.dynamicCast<ConstantIntVal>()) { - if(auto rightConst = right.As<ConstantIntVal>()) + if(auto rightConst = right.dynamicCast<ConstantIntVal>()) { return leftConst->value == rightConst->value; } } - if(auto leftVar = left.As<GenericParamIntVal>()) + if(auto leftVar = left.dynamicCast<GenericParamIntVal>()) { - if(auto rightVar = right.As<GenericParamIntVal>()) + if(auto rightVar = right.dynamicCast<GenericParamIntVal>()) { return leftVar->declRef.Equals(rightVar->declRef); } @@ -1326,31 +1328,31 @@ namespace Slang bool isEffectivelyScalarForInitializerLists( RefPtr<Type> type) { - if(type->As<ArrayExpressionType>()) return false; - if(type->As<VectorExpressionType>()) return false; - if(type->As<MatrixExpressionType>()) return false; + if(as<ArrayExpressionType>(type)) return false; + if(as<VectorExpressionType>(type)) return false; + if(as<MatrixExpressionType>(type)) return false; - if(type->As<BasicExpressionType>()) + if(as<BasicExpressionType>(type)) { return true; } - if(type->As<ResourceType>()) + if(as<ResourceType>(type)) { return true; } - if(type->As<UntypedBufferResourceType>()) + if(as<UntypedBufferResourceType>(type)) { return true; } - if(type->As<SamplerStateType>()) + if(as<SamplerStateType>(type)) { return true; } - if(auto declRefType = type->As<DeclRefType>()) + if(auto declRefType = as<DeclRefType>(type)) { - if(declRefType->declRef.As<StructDecl>()) + if(declRefType->declRef.as<StructDecl>()) return false; } @@ -1364,7 +1366,7 @@ namespace Slang { // A nested initializer list should always be used directly. // - if(fromExpr.As<InitializerListExpr>()) + if(fromExpr.dynamicCast<InitializerListExpr>()) { return true; } @@ -1461,7 +1463,7 @@ namespace Slang auto toType = inToType; UInt argCount = fromInitializerListExpr->args.Count(); - // In the case where we need to build a reuslt expression, + // In the case where we need to build a result expression, // we will collect the new arguments here List<RefPtr<Expr>> coercedArgs; @@ -1492,13 +1494,13 @@ namespace Slang // synthesizing default values. } } - else if (auto toVecType = toType->As<VectorExpressionType>()) + else if (auto toVecType = as<VectorExpressionType>(toType)) { auto toElementCount = toVecType->elementCount; auto toElementType = toVecType->elementType; UInt elementCount = 0; - if (auto constElementCount = toElementCount.As<ConstantIntVal>()) + if (auto constElementCount = toElementCount.dynamicCast<ConstantIntVal>()) { elementCount = (UInt) constElementCount->value; } @@ -1533,7 +1535,7 @@ namespace Slang } } } - else if(auto toArrayType = toType->As<ArrayExpressionType>()) + else if(auto toArrayType = as<ArrayExpressionType>(toType)) { // TODO(tfoley): If we can compute the size of the array statically, // then we want to check that there aren't too many initializers present @@ -1546,7 +1548,7 @@ namespace Slang // of elements being initialized matches what was declared. // UInt elementCount = 0; - if (auto constElementCount = toElementCount.As<ConstantIntVal>()) + if (auto constElementCount = toElementCount.dynamicCast<ConstantIntVal>()) { elementCount = (UInt) constElementCount->value; } @@ -1616,7 +1618,7 @@ namespace Slang new ConstantIntVal(elementCount)); } } - else if(auto toMatrixType = toType->As<MatrixExpressionType>()) + else if(auto toMatrixType = as<MatrixExpressionType>(toType)) { // In the general case, the initializer list might comprise // both vectors and scalars. @@ -1671,10 +1673,10 @@ namespace Slang } } } - else if(auto toDeclRefType = toType->As<DeclRefType>()) + else if(auto toDeclRefType = as<DeclRefType>(toType)) { auto toTypeDeclRef = toDeclRefType->declRef; - if(auto toStructDeclRef = toTypeDeclRef.As<StructDecl>()) + if(auto toStructDeclRef = toTypeDeclRef.as<StructDecl>()) { // Trying to initialize a `struct` type given an initializer list. // We will go through the fields in order and try to match them @@ -1768,7 +1770,7 @@ namespace Slang } // If either type is an error, then let things pass. - if (toType->As<ErrorType>() || fromType->As<ErrorType>()) + if (as<ErrorType>(toType) || as<ErrorType>(fromType)) { if (outToExpr) *outToExpr = CreateImplicitCastExpr(toType, fromExpr); @@ -1778,7 +1780,7 @@ namespace Slang } // Coercion from an initializer list is allowed for many types - if( auto fromInitializerListExpr = fromExpr.As<InitializerListExpr>()) + if( auto fromInitializerListExpr = fromExpr.dynamicCast<InitializerListExpr>()) { if(!tryCoerceInitializerList(toType, outToExpr, fromInitializerListExpr)) return false; @@ -1793,10 +1795,10 @@ namespace Slang } // - if (auto toDeclRefType = toType->As<DeclRefType>()) + if (auto toDeclRefType = as<DeclRefType>(toType)) { auto toTypeDeclRef = toDeclRefType->declRef; - if (auto interfaceDeclRef = toTypeDeclRef.As<InterfaceDecl>()) + if (auto interfaceDeclRef = toTypeDeclRef.as<InterfaceDecl>()) { // Trying to convert to an interface type. // @@ -1817,12 +1819,12 @@ namespace Slang // type parameter to an interface type... // #if 0 - else if (auto genParamDeclRef = toTypeDeclRef.As<GenericTypeParamDecl>()) + else if (auto genParamDeclRef = toTypeDeclRef.as<GenericTypeParamDecl>()) { // We need to enumerate the constraints placed on this type by its outer // generic declaration, and see if any of them guarantees that we // satisfy the given interface.. - auto genericDeclRef = genParamDeclRef.GetParent().As<GenericDecl>(); + auto genericDeclRef = genParamDeclRef.GetParent().as<GenericDecl>(); SLANG_ASSERT(genericDeclRef); for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(genericDeclRef)) @@ -1830,15 +1832,15 @@ namespace Slang auto sub = GetSub(constraintDeclRef); auto sup = GetSup(constraintDeclRef); - auto subDeclRef = sub->As<DeclRefType>(); + auto subDeclRef = as<DeclRefType>(sub); if (!subDeclRef) continue; if (subDeclRef->declRef != genParamDeclRef) continue; - auto supDeclRefType = sup->As<DeclRefType>(); + auto supDeclRefType = as<DeclRefType>(sup); if (supDeclRefType) { - auto toInterfaceDeclRef = supDeclRefType->declRef.As<InterfaceDecl>(); + auto toInterfaceDeclRef = supDeclRefType->declRef.as<InterfaceDecl>(); if (DoesTypeConformToInterface(fromType, toInterfaceDeclRef)) { if (outToExpr) @@ -1856,7 +1858,7 @@ namespace Slang } // Are we converting from a parameter group type to its element type? - if(auto fromParameterGroupType = fromType->As<ParameterGroupType>()) + if(auto fromParameterGroupType = as<ParameterGroupType>(fromType)) { auto fromElementType = fromParameterGroupType->getElementType(); @@ -2136,7 +2138,7 @@ namespace Slang void CheckVarDeclCommon(RefPtr<VarDeclBase> varDecl) { // A variable that didn't have an explicit type written must - // have its type inferred from the initial-value expresison. + // have its type inferred from the initial-value expression. // if(!varDecl->type.exp) { @@ -2210,12 +2212,12 @@ namespace Slang // Fill in default substitutions for the 'subtype' part of a type constraint decl void CheckConstraintSubType(TypeExp & typeExp) { - if (auto sharedTypeExpr = typeExp.exp.As<SharedTypeExpr>()) + if (auto sharedTypeExpr = typeExp.exp.dynamicCast<SharedTypeExpr>()) { - if (auto declRefType = sharedTypeExpr->base->AsDeclRefType()) + if (auto declRefType = as<DeclRefType>(sharedTypeExpr->base)) { declRefType->declRef.substitutions = createDefaultSubstitutions(getSession(), declRefType->declRef.getDecl()); - if (auto typetype = typeExp.exp->type.type.As<TypeType>()) + if (auto typetype = typeExp.exp->type.type.dynamicCast<TypeType>()) typetype->type = declRefType; } } @@ -2249,18 +2251,18 @@ namespace Slang // check the parameters for (auto m : genericDecl->Members) { - if (auto typeParam = m.As<GenericTypeParamDecl>()) + if (auto typeParam = as<GenericTypeParamDecl>(m)) { typeParam->initType = CheckProperType(typeParam->initType); } - else if (auto valParam = m.As<GenericValueParamDecl>()) + else if (auto valParam = as<GenericValueParamDecl>(m)) { // TODO: some real checking here... CheckVarDeclCommon(valParam); } - else if (auto constraint = m.As<GenericTypeConstraintDecl>()) + else if (auto constraint = as<GenericTypeConstraintDecl>(m)) { - CheckGenericConstraintDecl(constraint.Ptr()); + CheckGenericConstraintDecl(constraint); } } @@ -2300,9 +2302,9 @@ namespace Slang // For now we only allow inheritance from interfaces, so // we will validate that the type expression names an interface - if(auto declRefType = base.type->As<DeclRefType>()) + if(auto declRefType = as<DeclRefType>(base.type)) { - if(auto interfaceDeclRef = declRefType->declRef.As<InterfaceDecl>()) + if(auto interfaceDeclRef = declRefType->declRef.as<InterfaceDecl>()) { return; } @@ -2323,7 +2325,7 @@ namespace Slang if(!intVal) return nullptr; - auto constIntVal = intVal.As<ConstantIntVal>(); + auto constIntVal = as<ConstantIntVal>(intVal); if(!constIntVal) { getSink()->diagnose(expr->loc, Diagnostics::expectedIntegerConstantNotLiteral); @@ -2342,7 +2344,7 @@ namespace Slang // but for now we are just going to look for a direct string // literal AST node. - if(auto stringLitExpr = expr.As<StringLiteralExpr>()) + if(auto stringLitExpr = as<StringLiteralExpr>(expr)) { if(outVal) { @@ -2415,7 +2417,7 @@ namespace Slang // see if we have already created an AttributeDecl for this attribute struct for (auto alt : lookupResult.items) { - if (auto adecl = alt.declRef.As<AttributeDecl>()) + if (auto adecl = alt.declRef.as<AttributeDecl>()) return adecl.getDecl(); } } @@ -2427,7 +2429,7 @@ namespace Slang if (!userDefAttribAttrib) return nullptr; // create an AttributeDecl for the user defined attribute - auto structAttribDef = lookupResult.item.declRef.As<StructDecl>().getDecl(); + auto structAttribDef = lookupResult.item.declRef.as<StructDecl>().getDecl(); RefPtr<AttributeDecl> attribDecl = new AttributeDecl(); attribDecl->nameAndLoc = structAttribDef->nameAndLoc; attribDecl->loc = structAttribDef->loc; @@ -2443,7 +2445,7 @@ namespace Slang attribDecl->syntaxClass = getSession()->findSyntaxClass(getSession()->getNameObj("UserDefinedAttribute")); for (auto member : structAttribDef->Members) { - if (auto varMember = member.As<VarDecl>()) + if (auto varMember = as<VarDecl>(member)) { RefPtr<ParamDecl> param = new ParamDecl(); param->nameAndLoc = member->nameAndLoc; @@ -2468,7 +2470,7 @@ namespace Slang } for (int i = 0; i < numArgs; ++i) { - if (!attr->args[i]->As<IntegerLiteralExpr>()) + if (!as<IntegerLiteralExpr>(attr->args[i])) { return false; } @@ -2484,7 +2486,7 @@ namespace Slang } for (int i = 0; i < numArgs; ++i) { - if (!attr->args[i]->As<StringLiteralExpr>()) + if (!as<StringLiteralExpr>(attr->args[i])) { return false; } @@ -2514,7 +2516,7 @@ namespace Slang bool validateAttribute(RefPtr<Attribute> attr, AttributeDecl* attribClassDecl) { - if(auto numThreadsAttr = attr.As<NumThreadsAttribute>()) + if(auto numThreadsAttr = as<NumThreadsAttribute>(attr)) { SLANG_ASSERT(attr->args.Count() == 3); auto xVal = checkConstantIntVal(attr->args[0]); @@ -2529,7 +2531,7 @@ namespace Slang numThreadsAttr->y = (int32_t) yVal->value; numThreadsAttr->z = (int32_t) zVal->value; } - else if (auto bindingAttr = attr.As<GLSLBindingAttribute>()) + else if (auto bindingAttr = as<GLSLBindingAttribute>(attr)) { SLANG_ASSERT(attr->args.Count() == 2); auto binding = checkConstantIntVal(attr->args[0]); @@ -2538,7 +2540,7 @@ namespace Slang bindingAttr->binding = int32_t(binding->value); bindingAttr->set = int32_t(set->value); } - else if (auto maxVertexCountAttr = attr.As<MaxVertexCountAttribute>()) + else if (auto maxVertexCountAttr = as<MaxVertexCountAttribute>(attr)) { SLANG_ASSERT(attr->args.Count() == 1); auto val = checkConstantIntVal(attr->args[0]); @@ -2547,7 +2549,7 @@ namespace Slang maxVertexCountAttr->value = (int32_t)val->value; } - else if(auto instanceAttr = attr.As<InstanceAttribute>()) + else if(auto instanceAttr = as<InstanceAttribute>(attr)) { SLANG_ASSERT(attr->args.Count() == 1); auto val = checkConstantIntVal(attr->args[0]); @@ -2556,7 +2558,7 @@ namespace Slang instanceAttr->value = (int32_t)val->value; } - else if(auto entryPointAttr = attr.As<EntryPointAttribute>()) + else if(auto entryPointAttr = as<EntryPointAttribute>(attr)) { SLANG_ASSERT(attr->args.Count() == 1); @@ -2574,11 +2576,11 @@ namespace Slang entryPointAttr->stage = stage; } - else if ((attr.As<DomainAttribute>()) || - (attr.As<MaxTessFactorAttribute>()) || - (attr.As<OutputTopologyAttribute>()) || - (attr.As<PartitioningAttribute>()) || - (attr.As<PatchConstantFuncAttribute>())) + else if ((as<DomainAttribute>(attr)) || + (as<MaxTessFactorAttribute>(attr)) || + (as<OutputTopologyAttribute>(attr)) || + (as<PartitioningAttribute>(attr)) || + (as<PatchConstantFuncAttribute>(attr))) { // Let it go thru iff single string attribute if (!hasStringArgs(attr, 1)) @@ -2586,7 +2588,7 @@ namespace Slang getSink()->diagnose(attr, Diagnostics::expectedSingleStringArg, attr->name); } } - else if (attr.As<OutputControlPointsAttribute>()) + else if (as<OutputControlPointsAttribute>(attr)) { // Let it go thru iff single integral attribute if (!hasIntArgs(attr, 1)) @@ -2594,17 +2596,17 @@ namespace Slang getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->name); } } - else if (attr.As<PushConstantAttribute>()) + else if (as<PushConstantAttribute>(attr)) { // Has no args SLANG_ASSERT(attr->args.Count() == 0); } - else if (attr.As<EarlyDepthStencilAttribute>()) + else if (as<EarlyDepthStencilAttribute>(attr)) { // Has no args SLANG_ASSERT(attr->args.Count() == 0); } - else if (auto attrUsageAttr = attr.As<AttributeUsageAttribute>()) + else if (auto attrUsageAttr = as<AttributeUsageAttribute>(attr)) { uint32_t targetClassId = (uint32_t)UserDefinedAttributeTargets::None; if (attr->args.Count() == 1) @@ -2626,7 +2628,7 @@ namespace Slang return false; } } - else if (auto userDefAttr = attr.As<UserDefinedAttribute>()) + else if (auto userDefAttr = as<UserDefinedAttribute>(attr)) { // check arguments against attribute parameters defined in attribClassDecl uint32_t paramIndex = 0; @@ -2637,7 +2639,7 @@ namespace Slang { auto & arg = attr->args[paramIndex]; bool typeChecked = false; - if (auto basicType = paramDecl->getType()->AsBasicType()) + if (auto basicType = as<BasicExpressionType>(paramDecl->getType())) { if (basicType->baseType == BaseType::Int) { @@ -2706,7 +2708,7 @@ namespace Slang } RefPtr<RefObject> attrObj = attrDecl->syntaxClass.createInstance(); - auto attr = attrObj.As<Attribute>(); + auto attr = attrObj.dynamicCast<Attribute>(); if(!attr) { SLANG_DIAGNOSE_UNEXPECTED(getSink(), attrDecl, "attribute class did not yield an attribute object"); @@ -2800,7 +2802,7 @@ namespace Slang RefPtr<Modifier> m, ModifiableSyntaxNode* syntaxNode) { - if(auto hlslUncheckedAttribute = m.As<UncheckedAttribute>()) + if(auto hlslUncheckedAttribute = m.dynamicCast<UncheckedAttribute>()) { // We have an HLSL `[name(arg,...)]` attribute, and we'd like // to check that it is provides all the expected arguments @@ -2880,7 +2882,7 @@ namespace Slang for (auto decl : programNode->Members) { auto inner = decl; - if (auto genericDecl = decl.As<GenericDecl>()) + if (auto genericDecl = as<GenericDecl>(decl)) { inner = genericDecl->inner; } @@ -2906,7 +2908,7 @@ namespace Slang registerExtension(s); for (auto & g : programNode->getMembersOfType<GenericDecl>()) { - if (auto extDecl = g->inner->As<ExtensionDecl>()) + if (auto extDecl = as<ExtensionDecl>(g->inner)) { checkGenericDeclHeader(g); registerExtension(extDecl); @@ -2915,7 +2917,7 @@ namespace Slang // check user defined attribute classes first for (auto decl : programNode->Members) { - if (auto typeMember = decl->As<StructDecl>()) + if (auto typeMember = as<StructDecl>(decl)) { bool isTypeAttributeClass = false; for (auto attrib : typeMember->GetModifiersOfType<UncheckedAttribute>()) @@ -2986,9 +2988,9 @@ namespace Slang checkExtensionConformance(s); for (auto & g : programNode->getMembersOfType<GenericDecl>()) { - if (auto innerAggDecl = g->inner->As<AggTypeDecl>()) + if (auto innerAggDecl = as<AggTypeDecl>(g->inner)) checkAggTypeConformance(innerAggDecl); - else if (auto innerExtDecl = g->inner->As<ExtensionDecl>()) + else if (auto innerExtDecl = as<ExtensionDecl>(g->inner)) checkExtensionConformance(innerExtDecl); } } @@ -3034,17 +3036,17 @@ namespace Slang { auto genMbr = genDecl.getDecl()->Members[i]; auto requiredGenMbr = genDecl.getDecl()->Members[i]; - if (auto genTypeMbr = genMbr.As<GenericTypeParamDecl>()) + if (auto genTypeMbr = genMbr.dynamicCast<GenericTypeParamDecl>()) { - if (auto requiredGenTypeMbr = requiredGenMbr.As<GenericTypeParamDecl>()) + if (auto requiredGenTypeMbr = requiredGenMbr.dynamicCast<GenericTypeParamDecl>()) { } else return false; } - else if (auto genValMbr = genMbr.As<GenericValueParamDecl>()) + else if (auto genValMbr = genMbr.dynamicCast<GenericValueParamDecl>()) { - if (auto requiredGenValMbr = requiredGenMbr.As<GenericValueParamDecl>()) + if (auto requiredGenValMbr = requiredGenMbr.dynamicCast<GenericValueParamDecl>()) { if (!genValMbr->type->Equals(requiredGenValMbr->type)) return false; @@ -3052,9 +3054,9 @@ namespace Slang else return false; } - else if (auto genTypeConstraintMbr = genMbr.As<GenericTypeConstraintDecl>()) + else if (auto genTypeConstraintMbr = genMbr.dynamicCast<GenericTypeConstraintDecl>()) { - if (auto requiredTypeConstraintMbr = requiredGenMbr.As<GenericTypeConstraintDecl>()) + if (auto requiredTypeConstraintMbr = requiredGenMbr.dynamicCast<GenericTypeConstraintDecl>()) { if (!genTypeConstraintMbr->sup->Equals(requiredTypeConstraintMbr->sup)) { @@ -3069,7 +3071,7 @@ namespace Slang // TODO: this isn't right, because we need to specialize the // declarations of the generics to a common set of substitutions, // so that their types are comparable (e.g., foo<T> and foo<U> - // need to have substutition applies so that they are both foo<X>, + // need to have substitutions applies so that they are both foo<X>, // after which uses of the type X in their parameter lists can // be compared). @@ -3118,7 +3120,7 @@ namespace Slang if(conformance) { - // If all the constraints were satsified, then the chosen + // If all the constraints were satisfied, then the chosen // type can indeed satisfy the interface requirement. witnessTable->requirementDictionary.Add( requiredAssociatedTypeDeclRef.getDecl(), @@ -3158,9 +3160,9 @@ namespace Slang // to be satisfied by any type declaration: // a typedef, a `struct`, etc. // - if (auto memberFuncDecl = memberDeclRef.As<FuncDecl>()) + if (auto memberFuncDecl = memberDeclRef.as<FuncDecl>()) { - if (auto requiredFuncDeclRef = requiredMemberDeclRef.As<FuncDecl>()) + if (auto requiredFuncDeclRef = requiredMemberDeclRef.as<FuncDecl>()) { // Check signature match. return doesSignatureMatchRequirement( @@ -3169,9 +3171,9 @@ namespace Slang witnessTable); } } - else if (auto memberInitDecl = memberDeclRef.As<ConstructorDecl>()) + else if (auto memberInitDecl = memberDeclRef.as<ConstructorDecl>()) { - if (auto requiredInitDecl = requiredMemberDeclRef.As<ConstructorDecl>()) + if (auto requiredInitDecl = requiredMemberDeclRef.as<ConstructorDecl>()) { // Check signature match. return doesSignatureMatchRequirement( @@ -3180,7 +3182,7 @@ namespace Slang witnessTable); } } - else if (auto genDecl = memberDeclRef.As<GenericDecl>()) + else if (auto genDecl = memberDeclRef.as<GenericDecl>()) { // For a generic member, we will check if it can satisfy // a generic requirement in the interface. @@ -3192,14 +3194,14 @@ namespace Slang // to require performing something akin to overload // resolution as part of requirement satisfaction. // - if (auto requiredGenDeclRef = requiredMemberDeclRef.As<GenericDecl>()) + if (auto requiredGenDeclRef = requiredMemberDeclRef.as<GenericDecl>()) { return doesGenericSignatureMatchRequirement(genDecl, requiredGenDeclRef, witnessTable); } } - else if (auto subAggTypeDeclRef = memberDeclRef.As<AggTypeDecl>()) + else if (auto subAggTypeDeclRef = memberDeclRef.as<AggTypeDecl>()) { - if(auto requiredTypeDeclRef = requiredMemberDeclRef.As<AssocTypeDecl>()) + if(auto requiredTypeDeclRef = requiredMemberDeclRef.as<AssocTypeDecl>()) { checkDecl(subAggTypeDeclRef.getDecl()); @@ -3207,11 +3209,11 @@ namespace Slang return doesTypeSatisfyAssociatedTypeRequirement(satisfyingType, requiredTypeDeclRef, witnessTable); } } - else if (auto typedefDeclRef = memberDeclRef.As<TypeDefDecl>()) + else if (auto typedefDeclRef = memberDeclRef.as<TypeDefDecl>()) { // this is a type-def decl in an aggregate type // check if the specified type satisfies the constraints defined by the associated type - if (auto requiredTypeDeclRef = requiredMemberDeclRef.As<AssocTypeDecl>()) + if (auto requiredTypeDeclRef = requiredMemberDeclRef.as<AssocTypeDecl>()) { checkDecl(typedefDeclRef.getDecl()); @@ -3278,7 +3280,7 @@ namespace Slang // full of the satisfying values for each requirement // in the inherited-from interface. // - if( auto requiredInheritanceDeclRef = requiredMemberDeclRef.As<InheritanceDecl>() ) + if( auto requiredInheritanceDeclRef = requiredMemberDeclRef.as<InheritanceDecl>() ) { // Recursively check that the type conforms // to the inherited interface. @@ -3485,10 +3487,10 @@ namespace Slang InheritanceDecl* inheritanceDecl, Type* baseType) { - if (auto baseDeclRefType = baseType->As<DeclRefType>()) + if (auto baseDeclRefType = as<DeclRefType>(baseType)) { auto baseTypeDeclRef = baseDeclRefType->declRef; - if (auto baseInterfaceDeclRef = baseTypeDeclRef.As<InterfaceDecl>()) + if (auto baseInterfaceDeclRef = baseTypeDeclRef.as<InterfaceDecl>()) { // The type is stating that it conforms to an interface. // We need to check that it provides all of the members @@ -3513,18 +3515,18 @@ namespace Slang DeclRef<AggTypeDeclBase> declRef, InheritanceDecl* inheritanceDecl) { - declRef = createDefaultSubstitutionsIfNeeded(getSession(), declRef).As<AggTypeDeclBase>(); + declRef = createDefaultSubstitutionsIfNeeded(getSession(), declRef).as<AggTypeDeclBase>(); // Don't check conformances for abstract types that // are being used to express *required* conformances. - if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>()) + if (auto assocTypeDeclRef = declRef.as<AssocTypeDecl>()) { // An associated type declaration represents a requirement // in an outer interface declaration, and its members // (type constraints) represent additional requirements. return true; } - else if (auto interfaceDeclRef = declRef.As<InterfaceDecl>()) + else if (auto interfaceDeclRef = declRef.as<InterfaceDecl>()) { // HACK: Our semantics as they stand today are that an // `extension` of an interface that adds a new inheritance @@ -3554,9 +3556,9 @@ namespace Slang void checkExtensionConformance(ExtensionDecl* decl) { - if (auto targetDeclRefType = decl->targetType->As<DeclRefType>()) + if (auto targetDeclRefType = as<DeclRefType>(decl->targetType)) { - if (auto aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDecl>()) + if (auto aggTypeDeclRef = targetDeclRefType->declRef.as<AggTypeDecl>()) { for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>()) { @@ -3614,7 +3616,7 @@ namespace Slang // // TODO: We should also add a pass that takes // all the stated inheritance relationships, - // expands them to include implicitic inheritance, + // expands them to include implicit inheritance, // and then linearizes them. This would allow // later passes that need to know everything // a type inherits from to proceed linearly @@ -3655,9 +3657,9 @@ namespace Slang // as the tag type for an `enum` void validateEnumTagType(Type* type, SourceLoc const& loc) { - if(auto basicType = type->As<BasicExpressionType>()) + if(auto basicType = as<BasicExpressionType>(type)) { - // Allow the built-in intteger types. + // Allow the built-in integer types. if(isIntegerBaseType(basicType->baseType)) return; @@ -3694,14 +3696,14 @@ namespace Slang // Look at the type being inherited from. auto superType = inheritanceDecl->base.type; - if(auto errorType = superType->As<ErrorType>()) + if(auto errorType = as<ErrorType>(superType)) { // Ignore any erroneous inheritance clauses. continue; } - else if(auto declRefType = superType->As<DeclRefType>()) + else if(auto declRefType = as<DeclRefType>(superType)) { - if(auto interfaceDeclRef = declRefType->declRef.As<InterfaceDecl>()) + if(auto interfaceDeclRef = declRefType->declRef.as<InterfaceDecl>()) { // Don't consider interface bases as candidates for // the tag type. @@ -3740,7 +3742,7 @@ namespace Slang // type is suitable. (e.g., if we are going // to allow raw values for case tags to be // derived automatically, then the tag - // type needs to be some kind of interer type...) + // type needs to be some kind of integer type...) // // For now we will just be harsh and require it // to be one of a few builtin types. @@ -3772,9 +3774,9 @@ namespace Slang Name* tagAssociatedTypeName = getSession()->getNameObj("__Tag"); Decl* tagAssociatedTypeDecl = nullptr; - if(auto enumTypeTypeDeclRefType = enumTypeType.As<DeclRefType>()) + if(auto enumTypeTypeDeclRefType = enumTypeType.dynamicCast<DeclRefType>()) { - if(auto enumTypeTypeInterfaceDecl = enumTypeTypeDeclRefType->declRef.getDecl()->As<InterfaceDecl>()) + if(auto enumTypeTypeInterfaceDecl = as<InterfaceDecl>(enumTypeTypeDeclRefType->declRef.getDecl())) { for(auto memberDecl : enumTypeTypeInterfaceDecl->Members) { @@ -3791,7 +3793,7 @@ namespace Slang SLANG_DIAGNOSE_UNEXPECTED(getSink(), decl, "failed to find built-in declaration '__Tag'"); } - // Okay, add the conformance withess for `__Tag` being satisfied by `tagType` + // Okay, add the conformance witness for `__Tag` being satisfied by `tagType` witnessTable->requirementDictionary.Add(tagAssociatedTypeDecl, RequirementWitness(tagType)); // TODO: we actually also need to synthesize a witness for the conformance of `tagType` @@ -3842,7 +3844,7 @@ namespace Slang RefPtr<IntVal> explicitTagVal = TryConstantFoldExpr(explicitTagValExpr); if(explicitTagVal) { - if(auto constIntVal = explicitTagVal.As<ConstantIntVal>()) + if(auto constIntVal = explicitTagVal.dynamicCast<ConstantIntVal>()) { defaultTag = constIntVal->value; } @@ -3884,11 +3886,11 @@ namespace Slang for(auto memberDecl : decl->Members) { // Already checked inheritance declarations above. - if(auto inheritanceDecl = memberDecl->As<InheritanceDecl>()) + if(auto inheritanceDecl = as<InheritanceDecl>(memberDecl)) continue; // Already checked enum case declarations above. - if(auto caseDecl = memberDecl->As<EnumCaseDecl>()) + if(auto caseDecl = as<EnumCaseDecl>(memberDecl)) continue; // TODO: Right now we don't support other kinds of @@ -3911,7 +3913,7 @@ namespace Slang // An enum case had better appear inside an enum! // // TODO: Do we need/want to support generic cases some day? - auto parentEnumDecl = decl->ParentDecl->As<EnumDecl>(); + auto parentEnumDecl = as<EnumDecl>(decl->ParentDecl); SLANG_ASSERT(parentEnumDecl); // The tag type should have already been set by @@ -3963,7 +3965,7 @@ namespace Slang { decl->SetCheckState(DeclCheckState::CheckedHeader); // global generic param only allowed in global scope - auto program = decl->ParentDecl->As<ModuleDecl>(); + auto program = as<ModuleDecl>(decl->ParentDecl); if (!program) getSink()->diagnose(decl, Slang::Diagnostics::globalGenParamInGlobalScopeOnly); // Now check all of the member declarations. @@ -3983,7 +3985,7 @@ namespace Slang decl->SetCheckState(DeclCheckState::CheckedHeader); // assoctype only allowed in an interface - auto interfaceDecl = decl->ParentDecl->As<InterfaceDecl>(); + auto interfaceDecl = as<InterfaceDecl>(decl->ParentDecl); if (!interfaceDecl) getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly); @@ -4039,11 +4041,11 @@ namespace Slang if (dd == decl->inner) continue; - if (auto typeParamDecl = dd.As<GenericTypeParamDecl>()) + if (auto typeParamDecl = as<GenericTypeParamDecl>(dd)) outParams.Add(typeParamDecl); - else if (auto valueParamDecl = dd.As<GenericValueParamDecl>()) + else if (auto valueParamDecl = as<GenericValueParamDecl>(dd)) outParams.Add(valueParamDecl); - else if (auto constraintDecl = dd.As<GenericTypeConstraintDecl>()) + else if (auto constraintDecl = as<GenericTypeConstraintDecl>(dd)) outConstraints.Add(constraintDecl); } } @@ -4056,7 +4058,7 @@ namespace Slang // in each generic signature. We will consider parameters // and constraints separately so that we are independent // of the order in which constraints are given (that is, - // a constraint like `<T : IFoo>` whould be considered + // a constraint like `<T : IFoo>` would be considered // the same as `<T>` with a later `where T : IFoo`. List<Decl*> fstParams; @@ -4202,16 +4204,16 @@ namespace Slang if (dd == genericDecl->inner) continue; - if (auto typeParam = dd.As<GenericTypeParamDecl>()) + if (auto typeParam = as<GenericTypeParamDecl>(dd)) { auto type = DeclRefType::Create(getSession(), - makeDeclRef(typeParam.Ptr())); + makeDeclRef(typeParam)); subst->args.Add(type); } - else if (auto valueParam = dd.As<GenericValueParamDecl>()) + else if (auto valueParam = as<GenericValueParamDecl>(dd)) { auto val = new GenericParamIntVal( - makeDeclRef(valueParam.Ptr())); + makeDeclRef(valueParam)); subst->args.Add(val); } // TODO: need to handle constraints here? @@ -4751,7 +4753,7 @@ namespace Slang IntegerLiteralValue GetMinBound(RefPtr<IntVal> val) { - if (auto constantVal = val.As<ConstantIntVal>()) + if (auto constantVal = as<ConstantIntVal>(val)) return constantVal->value; // TODO(tfoley): Need to track intervals so that this isn't just a lie... @@ -4761,7 +4763,7 @@ namespace Slang void maybeInferArraySizeForVariable(VarDeclBase* varDecl) { // Not an array? - auto arrayType = varDecl->type->AsArrayType(); + auto arrayType = as<ArrayExpressionType>(varDecl->type); if (!arrayType) return; // Explicit element count given? @@ -4773,7 +4775,7 @@ namespace Slang if(!initExpr) return; // Is the type of the initializer an array type? - if(auto arrayInitType = initExpr->type->As<ArrayExpressionType>()) + if(auto arrayInitType = as<ArrayExpressionType>(initExpr->type)) { elementCount = arrayInitType->ArrayLength; } @@ -4792,7 +4794,7 @@ namespace Slang void validateArraySizeForVariable(VarDeclBase* varDecl) { - auto arrayType = varDecl->type->AsArrayType(); + auto arrayType = as<ArrayExpressionType>(varDecl->type); if (!arrayType) return; auto elementCount = arrayType->ArrayLength; @@ -4893,7 +4895,7 @@ namespace Slang // // For right now we will look for calls to intrinsic functions, and then inspect // their names (this is bad and slow). - auto funcDeclRefExpr = invokeExpr->FunctionExpr.As<DeclRefExpr>(); + auto funcDeclRefExpr = invokeExpr->FunctionExpr.dynamicCast<DeclRefExpr>(); if (!funcDeclRefExpr) return nullptr; auto funcDeclRef = funcDeclRefExpr->declRef; @@ -4904,7 +4906,7 @@ namespace Slang // operation right now. // // TODO: we should really allow constant-folding for anything - // that can be lowerd to our bytecode... + // that can be lowered to our bytecode... return nullptr; } @@ -4928,7 +4930,7 @@ namespace Slang argVals[argCount] = argVal; - if (auto constArgVal = argVal.As<ConstantIntVal>()) + if (auto constArgVal = as<ConstantIntVal>(argVal)) { constArgVals[argCount] = constArgVal->value; } @@ -5025,7 +5027,7 @@ namespace Slang { auto declRef = declRefExpr->declRef; - if (auto genericValParamRef = declRef.As<GenericValueParamDecl>()) + if (auto genericValParamRef = declRef.as<GenericValueParamDecl>()) { // TODO(tfoley): handle the case of non-`int` value parameters... return new GenericParamIntVal(genericValParamRef); @@ -5033,7 +5035,7 @@ namespace Slang // We may also need to check for references to variables that // are defined in a way that can be used as a constant expression: - if(auto varRef = declRef.As<VarDeclBase>()) + if(auto varRef = declRef.as<VarDeclBase>()) { auto varDecl = varRef.getDecl(); @@ -5090,7 +5092,7 @@ namespace Slang } } - else if(auto enumRef = declRef.As<EnumCaseDecl>()) + else if(auto enumRef = declRef.as<EnumCaseDecl>()) { // The cases in an `enum` declaration can also be used as constant expressions, if(auto tagExpr = getTagExpr(enumRef)) @@ -5121,7 +5123,7 @@ namespace Slang RefPtr<IntVal> TryCheckIntegerConstantExpression(Expr* exp) { // Check if type is acceptable for an integer constant expression - if(auto basicType = exp->type.type->As<BasicExpressionType>()) + if(auto basicType = as<BasicExpressionType>(exp->type.type)) { if(!isIntegerBaseType(basicType->baseType)) return nullptr; @@ -5195,7 +5197,7 @@ namespace Slang { auto session = getSession(); auto vectorGenericDecl = findMagicDecl( - session, "Vector").As<GenericDecl>(); + session, "Vector").dynamicCast<GenericDecl>(); auto vectorTypeDecl = vectorGenericDecl->inner; auto substitutions = new GenericSubstitution(); @@ -5205,9 +5207,9 @@ namespace Slang auto declRef = DeclRef<Decl>(vectorTypeDecl.Ptr(), substitutions); - return DeclRefType::Create( + return as<VectorExpressionType>(DeclRefType::Create( session, - declRef)->As<VectorExpressionType>(); + declRef)); } RefPtr<Expr> visitIndexExpr(IndexExpr* subscriptExpr) @@ -5232,7 +5234,7 @@ namespace Slang // Otherwise, we need to look at the type of the base expression, // to figure out how subscripting should work. auto baseType = baseExpr->type.Ptr(); - if (auto baseTypeType = baseType->As<TypeType>()) + if (auto baseTypeType = as<TypeType>(baseType)) { // We are trying to "index" into a type, so we have an expression like `float[2]` // which should be interpreted as resolving to an array type. @@ -5252,19 +5254,19 @@ namespace Slang subscriptExpr->type = QualType(getTypeType(arrayType)); return subscriptExpr; } - else if (auto baseArrayType = baseType->As<ArrayExpressionType>()) + else if (auto baseArrayType = as<ArrayExpressionType>(baseType)) { return CheckSimpleSubscriptExpr( subscriptExpr, baseArrayType->baseType); } - else if (auto vecType = baseType->As<VectorExpressionType>()) + else if (auto vecType = as<VectorExpressionType>(baseType)) { return CheckSimpleSubscriptExpr( subscriptExpr, vecType->elementType); } - else if (auto matType = baseType->As<MatrixExpressionType>()) + else if (auto matType = as<MatrixExpressionType>(baseType)) { // TODO(tfoley): We shouldn't go and recompute // row types over and over like this... :( @@ -5365,11 +5367,11 @@ namespace Slang RefPtr<Expr> e = expr; for(;;) { - if(auto memberExpr = e.As<MemberExpr>()) + if(auto memberExpr = as<MemberExpr>(e)) { e = memberExpr->BaseExpression; } - else if(auto subscriptExpr = e.As<IndexExpr>()) + else if(auto subscriptExpr = as<IndexExpr>(e)) { e = subscriptExpr->BaseExpression; } @@ -5381,7 +5383,7 @@ namespace Slang // // Now we check to see if we have a `this` expression, // and if it is immutable. - if(auto thisExpr = e.As<ThisExpr>()) + if(auto thisExpr = as<ThisExpr>(e)) { if(!thisExpr->type.IsLeftValue) { @@ -5400,9 +5402,9 @@ namespace Slang if (!type.IsLeftValue) { - if (type->As<ErrorType>()) + if (as<ErrorType>(type)) { - // Don't report an l-value issue on an errorneous expression + // Don't report an l-value issue on an erroneous expression } else { @@ -5431,10 +5433,10 @@ namespace Slang // TODO: need to check that the target type names a declaration... - if (auto targetDeclRefType = decl->targetType->As<DeclRefType>()) + if (auto targetDeclRefType = as<DeclRefType>(decl->targetType)) { // Attach our extension to that type as a candidate... - if (auto aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDecl>()) + if (auto aggTypeDeclRef = targetDeclRefType->declRef.as<AggTypeDecl>()) { auto aggTypeDecl = aggTypeDeclRef.getDecl(); decl->nextCandidateExtension = aggTypeDecl->candidateExtensions; @@ -5449,7 +5451,7 @@ namespace Slang { if (decl->IsChecked(getCheckedState())) return; - if (!decl->targetType->As<DeclRefType>()) + if (!as<DeclRefType>(decl->targetType)) { getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here"); } @@ -5573,7 +5575,7 @@ namespace Slang { if (checkingPhase == CheckingPhase::Header) { - // An acessor must appear nested inside a subscript declaration (today), + // An accessor must appear nested inside a subscript declaration (today), // or a property declaration (when we add them). It will derive // its return type from the outer declaration, so we handle both // of these checks at the same place. @@ -5606,8 +5608,8 @@ namespace Slang bool satisfied = false; // Has this constraint been met? }; - // A collection of constraints that will need to be satisified (solved) - // in order for checking to suceed. + // A collection of constraints that will need to be satisfied (solved) + // in order for checking to succeed. struct ConstraintSystem { // A source location to use in reporting any issues @@ -5710,7 +5712,7 @@ namespace Slang RefPtr<SubtypeWitness>* link = &witness; // As long as there is more than one breadcrumb, we - // need to be creating transitie witnesses. + // need to be creating transitive witnesses. while(bb->prev) { // On the first iteration when processing the list @@ -5780,7 +5782,7 @@ namespace Slang DeclRef<InterfaceDecl> interfaceDeclRef, DeclRef<Decl> requirementDeclRef) { - if(auto callableDeclRef = requirementDeclRef.As<CallableDecl>()) + if(auto callableDeclRef = requirementDeclRef.as<CallableDecl>()) { // A `static` method requirement can't be satisfied by a // tagged union, because there is no tag to dispatch on. @@ -5812,7 +5814,7 @@ namespace Slang TypeWitnessBreadcrumb* inBreadcrumbs) { // for now look up a conformance member... - if(auto declRefType = type->As<DeclRefType>()) + if(auto declRefType = as<DeclRefType>(type)) { auto declRef = declRefType->declRef; @@ -5830,7 +5832,7 @@ namespace Slang return true; } - if( auto aggTypeDeclRef = declRef.As<AggTypeDecl>() ) + if( auto aggTypeDeclRef = declRef.as<AggTypeDecl>() ) { checkDecl(aggTypeDeclRef.getDecl()); @@ -5842,9 +5844,9 @@ namespace Slang // that is being inherited from. This is dangerous because // it might lead to infinite loops. // - // TODO: A better appraoch would be to create a linearized list - // of all the interfaces that a given type direclty or indirectly - // inheirts, and store it with the type, so that we don't have + // TODO: A better approach would be to create a linearized list + // of all the interfaces that a given type directly or indirectly + // inherits, and store it with the type, so that we don't have // to recurse in places like this (and can maybe catch infinite // loops better). This would also help avoid checking multiply-inherited // conformances multiple times. @@ -5882,12 +5884,12 @@ namespace Slang } } } - else if( auto genericTypeParamDeclRef = declRef.As<GenericTypeParamDecl>() ) + else if( auto genericTypeParamDeclRef = declRef.as<GenericTypeParamDecl>() ) { // We need to enumerate the constraints placed on this type by its outer // generic declaration, and see if any of them guarantees that we // satisfy the given interface.. - auto genericDeclRef = genericTypeParamDeclRef.GetParent().As<GenericDecl>(); + auto genericDeclRef = genericTypeParamDeclRef.GetParent().as<GenericDecl>(); SLANG_ASSERT(genericDeclRef); for( auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(genericDeclRef) ) @@ -5895,7 +5897,7 @@ namespace Slang auto sub = GetSub(constraintDeclRef); auto sup = GetSup(constraintDeclRef); - auto subDeclRef = sub->As<DeclRefType>(); + auto subDeclRef = as<DeclRefType>(sub); if(!subDeclRef) continue; if(subDeclRef->declRef != genericTypeParamDeclRef) @@ -5918,7 +5920,7 @@ namespace Slang } } } - else if(auto taggedUnionType = type->As<TaggedUnionType>()) + else if(auto taggedUnionType = as<TaggedUnionType>(type)) { // A tagged union type conforms to an interface if all of // the constituent types in the tagged union conform. @@ -6049,7 +6051,7 @@ namespace Slang // through types `X` that are also builtin scalar types. // RefPtr<Type> bestType; - if(auto basicType = type.As<BasicExpressionType>()) + if(auto basicType = type.dynamicCast<BasicExpressionType>()) { for(Int baseTypeFlavorIndex = 0; baseTypeFlavorIndex < Int(BaseType::CountOf); baseTypeFlavorIndex++) { @@ -6092,7 +6094,7 @@ namespace Slang { // Our candidate can convert to the current "best" type, so // it is logically a more specific type that satisfies our - // constraints, thereforce we should keep it. + // constraints, therefore we should keep it. // bestType = candidateType; } @@ -6127,9 +6129,9 @@ namespace Slang return left; // We can join two basic types by picking the "better" of the two - if (auto leftBasic = left->As<BasicExpressionType>()) + if (auto leftBasic = as<BasicExpressionType>(left)) { - if (auto rightBasic = right->As<BasicExpressionType>()) + if (auto rightBasic = as<BasicExpressionType>(right)) { auto leftFlavor = leftBasic->baseType; auto rightFlavor = rightBasic->baseType; @@ -6149,7 +6151,7 @@ namespace Slang } // We can also join a vector and a scalar - if(auto rightVector = right->As<VectorExpressionType>()) + if(auto rightVector = as<VectorExpressionType>(right)) { return TryJoinVectorAndScalarType(rightVector, leftBasic); } @@ -6157,9 +6159,9 @@ namespace Slang // We can join two vector types by joining their element types // (and also their sizes...) - if( auto leftVector = left->As<VectorExpressionType>()) + if( auto leftVector = as<VectorExpressionType>(left)) { - if(auto rightVector = right->As<VectorExpressionType>()) + if(auto rightVector = as<VectorExpressionType>(right)) { // Check if the vector sizes match if(!leftVector->elementCount->EqualsVal(rightVector->elementCount.Ptr())) @@ -6178,24 +6180,24 @@ namespace Slang } // We can also join a vector and a scalar - if(auto rightBasic = right->As<BasicExpressionType>()) + if(auto rightBasic = as<BasicExpressionType>(right)) { return TryJoinVectorAndScalarType(leftVector, rightBasic); } } // HACK: trying to work trait types in here... - if(auto leftDeclRefType = left->As<DeclRefType>()) + if(auto leftDeclRefType = as<DeclRefType>(left)) { - if( auto leftInterfaceRef = leftDeclRefType->declRef.As<InterfaceDecl>() ) + if( auto leftInterfaceRef = leftDeclRefType->declRef.as<InterfaceDecl>() ) { // return TryJoinTypeWithInterface(right, leftInterfaceRef); } } - if(auto rightDeclRefType = right->As<DeclRefType>()) + if(auto rightDeclRefType = as<DeclRefType>(right)) { - if( auto rightInterfaceRef = rightDeclRefType->declRef.As<InterfaceDecl>() ) + if( auto rightInterfaceRef = rightDeclRefType->declRef.as<InterfaceDecl>() ) { // return TryJoinTypeWithInterface(left, rightInterfaceRef); @@ -6248,7 +6250,7 @@ namespace Slang List<RefPtr<Val>> args; for (auto m : getMembers(genericDeclRef)) { - if (auto typeParam = m.As<GenericTypeParamDecl>()) + if (auto typeParam = m.as<GenericTypeParamDecl>()) { RefPtr<Type> type = nullptr; for (auto& c : system->constraints) @@ -6256,7 +6258,7 @@ namespace Slang if (c.decl != typeParam.getDecl()) continue; - auto cType = c.val.As<Type>(); + auto cType = c.val.dynamicCast<Type>(); SLANG_RELEASE_ASSERT(cType.Ptr()); if (!type) @@ -6284,7 +6286,7 @@ namespace Slang } args.Add(type); } - else if (auto valParam = m.As<GenericValueParamDecl>()) + else if (auto valParam = m.as<GenericValueParamDecl>()) { // TODO(tfoley): maybe support more than integers some day? // TODO(tfoley): figure out how this needs to interact with @@ -6295,7 +6297,7 @@ namespace Slang if (c.decl != valParam.getDecl()) continue; - auto cVal = c.val.As<IntVal>(); + auto cVal = c.val.dynamicCast<IntVal>(); SLANG_RELEASE_ASSERT(cVal.Ptr()); if (!val) @@ -6485,7 +6487,7 @@ namespace Slang ParamCounts counts = { 0, 0 }; for (auto m : genericRef.getDecl()->Members) { - if (auto typeParam = m.As<GenericTypeParamDecl>()) + if (auto typeParam = as<GenericTypeParamDecl>(m)) { counts.allowed++; if (!typeParam->initType.Ptr()) @@ -6493,7 +6495,7 @@ namespace Slang counts.required++; } } - else if (auto valParam = m.As<GenericValueParamDecl>()) + else if (auto valParam = as<GenericValueParamDecl>(m)) { counts.allowed++; if (!valParam->initExpr) @@ -6514,11 +6516,11 @@ namespace Slang switch (candidate.flavor) { case OverloadCandidate::Flavor::Func: - paramCounts = CountParameters(GetParameters(candidate.item.declRef.As<CallableDecl>())); + paramCounts = CountParameters(GetParameters(candidate.item.declRef.as<CallableDecl>())); break; case OverloadCandidate::Flavor::Generic: - paramCounts = CountParameters(candidate.item.declRef.As<GenericDecl>()); + paramCounts = CountParameters(candidate.item.declRef.as<GenericDecl>()); break; default: @@ -6554,7 +6556,7 @@ namespace Slang auto decl = candidate.item.declRef.decl; - if(auto prefixExpr = expr.As<PrefixExpr>()) + if(auto prefixExpr = as<PrefixExpr>(expr)) { if(decl->HasModifier<PrefixModifier>()) return true; @@ -6567,7 +6569,7 @@ namespace Slang return false; } - else if(auto postfixExpr = expr.As<PostfixExpr>()) + else if(auto postfixExpr = as<PostfixExpr>(expr)) { if(decl->HasModifier<PostfixModifier>()) return true; @@ -6592,7 +6594,7 @@ namespace Slang OverloadResolveContext& context, OverloadCandidate& candidate) { - auto genericDeclRef = candidate.item.declRef.As<GenericDecl>(); + auto genericDeclRef = candidate.item.declRef.as<GenericDecl>(); // We will go ahead and hang onto the arguments that we've // already checked, since downstream validation might need @@ -6604,7 +6606,7 @@ namespace Slang uint32_t aa = 0; for (auto memberRef : getMembers(genericDeclRef)) { - if (auto typeParamRef = memberRef.As<GenericTypeParamDecl>()) + if (auto typeParamRef = memberRef.as<GenericTypeParamDecl>()) { if (aa >= context.argCount) { @@ -6627,7 +6629,7 @@ namespace Slang } checkedArgs.Add(typeExp.type); } - else if (auto valParamRef = memberRef.As<GenericValueParamDecl>()) + else if (auto valParamRef = memberRef.as<GenericValueParamDecl>()) { auto arg = context.getArg(aa++); @@ -6665,7 +6667,7 @@ namespace Slang switch (candidate.flavor) { case OverloadCandidate::Flavor::Func: - params = GetParameters(candidate.item.declRef.As<CallableDecl>()).ToArray(); + params = GetParameters(candidate.item.declRef.as<CallableDecl>()).ToArray(); break; case OverloadCandidate::Flavor::Generic: @@ -6741,10 +6743,10 @@ namespace Slang return createTypeEqualityWitness(sub); } - if(auto supDeclRefType = sup->As<DeclRefType>()) + if(auto supDeclRefType = as<DeclRefType>(sup)) { auto supDeclRef = supDeclRefType->declRef; - if(auto supInterfaceDeclRef = supDeclRef.As<InterfaceDecl>()) + if(auto supInterfaceDeclRef = supDeclRef.as<InterfaceDecl>()) { if(auto witness = tryGetInterfaceConformanceWitness(sub, supInterfaceDeclRef)) { @@ -6773,13 +6775,13 @@ namespace Slang if(candidate.flavor != OverloadCandidate::Flavor::Generic) return true; - auto genericDeclRef = candidate.item.declRef.As<GenericDecl>(); + auto genericDeclRef = candidate.item.declRef.as<GenericDecl>(); SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't be a generic candidate... // We should have the existing arguments to the generic // handy, so that we can construct a substitution list. - RefPtr<GenericSubstitution> subst = candidate.subst.As<GenericSubstitution>(); + RefPtr<GenericSubstitution> subst = candidate.subst.dynamicCast<GenericSubstitution>(); SLANG_ASSERT(subst); subst->genericDecl = genericDeclRef.getDecl(); @@ -6849,13 +6851,13 @@ namespace Slang RefPtr<Expr> originalExpr, RefPtr<GenericSubstitution> subst) { - auto baseDeclRefExpr = baseExpr.As<DeclRefExpr>(); + auto baseDeclRefExpr = baseExpr.dynamicCast<DeclRefExpr>(); if (!baseDeclRefExpr) { SLANG_DIAGNOSE_UNEXPECTED(getSink(), baseExpr, "expected a reference to a generic declaration"); return CreateErrorExpr(originalExpr); } - auto baseGenericRef = baseDeclRefExpr->declRef.As<GenericDecl>(); + auto baseGenericRef = baseDeclRefExpr->declRef.as<GenericDecl>(); if (!baseGenericRef) { SLANG_DIAGNOSE_UNEXPECTED(getSink(), baseExpr, "expected a reference to a generic declaration"); @@ -6868,7 +6870,7 @@ namespace Slang DeclRef<Decl> innerDeclRef(GetInner(baseGenericRef), subst); RefPtr<Expr> base; - if (auto mbrExpr = baseExpr.As<MemberExpr>()) + if (auto mbrExpr = as<MemberExpr>(baseExpr)) base = mbrExpr->BaseExpression; return ConstructDeclRefExpr( @@ -6926,7 +6928,7 @@ namespace Slang { case OverloadCandidate::Flavor::Func: { - RefPtr<AppExprBase> callExpr = context.originalExpr.As<InvokeExpr>(); + RefPtr<AppExprBase> callExpr = context.originalExpr.as<InvokeExpr>(); if(!callExpr) { callExpr = new InvokeExpr(); @@ -6941,7 +6943,7 @@ namespace Slang callExpr->type = QualType(candidate.resultType); // A call may yield an l-value, and we should take a look at the candidate to be sure - if(auto subscriptDeclRef = candidate.item.declRef.As<SubscriptDecl>()) + if(auto subscriptDeclRef = candidate.item.declRef.as<SubscriptDecl>()) { for(auto setter : subscriptDeclRef.getDecl()->getMembersOfType<SetterDecl>()) { @@ -6964,7 +6966,7 @@ namespace Slang return createGenericDeclRef( baseExpr, context.originalExpr, - candidate.subst.As<GenericSubstitution>()); + candidate.subst.as<GenericSubstitution>()); break; default: @@ -7084,7 +7086,7 @@ namespace Slang } else { - // This is the only candidate worthe keeping track of right now + // This is the only candidate worth keeping track of right now context.bestCandidateStorage = candidate; context.bestCandidate = &context.bestCandidateStorage; } @@ -7207,30 +7209,30 @@ namespace Slang RefPtr<Val> snd) { // if both values are types, then unify types - if (auto fstType = fst.As<Type>()) + if (auto fstType = dynamicCast<Type>(fst)) { - if (auto sndType = snd.As<Type>()) + if (auto sndType = dynamicCast<Type>(snd)) { return TryUnifyTypes(constraints, fstType, sndType); } } // if both values are constant integers, then compare them - if (auto fstIntVal = fst.As<ConstantIntVal>()) + if (auto fstIntVal = dynamicCast<ConstantIntVal>(fst)) { - if (auto sndIntVal = snd.As<ConstantIntVal>()) + if (auto sndIntVal = dynamicCast<ConstantIntVal>(snd)) { return fstIntVal->value == sndIntVal->value; } } // Check if both are integer values in general - if (auto fstInt = fst.As<IntVal>()) + if (auto fstInt = as<IntVal>(fst)) { - if (auto sndInt = snd.As<IntVal>()) + if (auto sndInt = as<IntVal>(snd)) { - auto fstParam = fstInt.As<GenericParamIntVal>(); - auto sndParam = sndInt.As<GenericParamIntVal>(); + auto fstParam = as<GenericParamIntVal>(fstInt); + auto sndParam = as<GenericParamIntVal>(sndInt); bool okay = false; if (fstParam) @@ -7247,12 +7249,12 @@ namespace Slang } } - if (auto fstWit = fst.As<DeclaredSubtypeWitness>()) + if (auto fstWit = as<DeclaredSubtypeWitness>(fst)) { - if (auto sndWit = snd.As<DeclaredSubtypeWitness>()) + if (auto sndWit = as<DeclaredSubtypeWitness>(snd)) { - auto constraintDecl1 = fstWit->declRef.As<TypeConstraintDecl>(); - auto constraintDecl2 = sndWit->declRef.As<TypeConstraintDecl>(); + auto constraintDecl1 = fstWit->declRef.as<TypeConstraintDecl>(); + auto constraintDecl2 = sndWit->declRef.as<TypeConstraintDecl>(); SLANG_ASSERT(constraintDecl1); SLANG_ASSERT(constraintDecl2); return TryUnifyTypes(constraints, @@ -7276,9 +7278,9 @@ namespace Slang if (!fst || !snd) return !fst && !snd; - if(auto fstGeneric = fst.As<GenericSubstitution>()) + if(auto fstGeneric = as<GenericSubstitution>(fst)) { - if(auto sndGeneric = snd.As<GenericSubstitution>()) + if(auto sndGeneric = as<GenericSubstitution>(snd)) { return tryUnifyGenericSubstitutions( constraints, @@ -7371,7 +7373,7 @@ namespace Slang DeclRef<VarDeclBase> const& varRef, RefPtr<IntVal> val) { - if(auto genericValueParamRef = varRef.As<GenericValueParamDecl>()) + if(auto genericValueParamRef = varRef.as<GenericValueParamDecl>()) { return TryUnifyIntParam(constraints, RefPtr<GenericValueParamDecl>(genericValueParamRef.getDecl()), val); } @@ -7386,25 +7388,25 @@ namespace Slang RefPtr<Type> fst, RefPtr<Type> snd) { - if (auto fstDeclRefType = fst->As<DeclRefType>()) + if (auto fstDeclRefType = as<DeclRefType>(fst)) { auto fstDeclRef = fstDeclRefType->declRef; if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(fstDeclRef.getDecl())) return TryUnifyTypeParam(constraints, typeParamDecl, snd); - if (auto sndDeclRefType = snd->As<DeclRefType>()) + if (auto sndDeclRefType = as<DeclRefType>(snd)) { auto sndDeclRef = sndDeclRefType->declRef; if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(sndDeclRef.getDecl())) return TryUnifyTypeParam(constraints, typeParamDecl, fst); - // can't be unified if they refer to differnt declarations. + // can't be unified if they refer to different declarations. if (fstDeclRef.getDecl() != sndDeclRef.getDecl()) return false; // next we need to unify the substitutions applied - // to each decalration reference. + // to each declaration reference. if (!tryUnifySubstitutions( constraints, fstDeclRef.substitutions.substitutions, @@ -7429,17 +7431,17 @@ namespace Slang // An error type can unify with anything, just so we avoid cascading errors. - if (auto fstErrorType = fst->As<ErrorType>()) + if (auto fstErrorType = as<ErrorType>(fst)) return true; - if (auto sndErrorType = snd->As<ErrorType>()) + if (auto sndErrorType = as<ErrorType>(snd)) return true; // A generic parameter type can unify with anything. // TODO: there actually needs to be some kind of "occurs check" sort // of thing here... - if (auto fstDeclRefType = fst->As<DeclRefType>()) + if (auto fstDeclRefType = as<DeclRefType>(fst)) { auto fstDeclRef = fstDeclRefType->declRef; @@ -7450,7 +7452,7 @@ namespace Slang } } - if (auto sndDeclRefType = snd->As<DeclRefType>()) + if (auto sndDeclRefType = as<DeclRefType>(snd)) { auto sndDeclRef = sndDeclRefType->declRef; @@ -7470,9 +7472,9 @@ namespace Slang // in a completely ad hoc fashion, but eventually we'd // want to do it more formally. - if(auto fstVectorType = fst->As<VectorExpressionType>()) + if(auto fstVectorType = as<VectorExpressionType>(fst)) { - if(auto sndScalarType = snd->As<BasicExpressionType>()) + if(auto sndScalarType = as<BasicExpressionType>(snd)) { return TryUnifyTypes( constraints, @@ -7481,9 +7483,9 @@ namespace Slang } } - if(auto fstScalarType = fst->As<BasicExpressionType>()) + if(auto fstScalarType = as<BasicExpressionType>(fst)) { - if(auto sndVectorType = snd->As<VectorExpressionType>()) + if(auto sndVectorType = as<VectorExpressionType>(snd)) { return TryUnifyTypes( constraints, @@ -7505,7 +7507,7 @@ namespace Slang DeclRef<ExtensionDecl> extDeclRef = makeDeclRef(extDecl); // If the extension is a generic extension, then we - // need to infer type argumenst that will give + // need to infer type arguments that will give // us a target type that matches `type`. // if (auto extGenericDecl = GetOuterGeneric(extDecl)) @@ -7517,15 +7519,15 @@ namespace Slang if (!TryUnifyTypes(constraints, extDecl->targetType.Ptr(), type)) return DeclRef<ExtensionDecl>(); - auto constraintSubst = TrySolveConstraintSystem(&constraints, DeclRef<Decl>(extGenericDecl, nullptr).As<GenericDecl>()); + auto constraintSubst = TrySolveConstraintSystem(&constraints, DeclRef<Decl>(extGenericDecl, nullptr).as<GenericDecl>()); if (!constraintSubst) { return DeclRef<ExtensionDecl>(); } - // Consruct a reference to the extension with our constraint variables + // Construct a reference to the extension with our constraint variables // set as they were found by solving the constraint system. - extDeclRef = DeclRef<Decl>(extDecl, constraintSubst).As<ExtensionDecl>(); + extDeclRef = DeclRef<Decl>(extDecl, constraintSubst).as<ExtensionDecl>(); } // Now extract the target type from our (possibly specialized) extension decl-ref. @@ -7535,29 +7537,29 @@ namespace Slang // an interface, and the `type` we are trying to match up has a this-type // substitution for that interface, then we want to attach a matching // substitution to the extension decl-ref. - if(auto targetDeclRefType = targetType->As<DeclRefType>()) + if(auto targetDeclRefType = as<DeclRefType>(targetType)) { - if(auto targetInterfaceDeclRef = targetDeclRefType->declRef.As<InterfaceDecl>()) + if(auto targetInterfaceDeclRef = targetDeclRefType->declRef.as<InterfaceDecl>()) { // Okay, the target type is an interface. // // Is the type we want to apply to also an interface? - if(auto appDeclRefType = type->As<DeclRefType>()) + if(auto appDeclRefType = as<DeclRefType>(type)) { - if(auto appInterfaceDeclRef = appDeclRefType->declRef.As<InterfaceDecl>()) + if(auto appInterfaceDeclRef = appDeclRefType->declRef.as<InterfaceDecl>()) { if(appInterfaceDeclRef.getDecl() == targetInterfaceDeclRef.getDecl()) { // Looks like we have a match in the types, // now let's see if we have a this-type substitution. - if(auto appThisTypeSubst = appInterfaceDeclRef.substitutions.substitutions.As<ThisTypeSubstitution>()) + if(auto appThisTypeSubst = appInterfaceDeclRef.substitutions.substitutions.as<ThisTypeSubstitution>()) { if(appThisTypeSubst->interfaceDecl == appInterfaceDeclRef.getDecl()) { // The type we want to apply to has a this-type substitution, // and (by construction) the target type currently does not. // - SLANG_ASSERT(!targetInterfaceDeclRef.substitutions.substitutions.As<ThisTypeSubstitution>()); + SLANG_ASSERT(!targetInterfaceDeclRef.substitutions.substitutions.as<ThisTypeSubstitution>()); // We will create a new substitution to apply to the target type. RefPtr<ThisTypeSubstitution> newTargetSubst = new ThisTypeSubstitution(); @@ -7647,7 +7649,7 @@ namespace Slang // Check what type of declaration we are dealing with, and then try // to match it up with the arguments accordingly... - if (auto funcDeclRef = unspecializedInnerRef.As<CallableDecl>()) + if (auto funcDeclRef = unspecializedInnerRef.as<CallableDecl>()) { auto params = GetParameters(funcDeclRef).ToArray(); @@ -7733,13 +7735,13 @@ namespace Slang // for (auto genericDeclRef : getMembersOfType<GenericDecl>(aggTypeDeclRef)) { - if (auto ctorDecl = genericDeclRef.getDecl()->inner.As<ConstructorDecl>()) + if (auto ctorDecl = genericDeclRef.getDecl()->inner.as<ConstructorDecl>()) { DeclRef<Decl> innerRef = SpecializeGenericForOverload(genericDeclRef, context); if (!innerRef) continue; - DeclRef<ConstructorDecl> innerCtorRef = innerRef.As<ConstructorDecl>(); + DeclRef<ConstructorDecl> innerCtorRef = innerRef.as<ConstructorDecl>(); AddCtorOverloadCandidate(typeItem, type, innerCtorRef, context, resultType); } } @@ -7762,13 +7764,13 @@ namespace Slang // Also check for generic constructors for (auto genericDeclRef : getMembersOfType<GenericDecl>(extDeclRef)) { - if (auto ctorDecl = genericDeclRef.getDecl()->inner.As<ConstructorDecl>()) + if (auto ctorDecl = genericDeclRef.getDecl()->inner.as<ConstructorDecl>()) { DeclRef<Decl> innerRef = SpecializeGenericForOverload(genericDeclRef, context); if (!innerRef) continue; - DeclRef<ConstructorDecl> innerCtorRef = innerRef.As<ConstructorDecl>(); + DeclRef<ConstructorDecl> innerCtorRef = innerRef.as<ConstructorDecl>(); AddCtorOverloadCandidate(typeItem, type, innerCtorRef, context, resultType); @@ -7788,7 +7790,7 @@ namespace Slang // interfaces that the type must conform to. // We expect the parent of the generic type parameter to be a generic... - auto genericDeclRef = typeDeclRef.GetParent().As<GenericDecl>(); + auto genericDeclRef = typeDeclRef.GetParent().as<GenericDecl>(); SLANG_ASSERT(genericDeclRef); for(auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(genericDeclRef)) @@ -7799,7 +7801,7 @@ namespace Slang // generic parameter in question, and `Foo` is whatever we are // constraining it to. auto subType = GetSub(constraintDeclRef); - auto subDeclRefType = subType->As<DeclRefType>(); + auto subDeclRefType = as<DeclRefType>(subType); if(!subDeclRefType) continue; if(!subDeclRefType->declRef.Equals(typeDeclRef)) @@ -7821,14 +7823,14 @@ namespace Slang OverloadResolveContext& context, RefPtr<Type> resultType) { - if (auto declRefType = type->As<DeclRefType>()) + if (auto declRefType = as<DeclRefType>(type)) { auto declRef = declRefType->declRef; - if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>()) + if (auto aggTypeDeclRef = declRef.as<AggTypeDecl>()) { AddAggTypeOverloadCandidates(LookupResultItem(aggTypeDeclRef), type, aggTypeDeclRef, context, resultType); } - else if(auto genericTypeParamDeclRef = declRef.As<GenericTypeParamDecl>()) + else if(auto genericTypeParamDeclRef = declRef.as<GenericTypeParamDecl>()) { addGenericTypeParamOverloadCandidates( genericTypeParamDeclRef, @@ -7844,18 +7846,18 @@ namespace Slang { auto declRef = item.declRef; - if (auto funcDeclRef = item.declRef.As<CallableDecl>()) + if (auto funcDeclRef = item.declRef.as<CallableDecl>()) { AddFuncOverloadCandidate(item, funcDeclRef, context); } - else if (auto aggTypeDeclRef = item.declRef.As<AggTypeDecl>()) + else if (auto aggTypeDeclRef = item.declRef.as<AggTypeDecl>()) { auto type = DeclRefType::Create( getSession(), aggTypeDeclRef); AddAggTypeOverloadCandidates(item, type, aggTypeDeclRef, context, type); } - else if (auto genericDeclRef = item.declRef.As<GenericDecl>()) + else if (auto genericDeclRef = item.declRef.as<GenericDecl>()) { // Try to infer generic arguments, based on the context DeclRef<Decl> innerRef = SpecializeGenericForOverload(genericDeclRef, context); @@ -7884,12 +7886,12 @@ namespace Slang AddOverloadCandidateInner(context, candidate); } } - else if( auto typeDefDeclRef = item.declRef.As<TypeDefDecl>() ) + else if( auto typeDefDeclRef = item.declRef.as<TypeDefDecl>() ) { auto type = getNamedType(getSession(), typeDefDeclRef); AddTypeOverloadCandidates(GetType(typeDefDeclRef), context, type); } - else if( auto genericTypeParamDeclRef = item.declRef.As<GenericTypeParamDecl>() ) + else if( auto genericTypeParamDeclRef = item.declRef.as<GenericTypeParamDecl>() ) { auto type = DeclRefType::Create( getSession(), @@ -7908,19 +7910,19 @@ namespace Slang { auto funcExprType = funcExpr->type; - if (auto declRefExpr = funcExpr.As<DeclRefExpr>()) + if (auto declRefExpr = as<DeclRefExpr>(funcExpr)) { // The expression directly referenced a declaration, // so we can use that declaration directly to look // for anything applicable. AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context); } - else if (auto funcType = funcExprType.As<FuncType>()) + else if (auto funcType = as<FuncType>(funcExprType)) { // TODO(tfoley): deprecate this path... AddFuncOverloadCandidate(funcType, context); } - else if (auto overloadedExpr = funcExpr.As<OverloadedExpr>()) + else if (auto overloadedExpr = as<OverloadedExpr>(funcExpr)) { auto lookupResult = overloadedExpr->lookupResult2; SLANG_RELEASE_ASSERT(lookupResult.isOverloaded()); @@ -7929,14 +7931,14 @@ namespace Slang AddDeclRefOverloadCandidates(item, context); } } - else if (auto overloadedExpr2 = funcExpr.As<OverloadedExpr2>()) + else if (auto overloadedExpr2 = as<OverloadedExpr2>(funcExpr)) { for (auto item : overloadedExpr2->candidiateExprs) { AddOverloadCandidates(item, context); } } - else if (auto typeType = funcExprType.As<TypeType>()) + else if (auto typeType = as<TypeType>(funcExprType)) { // If none of the above cases matched, but we are // looking at a type, then I suppose we have @@ -7967,14 +7969,14 @@ namespace Slang // If the immediate parent is a generic, then we probably // want the declaration above that... - auto parentGenericDeclRef = parentDeclRef.As<GenericDecl>(); + auto parentGenericDeclRef = parentDeclRef.as<GenericDecl>(); if(parentGenericDeclRef) { parentDeclRef = parentGenericDeclRef.GetParent(); } // Depending on what the parent is, we may want to format things specially - if(auto aggTypeDeclRef = parentDeclRef.As<AggTypeDecl>()) + if(auto aggTypeDeclRef = parentDeclRef.as<AggTypeDecl>()) { formatDeclPath(sb, aggTypeDeclRef); sb << "."; @@ -7986,7 +7988,7 @@ namespace Slang // signature if( parentGenericDeclRef ) { - auto genSubst = declRef.substitutions.substitutions.As<GenericSubstitution>(); + auto genSubst = declRef.substitutions.substitutions.as<GenericSubstitution>(); SLANG_RELEASE_ASSERT(genSubst); SLANG_RELEASE_ASSERT(genSubst->genericDecl == parentGenericDeclRef.getDecl()); @@ -8004,7 +8006,7 @@ namespace Slang void formatDeclParams(StringBuilder& sb, DeclRef<Decl> declRef) { - if (auto funcDeclRef = declRef.As<CallableDecl>()) + if (auto funcDeclRef = declRef.as<CallableDecl>()) { // This is something callable, so we need to also print parameter types for overloading @@ -8023,20 +8025,20 @@ namespace Slang sb << ")"; } - else if(auto genericDeclRef = declRef.As<GenericDecl>()) + else if(auto genericDeclRef = declRef.as<GenericDecl>()) { sb << "<"; bool first = true; for (auto paramDeclRef : getMembers(genericDeclRef)) { - if(auto genericTypeParam = paramDeclRef.As<GenericTypeParamDecl>()) + if(auto genericTypeParam = paramDeclRef.as<GenericTypeParamDecl>()) { if (!first) sb << ", "; first = false; sb << getText(genericTypeParam.GetName()); } - else if(auto genericValParam = paramDeclRef.As<GenericValueParamDecl>()) + else if(auto genericValParam = paramDeclRef.as<GenericValueParamDecl>()) { if (!first) sb << ", "; first = false; @@ -8106,7 +8108,7 @@ namespace Slang bool shouldAddToCache = false; OperatorOverloadCacheKey key; TypeCheckingCache* typeCheckingCache = getSession()->getTypeCheckingCache(); - if (auto opExpr = expr->As<OperatorExpr>()) + if (auto opExpr = as<OperatorExpr>(expr)) { if (key.fromOperatorExpr(opExpr)) { @@ -8127,7 +8129,7 @@ namespace Slang auto funcExpr = expr->FunctionExpr; auto funcExprType = funcExpr->type; - // If we are trying to apply an erroroneous expression, then just bail out now. + // If we are trying to apply an erroneous expression, then just bail out now. if(IsErrorExpr(funcExpr)) { return CreateErrorExpr(expr); @@ -8148,15 +8150,15 @@ namespace Slang context.args = expr->Arguments.Buffer(); context.loc = expr->loc; - if (auto funcMemberExpr = funcExpr.As<MemberExpr>()) + if (auto funcMemberExpr = as<MemberExpr>(funcExpr)) { context.baseExpr = funcMemberExpr->BaseExpression; } - else if (auto funcOverloadExpr = funcExpr.As<OverloadedExpr>()) + else if (auto funcOverloadExpr = as<OverloadedExpr>(funcExpr)) { context.baseExpr = funcOverloadExpr->base; } - else if (auto funcOverloadExpr2 = funcExpr.As<OverloadedExpr2>()) + else if (auto funcOverloadExpr2 = as<OverloadedExpr2>(funcExpr)) { context.baseExpr = funcOverloadExpr2->base; } @@ -8185,16 +8187,16 @@ namespace Slang } Name* funcName = nullptr; - if (auto baseVar = funcExpr.As<VarExpr>()) + if (auto baseVar = as<VarExpr>(funcExpr)) funcName = baseVar->name; - else if(auto baseMemberRef = funcExpr.As<MemberExpr>()) + else if(auto baseMemberRef = as<MemberExpr>(funcExpr)) funcName = baseMemberRef->name; String argsList = getCallSignatureString(context); if (context.bestCandidates[0].status != OverloadCandidate::Status::Appicable) { - // There were multple equally-good candidates, but none actually usable. + // There were multiple equally-good candidates, but none actually usable. // We will construct a diagnostic message to help out. if (funcName) @@ -8280,7 +8282,7 @@ namespace Slang LookupResultItem baseItem, OverloadResolveContext& context) { - if (auto genericDeclRef = baseItem.declRef.As<GenericDecl>()) + if (auto genericDeclRef = baseItem.declRef.as<GenericDecl>()) { checkDecl(genericDeclRef.getDecl()); @@ -8297,12 +8299,12 @@ namespace Slang RefPtr<Expr> baseExpr, OverloadResolveContext& context) { - if(auto baseDeclRefExpr = baseExpr.As<DeclRefExpr>()) + if(auto baseDeclRefExpr = as<DeclRefExpr>(baseExpr)) { auto declRef = baseDeclRefExpr->declRef; AddGenericOverloadCandidate(LookupResultItem(declRef), context); } - else if (auto overloadedExpr = baseExpr.As<OverloadedExpr>()) + else if (auto overloadedExpr = as<OverloadedExpr>(baseExpr)) { // We are referring to a bunch of declarations, each of which might be generic LookupResult result; @@ -8364,7 +8366,7 @@ namespace Slang // Things were ambiguous. if (context.bestCandidates[0].status != OverloadCandidate::Status::Appicable) { - // There were multple equally-good candidates, but none actually usable. + // There were multiple equally-good candidates, but none actually usable. // We will construct a diagnostic message to help out. // TODO(tfoley): print a reasonable message here... @@ -8450,13 +8452,13 @@ namespace Slang if (auto invoke = dynamic_cast<InvokeExpr*>(rs.Ptr())) { // if this is still an invoke expression, test arguments passed to inout/out parameter are LValues - if(auto funcType = invoke->FunctionExpr->type->As<FuncType>()) + if(auto funcType = as<FuncType>(invoke->FunctionExpr->type)) { UInt paramCount = funcType->getParamCount(); for (UInt pp = 0; pp < paramCount; ++pp) { auto paramType = funcType->getParamType(pp); - if (paramType->As<OutTypeBase>() || paramType->As<RefType>()) + if (as<OutTypeBase>(paramType) || as<RefType>(paramType)) { // `out`, `inout`, and `ref` parameters currently require // an *exact* match on the type of the argument. @@ -8475,7 +8477,7 @@ namespace Slang Diagnostics::argumentExpectedLValue, pp); - if( auto implicitCastExpr = argExpr.As<ImplicitCastExpr>() ) + if( auto implicitCastExpr = as<ImplicitCastExpr>(argExpr) ) { getSink()->diagnose( argExpr, @@ -8614,7 +8616,7 @@ namespace Slang for (;;) { auto baseType = expr->type; - if (auto pointerLikeType = baseType->As<PointerLikeType>()) + if (auto pointerLikeType = as<PointerLikeType>(baseType)) { auto elementType = QualType(pointerLikeType->elementType); elementType.IsLeftValue = baseType.IsLeftValue; @@ -8731,7 +8733,7 @@ namespace Slang RefPtr<Type> baseElementType, RefPtr<IntVal> baseElementCount) { - if (auto constantElementCount = baseElementCount.As<ConstantIntVal>()) + if (auto constantElementCount = as<ConstantIntVal>(baseElementCount)) { return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->value); } @@ -8783,14 +8785,14 @@ namespace Slang // members via extension, for vector or scalar types. // // TODO: Matrix swizzles probably need to be handled at some point. - if (auto baseVecType = baseType->AsVectorType()) + if (auto baseVecType = as<VectorExpressionType>(baseType)) { return CheckSwizzleExpr( expr, baseVecType->elementType, baseVecType->elementCount); } - else if(auto baseScalarType = baseType->AsBasicType()) + else if(auto baseScalarType = as<BasicExpressionType>(baseType)) { // Treat scalar like a 1-element vector when swizzling return CheckSwizzleExpr( @@ -8798,7 +8800,7 @@ namespace Slang baseScalarType, 1); } - else if(auto typeType = baseType->As<TypeType>()) + else if(auto typeType = as<TypeType>(baseType)) { // We are looking up a member inside a type. // We want to be careful here because we should only find members @@ -8808,7 +8810,7 @@ namespace Slang // We need to fix that. auto type = typeType->type; - if (type->As<ErrorType>()) + if (as<ErrorType>(type)) { return CreateErrorExpr(expr); } @@ -8837,9 +8839,9 @@ namespace Slang // The biggest challenge there is that we'd need to arrange // to generate "dispatcher" functions that could be used // to implement that function, in the case where we are - // making a static reference to some kind of polymoprhic declaration. + // making a static reference to some kind of polymorphic declaration. // - // (Also, static refernces to fields/properties would get even + // (Also, static references to fields/properties would get even // harder, because you'd have to know whether a getter/setter/ref-er // is needed). // @@ -8914,7 +8916,7 @@ namespace Slang expr->BaseExpression, expr->loc); } - else if (baseType->As<ErrorType>()) + else if (as<ErrorType>(baseType)) { return CreateErrorExpr(expr); } @@ -9039,14 +9041,14 @@ namespace Slang { auto containerDecl = scope->containerDecl; - if( auto funcDeclBase = containerDecl->As<FunctionDeclBase>() ) + if( auto funcDeclBase = as<FunctionDeclBase>(containerDecl) ) { if( funcDeclBase->HasModifier<MutatingAttribute>() ) { expr->type.IsLeftValue = true; } } - else if (auto aggTypeDecl = containerDecl->As<AggTypeDecl>()) + else if (auto aggTypeDecl = as<AggTypeDecl>(containerDecl)) { checkDecl(aggTypeDecl); @@ -9058,7 +9060,7 @@ namespace Slang makeDeclRef(aggTypeDecl)); return expr; } - else if (auto extensionDecl = containerDecl->As<ExtensionDecl>()) + else if (auto extensionDecl = as<ExtensionDecl>(containerDecl)) { checkDecl(extensionDecl); @@ -9238,7 +9240,7 @@ namespace Slang } Expr* expr = attr->args[0]; - StringLiteralExpr* stringLit = expr->As<StringLiteralExpr>(); + StringLiteralExpr* stringLit = as<StringLiteralExpr>(expr); if (!stringLit) { @@ -9415,7 +9417,7 @@ namespace Slang // as a (top-level) argument for a generic type parameter, so that we // can check for them here and cache them on the entry point request. // - if( auto taggedUnionType = type->As<TaggedUnionType>() ) + if( auto taggedUnionType = as<TaggedUnionType>(type) ) { entryPoint->taggedUnionTypes.Add(taggedUnionType); } @@ -9426,7 +9428,7 @@ namespace Slang // validate global type arguments only when we are generating code if ((entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) { - // check that user-provioded type arguments conforms to the generic type + // check that user-provided type arguments conforms to the generic type // parameter declaration of this translation unit // collect global generic parameters from all imported modules @@ -9491,10 +9493,10 @@ namespace Slang // As a quick sanity check, see if the argument that is being supplied for a parameter // is just the parameter itself, because this should always be an error: // - if( auto argDeclRefType = globalGenericArg->As<DeclRefType>() ) + if( auto argDeclRefType = as<DeclRefType>(globalGenericArg) ) { auto argDeclRef = argDeclRefType->declRef; - if(auto argGenericParamDeclRef = argDeclRef.As<GlobalGenericParamDecl>()) + if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>()) { if(argGenericParamDeclRef.getDecl() == globalGenericParam) { @@ -9639,12 +9641,12 @@ namespace Slang for( auto globalDecl : translationUnit->SyntaxNode->Members ) { auto maybeFuncDecl = globalDecl; - if( auto genericDecl = maybeFuncDecl->As<GenericDecl>() ) + if( auto genericDecl = as<GenericDecl>(maybeFuncDecl) ) { maybeFuncDecl = genericDecl->inner; } - auto funcDecl = maybeFuncDecl->As<FuncDecl>(); + auto funcDecl = as<FuncDecl>(maybeFuncDecl); if(!funcDecl) continue; @@ -9718,7 +9720,7 @@ namespace Slang // We need to insert an appropriate type for the expression, based on // what we found. - if (auto varDeclRef = declRef.As<VarDeclBase>()) + if (auto varDeclRef = declRef.as<VarDeclBase>()) { QualType qualType; qualType.type = GetType(varDeclRef); @@ -9740,16 +9742,16 @@ namespace Slang isLValue = false; // Variables declared with `let` are always immutable. - if(varDeclRef.As<LetDecl>()) + if(varDeclRef.as<LetDecl>()) isLValue = false; // Generic value parameters are always immutable - if(varDeclRef.As<GenericValueParamDecl>()) + if(varDeclRef.as<GenericValueParamDecl>()) isLValue = false; // Function parameters declared in the "modern" style // are immutable unless they have an `out` or `inout` modifier. - if( varDeclRef.As<ModernParamDecl>() ) + if( varDeclRef.as<ModernParamDecl>() ) { // Note: the `inout` modifier AST class inherits from // the class for the `out` modifier so that we can @@ -9764,43 +9766,43 @@ namespace Slang qualType.IsLeftValue = isLValue; return qualType; } - else if( auto enumCaseDeclRef = declRef.As<EnumCaseDecl>() ) + else if( auto enumCaseDeclRef = declRef.as<EnumCaseDecl>() ) { QualType qualType; qualType.type = getType(enumCaseDeclRef); qualType.IsLeftValue = false; return qualType; } - else if (auto typeAliasDeclRef = declRef.As<TypeDefDecl>()) + else if (auto typeAliasDeclRef = declRef.as<TypeDefDecl>()) { auto type = getNamedType(session, typeAliasDeclRef); *outTypeResult = type; return QualType(getTypeType(type)); } - else if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>()) + else if (auto aggTypeDeclRef = declRef.as<AggTypeDecl>()) { auto type = DeclRefType::Create(session, aggTypeDeclRef); *outTypeResult = type; return QualType(getTypeType(type)); } - else if (auto simpleTypeDeclRef = declRef.As<SimpleTypeDecl>()) + else if (auto simpleTypeDeclRef = declRef.as<SimpleTypeDecl>()) { auto type = DeclRefType::Create(session, simpleTypeDeclRef); *outTypeResult = type; return QualType(getTypeType(type)); } - else if (auto genericDeclRef = declRef.As<GenericDecl>()) + else if (auto genericDeclRef = declRef.as<GenericDecl>()) { auto type = getGenericDeclRefType(session, genericDeclRef); *outTypeResult = type; return QualType(getTypeType(type)); } - else if (auto funcDeclRef = declRef.As<CallableDecl>()) + else if (auto funcDeclRef = declRef.as<CallableDecl>()) { auto type = getFuncType(session, funcDeclRef); return QualType(type); } - else if (auto constraintDeclRef = declRef.As<TypeConstraintDecl>()) + else if (auto constraintDeclRef = declRef.as<TypeConstraintDecl>()) { // When we access a constraint or an inheritance decl (as a member), // we are conceptually performing a "cast" to the given super-type, @@ -9845,23 +9847,23 @@ namespace Slang for( auto mm : genericDecl->Members ) { - if( auto genericTypeParamDecl = mm.As<GenericTypeParamDecl>() ) + if( auto genericTypeParamDecl = as<GenericTypeParamDecl>(mm) ) { - genericSubst->args.Add(DeclRefType::Create(session, DeclRef<Decl>(genericTypeParamDecl.Ptr(), outerSubst))); + genericSubst->args.Add(DeclRefType::Create(session, DeclRef<Decl>(genericTypeParamDecl, outerSubst))); } - else if( auto genericValueParamDecl = mm.As<GenericValueParamDecl>() ) + else if( auto genericValueParamDecl = as<GenericValueParamDecl>(mm) ) { - genericSubst->args.Add(new GenericParamIntVal(DeclRef<GenericValueParamDecl>(genericValueParamDecl.Ptr(), outerSubst))); + genericSubst->args.Add(new GenericParamIntVal(DeclRef<GenericValueParamDecl>(genericValueParamDecl, outerSubst))); } } // create default substitution arguments for constraints for (auto mm : genericDecl->Members) { - if (auto genericTypeConstraintDecl = mm.As<GenericTypeConstraintDecl>()) + if (auto genericTypeConstraintDecl = as<GenericTypeConstraintDecl>(mm)) { RefPtr<DeclaredSubtypeWitness> witness = new DeclaredSubtypeWitness(); - witness->declRef = DeclRef<Decl>(genericTypeConstraintDecl.Ptr(), outerSubst); + witness->declRef = DeclRef<Decl>(genericTypeConstraintDecl, outerSubst); witness->sub = genericTypeConstraintDecl->sub.type; witness->sup = genericTypeConstraintDecl->sup.type; genericSubst->args.Add(witness); diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 887a62974..224fa3a28 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -3,6 +3,7 @@ #include "../core/slang-writer.h" #include "ir-dce.h" +#include "ir-entry-point-uniforms.h" #include "ir-glsl-legalize.h" #include "ir-insts.h" #include "ir-link.h" @@ -4237,11 +4238,11 @@ struct EmitVisitor if(auto layoutDecoration = inst->findDecoration<IRLayoutDecoration>()) { auto layout = layoutDecoration->getLayout(); - if(auto varLayout = layout->dynamicCast<VarLayout>()) + if(auto varLayout = dynamicCast<VarLayout>(layout)) { emitIRSemantics(ctx, varLayout); } - else if (auto entryPointLayout = layout->dynamicCast<EntryPointLayout>()) + else if (auto entryPointLayout = dynamicCast<EntryPointLayout>(layout)) { if(auto resultLayout = entryPointLayout->resultLayout) { @@ -4603,7 +4604,7 @@ struct EmitVisitor Expr* expr = attrib->args[0]; - auto stringLitExpr = expr->As<StringLiteralExpr>(); + auto stringLitExpr = as<StringLiteralExpr>(expr); if (!stringLitExpr) { SLANG_DIAGNOSE_UNEXPECTED(getSink(), entryPoint->loc, "Attribute parameter expecting to be a string "); @@ -4630,7 +4631,7 @@ struct EmitVisitor Expr* expr = attrib->args[0]; - auto intLitExpr = expr->As<IntegerLiteralExpr>(); + auto intLitExpr = as<IntegerLiteralExpr>(expr); if (!intLitExpr) { SLANG_DIAGNOSE_UNEXPECTED(getSink(), entryPoint->loc, "Attribute expects an int"); @@ -4840,39 +4841,39 @@ struct EmitVisitor { if(auto inputPrimitiveTypeModifier = pp->FindModifier<HLSLGeometryShaderInputPrimitiveTypeModifier>()) { - if(inputPrimitiveTypeModifier->As<HLSLTriangleModifier>()) + if(as<HLSLTriangleModifier>(inputPrimitiveTypeModifier)) { emit("layout(triangles) in;\n"); } - else if(inputPrimitiveTypeModifier->As<HLSLLineModifier>()) + else if(as<HLSLLineModifier>(inputPrimitiveTypeModifier)) { emit("layout(lines) in;\n"); } - else if(inputPrimitiveTypeModifier->As<HLSLLineAdjModifier>()) + else if(as<HLSLLineAdjModifier>(inputPrimitiveTypeModifier)) { emit("layout(lines_adjacency) in;\n"); } - else if(inputPrimitiveTypeModifier->As<HLSLPointModifier>()) + else if(as<HLSLPointModifier>(inputPrimitiveTypeModifier)) { emit("layout(points) in;\n"); } - else if(inputPrimitiveTypeModifier->As<HLSLTriangleAdjModifier>()) + else if(as<HLSLTriangleAdjModifier>(inputPrimitiveTypeModifier)) { emit("layout(triangles_adjacency) in;\n"); } } - if(auto outputStreamType = pp->type->As<HLSLStreamOutputType>()) + if(auto outputStreamType = as<HLSLStreamOutputType>(pp->type)) { - if(outputStreamType->As<HLSLTriangleStreamType>()) + if(as<HLSLTriangleStreamType>(outputStreamType)) { emit("layout(triangle_strip) out;\n"); } - else if(outputStreamType->As<HLSLLineStreamType>()) + else if(as<HLSLLineStreamType>(outputStreamType)) { emit("layout(line_strip) out;\n"); } - else if(outputStreamType->As<HLSLPointStreamType>()) + else if(as<HLSLPointStreamType>(outputStreamType)) { emit("layout(points) out;\n"); } @@ -5158,7 +5159,7 @@ struct EmitVisitor { if( auto layoutDecoration = func->findDecoration<IRLayoutDecoration>() ) { - return layoutDecoration->getLayout()->dynamicCast<EntryPointLayout>(); + return dynamicCast<EntryPointLayout>(layoutDecoration->getLayout()); } return nullptr; } @@ -5167,7 +5168,7 @@ struct EmitVisitor { if (auto layoutDecoration = func->findDecoration<IRLayoutDecoration>()) { - if (auto entryPointLayout = layoutDecoration->getLayout()->dynamicCast<EntryPointLayout>()) + if (auto entryPointLayout = dynamicCast<EntryPointLayout>(layoutDecoration->getLayout())) { return entryPointLayout; } @@ -5291,10 +5292,10 @@ struct EmitVisitor // auto typeLayout = layout->typeLayout; - while(auto arrayTypeLayout = typeLayout.As<ArrayTypeLayout>()) + while(auto arrayTypeLayout = as<ArrayTypeLayout>(typeLayout)) typeLayout = arrayTypeLayout->elementTypeLayout; - if (auto matrixTypeLayout = typeLayout.As<MatrixTypeLayout>()) + if (auto matrixTypeLayout = typeLayout.as<MatrixTypeLayout>()) { auto target = ctx->shared->target; @@ -5707,7 +5708,7 @@ struct EmitVisitor EmitVarChain elementChain = blockChain; auto typeLayout = varLayout->typeLayout; - if( auto parameterGroupTypeLayout = typeLayout.As<ParameterGroupTypeLayout>() ) + if( auto parameterGroupTypeLayout = as<ParameterGroupTypeLayout>(typeLayout) ) { containerChain = EmitVarChain(parameterGroupTypeLayout->containerVarLayout, &blockChain); elementChain = EmitVarChain(parameterGroupTypeLayout->elementVarLayout, &blockChain); @@ -5797,7 +5798,7 @@ struct EmitVisitor EmitVarChain elementChain = blockChain; auto typeLayout = varLayout->typeLayout->unwrapArray(); - if( auto parameterGroupTypeLayout = typeLayout.As<ParameterGroupTypeLayout>() ) + if( auto parameterGroupTypeLayout = as<ParameterGroupTypeLayout>(typeLayout) ) { containerChain = EmitVarChain(parameterGroupTypeLayout->containerVarLayout, &blockChain); elementChain = EmitVarChain(parameterGroupTypeLayout->elementVarLayout, &blockChain); @@ -6511,47 +6512,34 @@ EntryPointLayout* findEntryPointLayout( return nullptr; } -// Given a layout computed for a whole program, find -// the corresponding layout to use when looking up -// variables at the global scope. -// -// It might be that the global scope was logically -// mapped to a constant buffer, so that we need -// to "unwrap" that declaration to get at the -// actual struct type inside. -StructTypeLayout* getGlobalStructLayout( - ProgramLayout* programLayout) + /// Given a layout computed for a scope, get the layout to use when lookup up variables. + /// + /// A scope (such as the global scope of a program) groups its + /// parameters into a pseudo-`struct` type for layout purposes, + /// and in some cases that type will in turn be wrapped in a + /// `ConstantBuffer` type to indicate that the parameters needed + /// an implicit constant buffer to be allocated. + /// + /// This function "unwraps" the type layout to find the structure + /// type layout that must be stored inside. + /// +StructTypeLayout* getScopeStructLayout( + ScopeLayout* scopeLayout) { - auto globalScopeLayout = programLayout->globalScopeLayout->typeLayout; - if( auto gs = globalScopeLayout.As<StructTypeLayout>() ) + auto scopeTypeLayout = scopeLayout->parametersLayout->typeLayout; + if( auto structTypeLayout = as<StructTypeLayout>(scopeTypeLayout) ) { - return gs.Ptr(); + return structTypeLayout; } - else if( auto globalConstantBufferLayout = globalScopeLayout.As<ParameterGroupTypeLayout>() ) + else if( auto constantBufferTypeLayout = as<ParameterGroupTypeLayout>(scopeTypeLayout) ) { - // TODO: the `cbuffer` case really needs to be emitted very - // carefully, but that is beyond the scope of what a simple rewriter - // can easily do (without semantic analysis, etc.). - // - // The crux of the problem is that we need to collect all the - // global-scope uniforms (but not declarations that don't involve - // uniform storage...) and put them in a single `cbuffer` declaration, - // so that we can give it an explicit location. The fields in that - // declaration might use various type declarations, so we'd really - // need to emit all the type declarations first, and that involves - // some large scale reorderings. - // - // For now we will punt and just emit the declarations normally, - // and hope that the global-scope block (`$Globals`) gets auto-assigned - // the same location that we manually asigned it. - - auto elementTypeLayout = globalConstantBufferLayout->offsetElementTypeLayout; - auto elementTypeStructLayout = elementTypeLayout.As<StructTypeLayout>(); + auto elementTypeLayout = constantBufferTypeLayout->offsetElementTypeLayout; + auto elementTypeStructLayout = as<StructTypeLayout>(elementTypeLayout); // We expect all constant buffers to contain `struct` types for now SLANG_RELEASE_ASSERT(elementTypeStructLayout); - return elementTypeStructLayout.Ptr(); + return elementTypeStructLayout; } else { @@ -6560,6 +6548,16 @@ StructTypeLayout* getGlobalStructLayout( } } + /// Given a layout computed for a program, get the layout to use when lookup up variables. + /// + /// This is just an alias of `getScopeStructLayout`. + /// +StructTypeLayout* getGlobalStructLayout( + ProgramLayout* programLayout) +{ + return getScopeStructLayout(programLayout); +} + void legalizeTypes( TypeLegalizationContext* context, IRModule* module); @@ -6657,6 +6655,18 @@ String emitEntryPoint( // un-specialized IR. dumpIRIfEnabled(compileRequest, irModule); + // Now that we've linked the IR code, any layout/binding + // information has been attached to shader parameters + // and entry points. Now we are safe to make transformations + // that might move code without worrying about losing + // the connection between a parameter and its layout. + // + // An easy transformation of this kind is to take uniform + // parameters of a shader entry point and move them into + // the global scope instead. + // + moveEntryPointUniformParamsToGlobalScope(irModule); + // Desguar any union types, since these will be illegal on // various targets. // diff --git a/source/slang/ir-entry-point-uniforms.cpp b/source/slang/ir-entry-point-uniforms.cpp new file mode 100644 index 000000000..64deec1c5 --- /dev/null +++ b/source/slang/ir-entry-point-uniforms.cpp @@ -0,0 +1,423 @@ +// ir-entry-point-uniforms.cpp +#include "ir-entry-point-uniforms.h" + +#include "ir.h" +#include "ir-insts.h" + +#include "mangle.h" + +namespace Slang +{ + + +// The transformation in this file will solve the problem of taking +// code like the following: +// +// float4 fragmentMain( +// uniform Texture2D t, +// uniform SamplerState s; +// uniform float4 c, +// float2 uv : UV) : SV_Target +// { +// return t.Sample(s, uv) + c; +// } +// +// and transforming into code like this: +// +// struct Params +// { +// Texture2D t; +// SamplerState s; +// float4 c; +// } +// ConstantBuffer<Params> params; +// +// float4 fragmentMain( +// float2 uv : UV) : SV_Target +// { +// return params.t.Sample(params.s, uv) + params.c; +// } +// +// As can be seen in this example, the `uniform` parameters +// declared as entry point parameters have been moved into +// a `struct` declaration that we then use to declare a global +// shader parameter that is a `ConstantBuffer`. We then +// rewrite references to those parameters to refer to the +// contents of the new constant buffer instead. +// +// We perform this transformation after the target-specific +// linking step, because that will have attached layout information +// to the entry point and its parameters. We need that layout +// information so that we can: +// +// * Identify which parameters are uniform vs. varying. +// * Have an appropriate layout to attached to the synthesized +// global shader parameter `params`. +// +// One additional wrinkle this pass has to deal with is that +// in the case where the shader doesn't have any "ordinary" +// uniform parameters like `c` (e.g., it only has resource/object +// parameters), we do *not* wrap the parameter `struct` in +// a `ConstantBuffer`. For example, suppose we have: +// +// float4 fragmentMain( +// uniform Texture2D t, +// uniform SamplerState s; +// float2 uv : UV) : SV_Target +// { +// return t.Sample(s, uv); +// } +// +// In this case the output of the transformation shold be: +// +// struct Params +// { +// Texture2D t; +// SamplerState s; +// } +// Params params; +// +// float4 fragmentMain( +// float2 uv : UV) : SV_Target +// { +// return params.t.Sample(params.s, uv) + params.c; +// } +// +// Note that this pass should always come before type legalization, +// which will take responsibility for turning a variable like +// `params` above into individual variables for the `t` and +// `s` fields. + +// The overall structure here is similar to many other IR passes. +// We define a "context" structure to encapsulate the pass. +// +struct MoveEntryPointUniformParametersToGlobalScope +{ + // We'll hang on to the module we are processing, + // so that we can refer to it when setting up `IRBuilder`s. + // + IRModule* module; + + // We will process a whole module by visiting all + // its global functions, looking for entry points. + // + void processModule() + { + // Note that we are only looking at true global-scope + // functions and not functions nested inside of + // IR generics. When using generic entry points, this + // pass should be run after the entry point(s) have + // been specialized to their generic type parameters. + + for( auto inst : module->getGlobalInsts() ) + { + // We are only interested in entry points. + // + // Every entry point must be a function. + // + auto func = as<IRFunc>(inst); + if( !func ) + continue; + + // Entry points will always have the `[entryPoint]` + // decoration to differentiate them from ordinary + // functions. + // + // TODO: we could make `IREntryPoint` a subclass of + // `IRFunc` if desired, to avoid having to attach + // an explicit decoration to identify them. + // + if( !func->findDecorationImpl(kIROp_EntryPointDecoration) ) + continue; + + // If we fine a candidate entry point, then we + // will process it. + // + processEntryPoint(func); + } + } + + void processEntryPoint(IRFunc* func) + { + // We expect all entry points to have explicit layout information attached. + // + // We will assert that we have the information we need, but try to be + // defensive and bail out in the failure case in release builds. + // + auto funcLayoutDecoration = func->findDecoration<IRLayoutDecoration>(); + SLANG_ASSERT(funcLayoutDecoration); + if(!funcLayoutDecoration) + return; + + auto entryPointLayout = dynamic_cast<EntryPointLayout*>(funcLayoutDecoration->getLayout()); + SLANG_ASSERT(entryPointLayout); + if(!entryPointLayout) + return; + + // The parameter layout for an entry point will either be a structure + // type layout, or a constant buffer (a case of parameter group) + // wrapped around such a structure. + // + // If we are in the latter case we will need to make sure to allocate + // an explicit IR constant buffer for that wrapper, + // + auto entryPointParamsLayout = entryPointLayout->parametersLayout; + bool needConstantBuffer = entryPointParamsLayout->typeLayout.as<ParameterGroupTypeLayout>() != nullptr; + + // We will set up an IR builder so that we are ready to generate code. + // + SharedIRBuilder sharedBuilderStorage; + auto sharedBuilder = &sharedBuilderStorage; + sharedBuilder->module = module; + sharedBuilder->session = module->getSession(); + + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = sharedBuilder; + + // *If* the entry point has any uniform parameter then we want to create a + // structure type to house them, and a global shader parameter (either + // an instance of that type or a constant buffer). + // + // We only want to create these if actually needed, so we will declare + // them here and then initialize them on-demand. + // + IRStructType* paramStructType = nullptr; + IRGlobalParam* globalParam = nullptr; + + // We will be removing any uniform parameters we run into, so we + // need to iterate the parameter list carefully to deal with + // us modifying it along the way. + // + IRParam* nextParam = nullptr; + for( IRParam* param = func->getFirstParam(); param; param = nextParam ) + { + nextParam = param->getNextParam(); + + // We expect all entry-point parameters to have layout information, + // but we will be defensive and skip parameters without the required + // information when we are in a release build. + // + auto layoutDecoration = param->findDecoration<IRLayoutDecoration>(); + SLANG_ASSERT(layoutDecoration); + if(!layoutDecoration) + continue; + auto paramLayout = dynamic_cast<VarLayout*>(layoutDecoration->getLayout()); + SLANG_ASSERT(paramLayout); + if(!paramLayout) + continue; + + // A parameter that has varying input/output behavior should be left alone, + // since this pass is only supposed to apply to uniform (non-varying) + // parameters. + // + if(isVaryingParameter(paramLayout)) + continue; + + // At this point we know that `param` is not a varying shader parameter, + // so that we want to turn it into an equivalent global shader parameter. + // + // If this is the first parameter we are running into, then we need + // to deal with creating the structure type and global shader + // parameter that our transformed entry point will use. + // + if( !paramStructType ) + { + // First we create the structure to hold the parameters. + // + builder->setInsertBefore(func); + paramStructType = builder->createStructType(); + + if( needConstantBuffer ) + { + // If we need a constant buffer, then the global + // shader parameter will be a `ConstantBuffer<paramStructType>` + // + auto constantBufferType = builder->getConstantBufferType(paramStructType); + globalParam = builder->createGlobalParam(constantBufferType); + } + else + { + // Otherwise, the global shader parameter is just + // an instance of `paramStructType`. + // + globalParam = builder->createGlobalParam(paramStructType); + } + + // No matter what, the global shader parameter should have the layout + // information from the entry point attached to it, so that the + // contained parameters will end up in the right place(s). + // + builder->addLayoutDecoration(globalParam, entryPointParamsLayout); + } + + // Now that we've ensured the global `struct` type and shader paramter + // exist, we need to add a field to the `struct` to represent the + // current parameter. + // + + auto paramType = param->getFullType(); + + builder->setInsertBefore(paramStructType); + auto paramFieldKey = builder->createStructKey(); + auto paramField = builder->createStructField(paramStructType, paramFieldKey, paramType); + SLANG_UNUSED(paramField); + + // We will transfer all decorations on the parameter over to the key + // so that they can affect downstream emit logic. + // + // TODO: We should double-check whether any of the decorations should + // be moved to the *field* instead. + // + param->transferDecorationsTo(paramFieldKey); + + // There is a bit of a hacky issue, where downstream passes (notably + // type legalization) require the field keys for `struct` types to + // have mangled names, because those mangled names will be used to + // lookup field layout information inside of the layout information + // for the `struct` type. + // + // TODO: We should fix that design choice in how layout information + // is stored, to avoid the reliance on name strings. + // + builder->addExportDecoration(paramFieldKey, getMangledName(paramLayout->varDecl).getUnownedSlice()); + + // At this point we want to eliminate the original entry point + // parameter, in favor of the `struct` field we declared. + // That required replacing any uses of the parameter with + // appropriate code to pull out the field. + // + // We *could* extract the field at the start of the shader + // and then do a `replaceAllUsesWith` to propragate it + // down, but in practice we expect that it is better for + // performance to "rematerialize" the value of a shader + // parameter as close to where it is used as possible. + // + // We are therefore going to replace the uses one at a time. + // + while(auto use = param->firstUse ) + { + // Given a `use` of the paramter, we will insert + // the replacement code right before the instruction + // that is doing the using. + // + builder->setInsertBefore(use->getUser()); + + // The way to extract the field that corresponds + // to the parameter depends on whether or not + // we generated a constant buffer. + // + IRInst* fieldVal = nullptr; + if( needConstantBuffer ) + { + // A constant buffer behaves like a pointer + // at the IR level, so we first do a pointer + // offset operation to compute what amounts + // to `&cb->field`, and then load from that address. + // + auto fieldAddress = builder->emitFieldAddress( + builder->getPtrType(paramType), + globalParam, + paramFieldKey); + fieldVal = builder->emitLoad(fieldAddress); + } + else + { + // In the ordinary struct case, the parameter + // has an ordinary `struct` type (not a pointer), + // so we just extract the field directly. + // + fieldVal = builder->emitFieldExtract( + paramType, + globalParam, + paramFieldKey); + } + + // We replace the value used at this use site, which + // will have a side effect of making `use` no longer + // be on the list of uses for `param`, so that when + // we get back to the top of the loop the list of + // uses will be shorter. + // + use->set(fieldVal); + } + + // Once we've replaced all the uses of `param`, we + // can go ahead and remove it completely. + // + param->removeAndDeallocate(); + } + } + + // We need to be able to determine if a parameter is logically + // a "varying" parameter based on its layout. + // + bool isVaryingParameter(VarLayout* layout) + { + // If *any* of the resources consumed by the parameter + // is a varying resource kind (e.g., varying input) then + // we consider the whole parameter to be varying. + // + // This is reasonable because there is no way to declare + // a parameter that mixes varying and non-varying fields. + // + for( auto resInfo : layout->resourceInfos ) + { + if(isVaryingResourceKind(resInfo.kind)) + return true; + } + + // Varying parameters with "system value" semantics currently show up as + // consuming no resources, so we need to special-case that here. + // + // Note: an empty `struct` parameter would also show up the same way, but + // we should eliminate any such parameters later on during type legalization. + // + if(layout->resourceInfos.Count() == 0) + return true; + + // if none of the above tests determined that the + // parameter was varying, then we can safely consider + // it to be non-varying (uniform): + return false; + } + + // In order to determine whether a parameter is varying based on its + // layout, we need to know which resource kinds represent varying + // shader parameters. + // + bool isVaryingResourceKind(LayoutResourceKind kind) + { + switch( kind ) + { + default: + return false; + + // Note: The set of cases that are considered + // varying here would need to be extended if we + // add more fine-grained resource kinds (e.g., + // if we ever add an explicit resource kind + // for geometry shader output streams). + // + // Ordinary varying input/output: + case LayoutResourceKind::VaryingInput: + case LayoutResourceKind::VaryingOutput: + // + // Ray-tracing shader input/output: + case LayoutResourceKind::CallablePayload: + case LayoutResourceKind::HitAttributes: + case LayoutResourceKind::RayPayload: + return true; + } + } +}; + +void moveEntryPointUniformParamsToGlobalScope( + IRModule* module) +{ + MoveEntryPointUniformParametersToGlobalScope context; + context.module = module; + context.processModule(); +} + +} diff --git a/source/slang/ir-entry-point-uniforms.h b/source/slang/ir-entry-point-uniforms.h new file mode 100644 index 000000000..5fcfab167 --- /dev/null +++ b/source/slang/ir-entry-point-uniforms.h @@ -0,0 +1,12 @@ +// ir-entry-point-uniform.h +#pragma once + +namespace Slang +{ +struct IRModule; + + /// Move any uniform parameters of entry points to the global scope instead. +void moveEntryPointUniformParamsToGlobalScope( + IRModule* module); + +} diff --git a/source/slang/ir-glsl-legalize.cpp b/source/slang/ir-glsl-legalize.cpp index d2b696c5b..7dc88a0fa 100644 --- a/source/slang/ir-glsl-legalize.cpp +++ b/source/slang/ir-glsl-legalize.cpp @@ -744,7 +744,7 @@ ScalarizedVal extractField( case ScalarizedVal::Flavor::tuple: { - auto tupleVal = val.impl.As<ScalarizedTupleValImpl>(); + auto tupleVal = as<ScalarizedTupleValImpl>(val.impl); return tupleVal->elements[fieldIndex].val; } @@ -821,7 +821,7 @@ void assign( // We are assigning from a tuple to a destination // that is not a tuple. We will perform assignment // element-by-element. - auto rightTupleVal = right.impl.As<ScalarizedTupleValImpl>(); + auto rightTupleVal = as<ScalarizedTupleValImpl>(right.impl); UInt elementCount = rightTupleVal->elements.Count(); for( UInt ee = 0; ee < elementCount; ++ee ) @@ -847,7 +847,7 @@ void assign( { // We have a tuple, so we are going to need to try and assign // to each of its constituent fields. - auto leftTupleVal = left.impl.As<ScalarizedTupleValImpl>(); + auto leftTupleVal = as<ScalarizedTupleValImpl>(left.impl); UInt elementCount = leftTupleVal->elements.Count(); for( UInt ee = 0; ee < elementCount; ++ee ) @@ -869,7 +869,7 @@ void assign( // // In this case we are converting to the actual type of the GLSL variable, // from the "pretend" type that it had in the IR before. - auto typeAdapter = left.impl.As<ScalarizedTypeAdapterValImpl>(); + auto typeAdapter = as<ScalarizedTypeAdapterValImpl>(left.impl); auto adaptedRight = adaptType(builder, right, typeAdapter->actualType, typeAdapter->pretendType); assign(builder, typeAdapter->val, adaptedRight); } @@ -905,7 +905,7 @@ ScalarizedVal getSubscriptVal( case ScalarizedVal::Flavor::tuple: { - auto inputTuple = val.impl.As<ScalarizedTupleValImpl>(); + auto inputTuple = val.impl.as<ScalarizedTupleValImpl>(); RefPtr<ScalarizedTupleValImpl> resultTuple = new ScalarizedTupleValImpl(); resultTuple->type = elementType; @@ -967,7 +967,7 @@ IRInst* materializeTupleValue( IRBuilder* builder, ScalarizedVal val) { - auto tupleVal = val.impl.As<ScalarizedTupleValImpl>(); + auto tupleVal = val.impl.as<ScalarizedTupleValImpl>(); SLANG_ASSERT(tupleVal); UInt elementCount = tupleVal->elements.Count(); @@ -1044,7 +1044,7 @@ IRInst* materializeValue( case ScalarizedVal::Flavor::tuple: { - auto tupleVal = val.impl.As<ScalarizedTupleValImpl>(); + //auto tupleVal = as<ScalarizedTupleValImpl>(val.impl); return materializeTupleValue(builder, val); } break; @@ -1055,7 +1055,7 @@ IRInst* materializeValue( // doesn't match the type it pretends to have. To make this // work we need to adapt the type from its actual type over // to its pretend type. - auto typeAdapter = val.impl.As<ScalarizedTypeAdapterValImpl>(); + auto typeAdapter = as<ScalarizedTypeAdapterValImpl>(val.impl); auto adapted = adaptType(builder, typeAdapter->val, typeAdapter->pretendType, typeAdapter->actualType); return materializeValue(builder, adapted); } @@ -1516,18 +1516,16 @@ void legalizeEntryPointForGLSL( // to be at the start of the "ordinary" instructions in the block: builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - UInt paramCounter = 0; for( auto pp = firstBlock->getFirstParam(); pp; pp = pp->getNextParam() ) { - UInt paramIndex = paramCounter++; - - // We assume that the entry-point layout includes information - // on each parameter, and that these arrays are kept aligned. - // Note that this means that any transformations that mess - // with function signatures will need to also update layout info... + // We assume that the entry-point parameters will all have + // layout information attached to them, which is kept up-to-date + // by any transformations affecting the parameter list. // - SLANG_ASSERT(entryPointLayout->fields.Count() > paramIndex); - auto paramLayout = entryPointLayout->fields[paramIndex]; + auto paramLayoutDecoration = pp->findDecoration<IRLayoutDecoration>(); + SLANG_ASSERT(paramLayoutDecoration); + auto paramLayout = dynamic_cast<VarLayout*>(paramLayoutDecoration->getLayout()); + SLANG_ASSERT(paramLayout); legalizeEntryPointParameterForGLSL( &context, diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index 8662569ba..6b12612ef 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -759,6 +759,9 @@ struct IRBuilder return getFuncType(paramTypes.Count(), paramTypes.Buffer(), resultType); } + IRConstantBufferType* getConstantBufferType( + IRType* elementType); + IRConstExprRate* getConstExprRate(); IRGroupSharedRate* getGroupSharedRate(); diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp index 141799a06..eb22da967 100644 --- a/source/slang/ir-legalize-types.cpp +++ b/source/slang/ir-legalize-types.cpp @@ -71,7 +71,7 @@ LegalVal LegalVal::implicitDeref(LegalVal const& val) LegalVal LegalVal::getImplicitDeref() { SLANG_ASSERT(flavor == Flavor::implicitDeref); - return obj.As<ImplicitDerefVal>()->val; + return as<ImplicitDerefVal>(obj)->val; } @@ -1017,7 +1017,7 @@ static LegalVal legalizeInst( RefPtr<VarLayout> findVarLayout(IRInst* value) { if (auto layoutDecoration = value->findDecoration<IRLayoutDecoration>()) - return layoutDecoration->getLayout()->dynamicCast<VarLayout>(); + return dynamicCast<VarLayout>(layoutDecoration->getLayout()); return nullptr; } diff --git a/source/slang/ir-link.cpp b/source/slang/ir-link.cpp index 610658d51..dba4fc2d1 100644 --- a/source/slang/ir-link.cpp +++ b/source/slang/ir-link.cpp @@ -8,9 +8,6 @@ namespace Slang { -StructTypeLayout* getGlobalStructLayout( - ProgramLayout* programLayout); - // Needed for lookup up entry-point layouts. // // TODO: maybe arrange so that codegen is driven from the layout layer @@ -721,14 +718,15 @@ IRFunc* specializeIRForEntryPoint( // than having to look it up on the original entry-point layout. if( auto firstBlock = clonedFunc->getFirstBlock() ) { - UInt paramLayoutCount = entryPointLayout->fields.Count(); + auto paramsStructLayout = getScopeStructLayout(entryPointLayout); + UInt paramLayoutCount = paramsStructLayout->fields.Count(); UInt paramCounter = 0; for( auto pp = firstBlock->getFirstParam(); pp; pp = pp->getNextParam() ) { UInt paramIndex = paramCounter++; if( paramIndex < paramLayoutCount ) { - auto paramLayout = entryPointLayout->fields[paramIndex]; + auto paramLayout = paramsStructLayout->fields[paramIndex]; context->builder->addLayoutDecoration( pp, paramLayout); @@ -1227,7 +1225,7 @@ IRSpecializationState* createIRSpecializationState( // Next, we want to optimize lookup for layout infromation // associated with global declarations, so that we can // look things up based on the IR values (using mangled names) - auto globalStructLayout = getGlobalStructLayout(newProgramLayout); + auto globalStructLayout = getScopeStructLayout(newProgramLayout); for (auto globalVarLayout : globalStructLayout->fields) { auto mangledName = getMangledName(globalVarLayout->varDecl); @@ -1235,6 +1233,10 @@ IRSpecializationState* createIRSpecializationState( } // for now, clone all unreferenced witness tables + // + // TODO: This step should *not* be needed with the current IR + // specialization approach, so we should consider removing it. + // for (auto sym :context->getSymbols()) { if (sym.Value->irGlobalValue->op == kIROp_WitnessTable) diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index f622804b2..0a5b8491c 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -1710,6 +1710,15 @@ namespace Slang (IRInst* const*) paramTypes); } + IRConstantBufferType* IRBuilder::getConstantBufferType(IRType* elementType) + { + IRInst* operands[] = { elementType }; + return (IRConstantBufferType*) getType( + kIROp_ConstantBufferType, + 1, + operands); + } + IRConstExprRate* IRBuilder::getConstExprRate() { return (IRConstExprRate*)getType(kIROp_ConstExprRate); diff --git a/source/slang/ir.h b/source/slang/ir.h index 343f5b79b..bbb68cdda 100644 --- a/source/slang/ir.h +++ b/source/slang/ir.h @@ -407,15 +407,35 @@ struct IRInst void _insertAt(IRInst* inPrev, IRInst* inNext, IRInst* inParent); }; -// `dynamic_cast` equivalent template<typename T> -T* as(IRInst* inst, T* /* */ = nullptr) +T* dynamicCast(IRInst* inst) { if (inst && T::isaImpl(inst->op)) - return (T*) inst; + return static_cast<T*>(inst); return nullptr; } +template<typename T> +const T* dynamicCast(const IRInst* inst) +{ + if (inst && T::isaImpl(inst->op)) + return static_cast<const T*>(inst); + return nullptr; +} + +// `dynamic_cast` equivalent (we just use dynamicCast) +template<typename T> +T* as(IRInst* inst) +{ + return dynamicCast<T>(inst); +} + +template<typename T> +const T* as(const IRInst* inst) +{ + return dynamicCast<T>(inst); +} + // `static_cast` equivalent, with debug validation template<typename T> T* cast(IRInst* inst, T* /* */ = nullptr) diff --git a/source/slang/legalize-types.h b/source/slang/legalize-types.h index 014df123f..d45642927 100644 --- a/source/slang/legalize-types.h +++ b/source/slang/legalize-types.h @@ -90,7 +90,7 @@ struct LegalType RefPtr<ImplicitDerefType> getImplicitDeref() const { SLANG_ASSERT(flavor == Flavor::implicitDeref); - return obj.As<ImplicitDerefType>(); + return obj.dynamicCast<ImplicitDerefType>(); } static LegalType tuple( @@ -99,7 +99,7 @@ struct LegalType RefPtr<TuplePseudoType> getTuple() const { SLANG_ASSERT(flavor == Flavor::tuple); - return obj.As<TuplePseudoType>(); + return obj.dynamicCast<TuplePseudoType>(); } static LegalType pair( @@ -113,7 +113,7 @@ struct LegalType RefPtr<PairPseudoType> getPair() const { SLANG_ASSERT(flavor == Flavor::pair); - return obj.As<PairPseudoType>(); + return obj.dynamicCast<PairPseudoType>(); } }; @@ -301,7 +301,7 @@ struct LegalVal RefPtr<TuplePseudoVal> getTuple() const { SLANG_ASSERT(flavor == Flavor::tuple); - return obj.As<TuplePseudoVal>(); + return obj.as<TuplePseudoVal>(); } static LegalVal implicitDeref(LegalVal const& val); @@ -316,7 +316,7 @@ struct LegalVal RefPtr<PairPseudoVal> getPair() const { SLANG_ASSERT(flavor == Flavor::pair); - return obj.As<PairPseudoVal>(); + return obj.as<PairPseudoVal>(); } }; diff --git a/source/slang/lookup.cpp b/source/slang/lookup.cpp index f74e11016..d9e63ff09 100644 --- a/source/slang/lookup.cpp +++ b/source/slang/lookup.cpp @@ -184,7 +184,7 @@ void DoMemberLookupImpl( // If the type was pointer-like, then dereference it // automatically here. - if (auto pointerLikeType = baseType->As<PointerLikeType>()) + if (auto pointerLikeType = as<PointerLikeType>(baseType)) { // Need to leave a breadcrumb to indicate that we // did an implicit dereference here @@ -200,9 +200,9 @@ void DoMemberLookupImpl( // Default case: no dereference needed - if (auto baseDeclRefType = baseType->As<DeclRefType>()) + if (auto baseDeclRefType = as<DeclRefType>(baseType)) { - if (auto baseAggTypeDeclRef = baseDeclRefType->declRef.As<AggTypeDecl>()) + if (auto baseAggTypeDeclRef = baseDeclRefType->declRef.as<AggTypeDecl>()) { DoLocalLookupImpl( session, @@ -239,14 +239,14 @@ DeclRef<Decl> maybeSpecializeInterfaceDeclRef( DeclRef<Decl> superTypeDeclRef, // The decl-ref we are going to perform lookup in DeclRef<TypeConstraintDecl> constraintDeclRef) // The type constraint that told us our type is a subtype { - if (auto superInterfaceDeclRef = superTypeDeclRef.As<InterfaceDecl>()) + if (auto superInterfaceDeclRef = superTypeDeclRef.as<InterfaceDecl>()) { // Create a subtype witness value to note the subtype relationship // that makes this specialization valid. // // Note: this is to ensure that we can specialize the subtype witness // later (e.g., by replacing a subtype witness that represents a generic - // constraint paraqmeter with the concrete generic arguments that + // constraint parameter with the concrete generic arguments that // are used at a particular call site to the generic). RefPtr<DeclaredSubtypeWitness> subtypeWitness = new DeclaredSubtypeWitness(); subtypeWitness->declRef = constraintDeclRef; @@ -272,9 +272,9 @@ RefPtr<Type> maybeSpecializeInterfaceDeclRef( RefPtr<Type> superType, // The type we are going to perform lookup in DeclRef<TypeConstraintDecl> constraintDeclRef) // The type constraint that told us our type is a subtype { - if (auto superDeclRefType = superType->As<DeclRefType>()) + if (auto superDeclRefType = as<DeclRefType>(superType)) { - if (auto superInterfaceDeclRef = superDeclRefType->declRef.As<InterfaceDecl>()) + if (auto superInterfaceDeclRef = superDeclRefType->declRef.as<InterfaceDecl>()) { auto specializedInterfaceDeclRef = maybeSpecializeInterfaceDeclRef( subType, @@ -356,7 +356,7 @@ void DoLocalLookupImpl( } // Consider lookup via extension - if( auto aggTypeDeclRef = containerDeclRef.As<AggTypeDecl>() ) + if( auto aggTypeDeclRef = containerDeclRef.as<AggTypeDecl>() ) { RefPtr<Type> type = DeclRefType::Create( session, @@ -388,23 +388,23 @@ void DoLocalLookupImpl( // // This code should be converted to do a type-based lookup // through declared bases for *any* aggregate type declaration. - // I think that logic is present in the type-bsed lookup path, but + // I think that logic is present in the type-based lookup path, but // it would be needed here for when doing lookup from inside an // aggregate declaration. // if we are looking at an extension, find the target decl that we are extending DeclRef<Decl> targetDeclRef = containerDeclRef; RefPtr<DeclRefType> targetDeclRefType; - if (auto extDeclRef = containerDeclRef.As<ExtensionDecl>()) + if (auto extDeclRef = containerDeclRef.as<ExtensionDecl>()) { - targetDeclRefType = extDeclRef.getDecl()->targetType->AsDeclRefType(); + targetDeclRefType = as<DeclRefType>(extDeclRef.getDecl()->targetType); SLANG_ASSERT(targetDeclRefType); int diff = 0; - targetDeclRef = targetDeclRefType->declRef.As<ContainerDecl>().SubstituteImpl(containerDeclRef.substitutions, &diff); + targetDeclRef = targetDeclRefType->declRef.as<ContainerDecl>().SubstituteImpl(containerDeclRef.substitutions, &diff); } // if we are looking inside an interface decl, try find in the interfaces it inherits from - bool isInterface = targetDeclRef.As<InterfaceDecl>() ? true : false; + bool isInterface = targetDeclRef.as<InterfaceDecl>() ? true : false; if (isInterface) { if(!targetDeclRefType) @@ -417,7 +417,7 @@ void DoLocalLookupImpl( { checkDecl(request.semantics, inheritanceDeclRef.decl); - auto baseType = inheritanceDeclRef.getDecl()->base.type.As<DeclRefType>(); + auto baseType = inheritanceDeclRef.getDecl()->base.type.dynamicCast<DeclRefType>(); SLANG_ASSERT(baseType); int diff = 0; auto baseInterfaceDeclRef = baseType->declRef.SubstituteImpl(containerDeclRef.substitutions, &diff); @@ -428,7 +428,7 @@ void DoLocalLookupImpl( baseInterfaceDeclRef, inheritanceDeclRef); - DoLocalLookupImpl(session, name, baseInterfaceDeclRef.As<ContainerDecl>(), request, result, inBreadcrumbs); + DoLocalLookupImpl(session, name, baseInterfaceDeclRef.as<ContainerDecl>(), request, result, inBreadcrumbs); } } } @@ -447,7 +447,7 @@ void DoLookupImpl( for (;scope != endScope; scope = scope->parent) { // Note that we consider all "peer" scopes together, - // so that a hit in one of them does not proclude + // so that a hit in one of them does not preclude // also finding a hit in another for(auto link = scope; link; link = link->nextSibling) { @@ -457,7 +457,7 @@ void DoLookupImpl( continue; DeclRef<ContainerDecl> containerDeclRef = - DeclRef<Decl>(containerDecl, createDefaultSubstitutions(session, containerDecl)).As<ContainerDecl>(); + DeclRef<Decl>(containerDecl, createDefaultSubstitutions(session, containerDecl)).as<ContainerDecl>(); BreadcrumbInfo breadcrumb; BreadcrumbInfo* breadcrumbs = nullptr; @@ -470,7 +470,7 @@ void DoLookupImpl( // just `AggTypeDecl`, because we want to catch `extension` // declarations as well. // - if (auto aggTypeDeclRef = containerDeclRef.As<AggTypeDeclBase>()) + if (auto aggTypeDeclRef = containerDeclRef.as<AggTypeDeclBase>()) { breadcrumb.kind = LookupResultItem::Breadcrumb::Kind::This; breadcrumb.thisParameterMode = thisParameterMode; @@ -485,13 +485,13 @@ void DoLookupImpl( // if we are currently in an extension decl, perform local lookup // in the target decl we are extending - if (auto extDeclRef = containerDeclRef.As<ExtensionDecl>()) + if (auto extDeclRef = containerDeclRef.as<ExtensionDecl>()) { if (extDeclRef.getDecl()->targetType) { - if (auto targetDeclRef = extDeclRef.getDecl()->targetType->AsDeclRefType()) + if (auto targetDeclRef = as<DeclRefType>(extDeclRef.getDecl()->targetType)) { - if (auto aggDeclRef = targetDeclRef->declRef.As<AggTypeDecl>()) + if (auto aggDeclRef = targetDeclRef->declRef.as<AggTypeDecl>()) { containerDeclRef = extDeclRef.Substitute(aggDeclRef); } @@ -502,7 +502,7 @@ void DoLookupImpl( session, name, containerDeclRef, request, result, breadcrumbs); - if( auto funcDeclRef = containerDeclRef.As<FunctionDeclBase>() ) + if( auto funcDeclRef = containerDeclRef.as<FunctionDeclBase>() ) { if( funcDeclRef.getDecl()->HasModifier<MutatingAttribute>() ) { @@ -611,7 +611,7 @@ void lookUpThroughConstraint( constraintDeclRef); // We need to track the indirection we took in lookup, - // so that we can construct an approrpiate AST on the other + // so that we can construct an appropriate AST on the other // side that includes the "upcase" from sub-type to super-type. // BreadcrumbInfo breadcrumb; @@ -624,7 +624,7 @@ void lookUpThroughConstraint( // // TODO: The even simpler thing we need to worry about here is that if // there is ever a "diamond" relationship in the inheritance hierarchy, - // we might end up seeing the same interface via diffrent "paths" and + // we might end up seeing the same interface via different "paths" and // we wouldn't want that to lead to overload-resolution failure. // lookUpMemberImpl(session, semantics, name, superType, ioResult, &breadcrumb, mask); @@ -639,12 +639,12 @@ void lookUpMemberImpl( BreadcrumbInfo* inBreadcrumbs, LookupMask mask) { - if (auto declRefType = type->As<DeclRefType>()) + if (auto declRefType = as<DeclRefType>(type)) { auto declRef = declRefType->declRef; - if (declRef.As<AssocTypeDecl>() || declRef.As<GlobalGenericParamDecl>()) + if (declRef.as<AssocTypeDecl>() || declRef.as<GlobalGenericParamDecl>()) { - for (auto constraintDeclRef : getMembersOfType<TypeConstraintDecl>(declRef.As<ContainerDecl>())) + for (auto constraintDeclRef : getMembersOfType<TypeConstraintDecl>(declRef.as<ContainerDecl>())) { lookUpThroughConstraint( session, @@ -657,16 +657,16 @@ void lookUpMemberImpl( mask); } } - else if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>()) + else if (auto aggTypeDeclRef = declRef.as<AggTypeDecl>()) { LookupRequest request; request.semantics = semantics; DoLocalLookupImpl(session, name, aggTypeDeclRef, request, ioResult, inBreadcrumbs); } - else if (auto genericTypeParamDeclRef = declRef.As<GenericTypeParamDecl>()) + else if (auto genericTypeParamDeclRef = declRef.as<GenericTypeParamDecl>()) { - auto genericDeclRef = genericTypeParamDeclRef.GetParent().As<GenericDecl>(); + auto genericDeclRef = genericTypeParamDeclRef.GetParent().as<GenericDecl>(); assert(genericDeclRef); for(auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(genericDeclRef)) @@ -677,7 +677,7 @@ void lookUpMemberImpl( // generic parameter in question, and `Foo` is whatever we are // constraining it to. auto subType = GetSub(constraintDeclRef); - auto subDeclRefType = subType->As<DeclRefType>(); + auto subDeclRefType = as<DeclRefType>(subType); if(!subDeclRefType) continue; if(!subDeclRefType->declRef.Equals(genericTypeParamDeclRef)) diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 3f77de446..f88a44e53 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -610,7 +610,7 @@ LoweredValInfo emitCallToDeclRef( auto builder = context->irBuilder; - if (auto subscriptDeclRef = funcDeclRef.As<SubscriptDecl>()) + if (auto subscriptDeclRef = funcDeclRef.as<SubscriptDecl>()) { // A reference to a subscript declaration is a special case, // because it is not possible to call a subscript directly; @@ -627,7 +627,7 @@ LoweredValInfo emitCallToDeclRef( // We want to track whether this subscript has any accessors other than // `get` (assuming that everything except `get` can be used for setting...). - if (auto foundGetterDeclRef = accessorDeclRef.As<GetterDecl>()) + if (auto foundGetterDeclRef = accessorDeclRef.as<GetterDecl>()) { // We found a getter. getterDeclRef = foundGetterDeclRef; @@ -731,7 +731,7 @@ LoweredValInfo emitCallToDeclRef( } // TODO: handle target intrinsic modifier too... - if( auto ctorDeclRef = funcDeclRef.As<ConstructorDecl>() ) + if( auto ctorDeclRef = funcDeclRef.as<ConstructorDecl>() ) { // HACK: we know all constructors are builtins for now, // so we need to emit them as a call to the corresponding @@ -900,7 +900,7 @@ top: auto base = materialize(context, boundMemberInfo->base); auto declRef = boundMemberInfo->declRef; - if( auto fieldDeclRef = declRef.As<VarDecl>() ) + if( auto fieldDeclRef = declRef.as<VarDecl>() ) { lowered = extractField(context, boundMemberInfo->type, base, fieldDeclRef); goto top; @@ -1152,13 +1152,13 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower // We will assume here that the super-type is an interface, and it // will be left to the front-end to ensure this property. // - auto supDeclRefType = val->sup->As<DeclRefType>(); + auto supDeclRefType = as<DeclRefType>(val->sup); if(!supDeclRefType) { SLANG_UNEXPECTED("super-type not a decl-ref type when generating tagged union witness table"); UNREACHABLE_RETURN(LoweredValInfo()); } - auto supInterfaceDeclRef = supDeclRefType->declRef.As<InterfaceDecl>(); + auto supInterfaceDeclRef = supDeclRefType->declRef.as<InterfaceDecl>(); if( !supInterfaceDeclRef ) { SLANG_UNEXPECTED("super-type not an interface type when generating tagged union witness table"); @@ -1197,7 +1197,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower - if(auto callableDeclRef = reqDeclRef.As<CallableDecl>()) + if(auto callableDeclRef = reqDeclRef.as<CallableDecl>()) { // We have something callable, so we need to synthesize // a function to satisfy it. @@ -1315,10 +1315,10 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower caseArgs.Add(caseThisArg); // The remaining arguments to the call will just be forwarded from - // the parameters of the wrapper functon. + // the parameters of the wrapper function. // // TODO: This would need to change if/when we started allowing `This` type - // or assocaited-type parameters to be used at call sites where a tagged + // or associated-type parameters to be used at call sites where a tagged // union is used. // for( auto param : irParams ) @@ -1624,39 +1624,39 @@ void addVarDecorations( auto builder = context->irBuilder; for(RefPtr<Modifier> mod : decl->modifiers) { - if(mod.As<HLSLNoInterpolationModifier>()) + if(as<HLSLNoInterpolationModifier>(mod)) { builder->addInterpolationModeDecoration(inst, IRInterpolationMode::NoInterpolation); } - else if(mod.As<HLSLNoPerspectiveModifier>()) + else if(as<HLSLNoPerspectiveModifier>(mod)) { builder->addInterpolationModeDecoration(inst, IRInterpolationMode::NoPerspective); } - else if(mod.As<HLSLLinearModifier>()) + else if(as<HLSLLinearModifier>(mod)) { builder->addInterpolationModeDecoration(inst, IRInterpolationMode::Linear); } - else if(mod.As<HLSLSampleModifier>()) + else if(as<HLSLSampleModifier>(mod)) { builder->addInterpolationModeDecoration(inst, IRInterpolationMode::Sample); } - else if(mod.As<HLSLCentroidModifier>()) + else if(as<HLSLCentroidModifier>(mod)) { builder->addInterpolationModeDecoration(inst, IRInterpolationMode::Centroid); } - else if(mod.As<VulkanRayPayloadAttribute>()) + else if(as<VulkanRayPayloadAttribute>(mod)) { builder->addSimpleDecoration<IRVulkanRayPayloadDecoration>(inst); } - else if(mod.As<VulkanCallablePayloadAttribute>()) + else if(as<VulkanCallablePayloadAttribute>(mod)) { builder->addSimpleDecoration<IRVulkanCallablePayloadDecoration>(inst); } - else if(mod.As<VulkanHitAttributesAttribute>()) + else if(as<VulkanHitAttributesAttribute>(mod)) { builder->addSimpleDecoration<IRVulkanHitAttributesDecoration>(inst); } - else if(mod.As<GloballyCoherentModifier>()) + else if(as<GloballyCoherentModifier>(mod)) { builder->addSimpleDecoration<IRGloballyCoherentDecoration>(inst); } @@ -1871,12 +1871,12 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> auto loweredBase = lowerRValueExpr(context, expr->BaseExpression); auto declRef = expr->declRef; - if (auto fieldDeclRef = declRef.As<VarDecl>()) + if (auto fieldDeclRef = declRef.as<VarDecl>()) { // Okay, easy enough: we have a reference to a field of a struct type... return extractField(loweredType, loweredBase, fieldDeclRef); } - else if (auto callableDeclRef = declRef.As<CallableDecl>()) + else if (auto callableDeclRef = declRef.as<CallableDecl>()) { RefPtr<BoundMemberInfo> boundMemberInfo = new BoundMemberInfo(); boundMemberInfo->type = nullptr; @@ -1884,7 +1884,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> boundMemberInfo->declRef = callableDeclRef; return LoweredValInfo::boundMember(boundMemberInfo); } - else if(auto constraintDeclRef = declRef.As<TypeConstraintDecl>()) + else if(auto constraintDeclRef = declRef.as<TypeConstraintDecl>()) { // The code is making use of a "witness" that a value of // some generic type conforms to an interface. @@ -1977,11 +1977,11 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> LoweredValInfo getDefaultVal(Type* type) { auto irType = lowerType(context, type); - if (auto basicType = type->As<BasicExpressionType>()) + if (auto basicType = as<BasicExpressionType>(type)) { return getSimpleDefaultVal(irType); } - else if (auto vectorType = type->As<VectorExpressionType>()) + else if (auto vectorType = as<VectorExpressionType>(type)) { UInt elementCount = (UInt) GetIntVal(vectorType->elementCount); @@ -1995,7 +1995,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return LoweredValInfo::simple( getBuilder()->emitMakeVector(irType, args.Count(), args.Buffer())); } - else if (auto matrixType = type->As<MatrixExpressionType>()) + else if (auto matrixType = as<MatrixExpressionType>(type)) { UInt rowCount = (UInt) GetIntVal(matrixType->getRowCount()); @@ -2011,7 +2011,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return LoweredValInfo::simple( getBuilder()->emitMakeMatrix(irType, args.Count(), args.Buffer())); } - else if (auto arrayType = type->As<ArrayExpressionType>()) + else if (auto arrayType = as<ArrayExpressionType>(type)) { UInt elementCount = (UInt) GetIntVal(arrayType->ArrayLength); @@ -2026,10 +2026,10 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return LoweredValInfo::simple( getBuilder()->emitMakeArray(irType, args.Count(), args.Buffer())); } - else if (auto declRefType = type->As<DeclRefType>()) + else if (auto declRefType = as<DeclRefType>(type)) { DeclRef<Decl> declRef = declRefType->declRef; - if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>()) + if (auto aggTypeDeclRef = declRef.as<AggTypeDecl>()) { List<IRInst*> args; for (auto ff : getMembersOfType<VarDecl>(aggTypeDeclRef)) @@ -2082,7 +2082,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // Now for each argument in the initializer list, // fill in the appropriate field of the result - if (auto arrayType = type->As<ArrayExpressionType>()) + if (auto arrayType = as<ArrayExpressionType>(type)) { UInt elementCount = (UInt) GetIntVal(arrayType->ArrayLength); @@ -2104,7 +2104,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return LoweredValInfo::simple( getBuilder()->emitMakeArray(irType, args.Count(), args.Buffer())); } - else if (auto vectorType = type->As<VectorExpressionType>()) + else if (auto vectorType = as<VectorExpressionType>(type)) { UInt elementCount = (UInt) GetIntVal(vectorType->elementCount); @@ -2126,7 +2126,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return LoweredValInfo::simple( getBuilder()->emitMakeVector(irType, args.Count(), args.Buffer())); } - else if (auto matrixType = type->As<MatrixExpressionType>()) + else if (auto matrixType = as<MatrixExpressionType>(type)) { UInt rowCount = (UInt) GetIntVal(matrixType->getRowCount()); @@ -2150,10 +2150,10 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return LoweredValInfo::simple( getBuilder()->emitMakeMatrix(irType, args.Count(), args.Buffer())); } - else if (auto declRefType = type->As<DeclRefType>()) + else if (auto declRefType = as<DeclRefType>(type)) { DeclRef<Decl> declRef = declRefType->declRef; - if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>()) + if (auto aggTypeDeclRef = declRef.as<AggTypeDecl>()) { UInt argCounter = 0; for (auto ff : getMembersOfType<VarDecl>(aggTypeDeclRef)) @@ -2181,7 +2181,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> } // If none of the above cases matched, then we had better - // have zero arguments in the initailizer list, in which + // have zero arguments in the initializer list, in which // case we are just looking for default initialization. // SLANG_UNEXPECTED("unhandled case for initializer list codegen"); @@ -2261,7 +2261,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // TODO: The approach we are taking here to default arguments // is simplistic, and has consequences for the front-end as - // well as binary serializatiojn of modules. + // well as binary serialization of modules. // // We could consider some more refined approaches where, e.g., // functions with default arguments generate multiple IR-level @@ -2364,7 +2364,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> List<IRInst*>* ioArgs, List<OutArgumentFixup>* ioFixups) { - if (auto callableDeclRef = funcDeclRef.As<CallableDecl>()) + if (auto callableDeclRef = funcDeclRef.as<CallableDecl>()) { addDirectCallArgs(expr, callableDeclRef, ioArgs, ioFixups); } @@ -2412,7 +2412,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // First look to see if the expression references a // declaration at all. - auto declRefExpr = funcExpr.As<DeclRefExpr>(); + auto declRefExpr = as<DeclRefExpr>(funcExpr); if(!declRefExpr) return false; @@ -2430,24 +2430,24 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> else { // The callee declaration isn't itself a callable (it must have - // a funciton type, though). + // a function type, though). return false; } // Now we can look at the specific kinds of declaration references, // and try to tease them apart. - if (auto memberFuncExpr = funcExpr.As<MemberExpr>()) + if (auto memberFuncExpr = as<MemberExpr>(funcExpr)) { outInfo->funcDeclRef = memberFuncExpr->declRef; outInfo->baseExpr = memberFuncExpr->BaseExpression; return true; } - else if (auto staticMemberFuncExpr = funcExpr.As<StaticMemberExpr>()) + else if (auto staticMemberFuncExpr = as<StaticMemberExpr>(funcExpr)) { outInfo->funcDeclRef = staticMemberFuncExpr->declRef; return true; } - else if (auto varExpr = funcExpr.As<VarExpr>()) + else if (auto varExpr = as<VarExpr>(funcExpr)) { outInfo->funcDeclRef = varExpr->declRef; return true; @@ -2484,7 +2484,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> List<IRInst*> irArgs; // We will also collect "fixup" actions that need - // to be performed after teh call, in order to + // to be performed after the call, in order to // copy the final values for `out` parameters // back to their arguments. List<OutArgumentFixup> argFixups; @@ -2493,7 +2493,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> ResolvedCallInfo resolvedInfo; if( tryResolveDeclRefForCall(funcExpr, &resolvedInfo) ) { - // In this case we know exaclty what declaration we + // In this case we know exactly what declaration we // are going to call, and so we can resolve things // appropriately. auto funcDeclRef = resolvedInfo.funcDeclRef; @@ -2525,9 +2525,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // TODO: In this case we should be emitting code for the callee as // an ordinary expression, then emitting the arguments according - // to the type information on the callee (e.g., which paameters + // to the type information on the callee (e.g., which parameters // are `out` or `inout`, and then finally emitting the `call` - // instruciton. + // instruction. // // We don't currently have the case of emitting arguments according // to function type info (instead of declaration info), and really @@ -2632,7 +2632,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // Because our representation of lowered "values" // can encompass l-values explicitly, we can // lower assignment easily. We just lower the left- - // and right-hand sides, and then peform an assignment + // and right-hand sides, and then perform an assignment // based on the resulting values. // auto leftVal = lowerLValueExpr(context, expr->left); @@ -2693,7 +2693,7 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis { auto baseSwizzleInfo = loweredBase.getSwizzledLValueInfo(); - // Our new swizzle witll use the same base expression (e.g., + // Our new swizzle will use the same base expression (e.g., // `foo[i]` in our example above), but will need to remap // the swizzle indices it uses. // @@ -2740,7 +2740,7 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis struct RValueExprLoweringVisitor : ExprLoweringVisitorBase<RValueExprLoweringVisitor> { // A swizzle in an r-value context can save time by just - // emitting the swizzle instuctions directly. + // emitting the swizzle instructions directly. LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr) { auto irType = lowerType(context, expr->type); @@ -3713,7 +3713,7 @@ LoweredValInfo tryGetAddress( // we care about, and then write it back. auto declRef = boundMemberInfo->declRef; - if( auto fieldDeclRef = declRef.As<VarDecl>() ) + if( auto fieldDeclRef = declRef.as<VarDecl>() ) { auto baseVal = boundMemberInfo->base; auto basePtr = tryGetAddress(context, baseVal, TryGetAddressMode::Aggressive); @@ -3955,7 +3955,7 @@ top: // we care about, and then write it back. auto declRef = boundMemberInfo->declRef; - if( auto fieldDeclRef = declRef.As<VarDecl>() ) + if( auto fieldDeclRef = declRef.as<VarDecl>() ) { // materialize the base value and move it into // a mutable temporary if needed @@ -4071,13 +4071,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // This might be a type constraint on an associated type, // in which case it should lower as the key for that // interface requirement. - if(auto assocTypeDecl = decl->ParentDecl->As<AssocTypeDecl>()) + if(auto assocTypeDecl = as<AssocTypeDecl>(decl->ParentDecl)) { // TODO: might need extra steps if we ever allow // generic associated types. - if(auto interfaceDecl = assocTypeDecl->ParentDecl->As<InterfaceDecl>()) + if(auto interfaceDecl = as<InterfaceDecl>(assocTypeDecl->ParentDecl)) { // Okay, this seems to be an interface rquirement, and // we should lower it as such. @@ -4085,7 +4085,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } } - if(auto globalGenericParamDecl = decl->ParentDecl->As<GlobalGenericParamDecl>()) + if(auto globalGenericParamDecl = as<GlobalGenericParamDecl>(decl->ParentDecl)) { // This is a constraint on a global generic type parameters, // and so it should lower as a parameter of its own. @@ -4189,7 +4189,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // interface requires, and not what it provides. // auto parentDecl = inheritanceDecl->ParentDecl; - if (auto parentInterfaceDecl = parentDecl->As<InterfaceDecl>()) + if (auto parentInterfaceDecl = as<InterfaceDecl>(parentDecl)) { return LoweredValInfo::simple(getInterfaceRequirementKey(inheritanceDecl)); } @@ -4198,12 +4198,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // declaration is being used to add a conformance to // an existing `interface`: // - if(auto parentExtensionDecl = parentDecl->As<ExtensionDecl>()) + if(auto parentExtensionDecl = as<ExtensionDecl>(parentDecl)) { auto targetType = parentExtensionDecl->targetType; - if(auto targetDeclRefType = targetType->As<DeclRefType>()) + if(auto targetDeclRefType = as<DeclRefType>(targetType)) { - if(auto targetInterfaceDeclRef = targetDeclRefType->declRef.As<InterfaceDecl>()) + if(auto targetInterfaceDeclRef = targetDeclRefType->declRef.as<InterfaceDecl>()) { return LoweredValInfo::simple(getInterfaceRequirementKey(inheritanceDecl)); } @@ -4278,7 +4278,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> LoweredValInfo visitDeclGroup(DeclGroup* declGroup) { - // To lowere a group of declarations, we just + // To lower a group of declarations, we just // lower each one individually. // for (auto decl : declGroup->decls) @@ -4497,11 +4497,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // in the order they were declared. for (auto member : genericDecl->Members) { - if (auto typeParamDecl = member.As<GenericTypeParamDecl>()) + if (auto typeParamDecl = as<GenericTypeParamDecl>(member)) { genericArgs.Add(getSimpleVal(context, ensureDecl(context, typeParamDecl))); } - else if (auto valDecl = member.As<GenericValueParamDecl>()) + else if (auto valDecl = as<GenericValueParamDecl>(member)) { genericArgs.Add(getSimpleVal(context, ensureDecl(context, valDecl))); } @@ -4510,7 +4510,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // declaration order. for (auto member : genericDecl->Members) { - if (auto constraintDecl = member.As<GenericTypeConstraintDecl>()) + if (auto constraintDecl = as<GenericTypeConstraintDecl>(member)) { genericArgs.Add(getSimpleVal(context, ensureDecl(context, constraintDecl))); } @@ -4815,7 +4815,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // As a special case, any type constraints placed // on an associated type will *also* need to be turned // into requirement keys for this interface. - if (auto associatedTypeDecl = requirementDecl.As<AssocTypeDecl>()) + if (auto associatedTypeDecl = as<AssocTypeDecl>(requirementDecl)) { for (auto constraintDecl : associatedTypeDecl->getMembersOfType<TypeConstraintDecl>()) { @@ -5018,7 +5018,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> DeclRef<D> createDefaultSpecializedDeclRef(D* decl) { DeclRef<Decl> declRef = createDefaultSpecializedDeclRefImpl(decl); - return declRef.As<D>(); + return declRef.as<D>(); } @@ -5318,7 +5318,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // in the order they were declared. for (auto member : genericDecl->Members) { - if (auto typeParamDecl = member.As<GenericTypeParamDecl>()) + if (auto typeParamDecl = as<GenericTypeParamDecl>(member)) { // TODO: use a `TypeKind` to represent the // classifier of the parameter. @@ -5326,7 +5326,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> addNameHint(context, param, typeParamDecl); setValue(subContext, typeParamDecl, LoweredValInfo::simple(param)); } - else if (auto valDecl = member.As<GenericValueParamDecl>()) + else if (auto valDecl = as<GenericValueParamDecl>(member)) { auto paramType = lowerType(subContext, valDecl->getType()); auto param = subBuilder->emitParam(paramType); @@ -5338,7 +5338,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // declaration order. for (auto member : genericDecl->Members) { - if (auto constraintDecl = member.As<GenericTypeConstraintDecl>()) + if (auto constraintDecl = as<GenericTypeConstraintDecl>(member)) { // TODO: use a `WitnessTableKind` to represent the // classifier of the parameter. @@ -5796,14 +5796,15 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> LoweredValInfo visitGenericDecl(GenericDecl * genDecl) { // TODO: Should this just always visit/lower the inner decl? - if (auto innerFuncDecl = genDecl->inner->As<FunctionDeclBase>()) + + if (auto innerFuncDecl = as<FunctionDeclBase>(genDecl->inner)) return ensureDecl(context, innerFuncDecl); - else if (auto innerStructDecl = genDecl->inner->As<StructDecl>()) + else if (auto innerStructDecl = as<StructDecl>(genDecl->inner)) { ensureDecl(context, innerStructDecl); return LoweredValInfo(); } - else if( auto extensionDecl = genDecl->inner->As<ExtensionDecl>() ) + else if( auto extensionDecl = as<ExtensionDecl>(genDecl->inner) ) { return ensureDecl(context, extensionDecl); } @@ -5816,7 +5817,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // A function declaration may have multiple, target-specific // overloads, and we need to emit an IR version of each of these. - // The front end will form a linked list of declaratiosn with + // The front end will form a linked list of declarations with // the same signature, whenever there is any kind of redeclaration. // We will look to see if that linked list has been formed. auto primaryDecl = decl->primaryDecl; @@ -5940,17 +5941,17 @@ IRInst* lowerSubstitutionArg( bool canDeclLowerToAGeneric(RefPtr<Decl> decl) { // A callable decl lowers to an `IRFunc`, and can be generic - if(decl.As<CallableDecl>()) return true; + if(as<CallableDecl>(decl)) return true; // An aggregate type decl lowers to an `IRStruct`, and can be generic - if(decl.As<AggTypeDecl>()) return true; + if(as<AggTypeDecl>(decl)) return true; // An inheritance decl lowers to an `IRWitnessTable`, and can be generic - if(decl.As<InheritanceDecl>()) return true; + if(as<InheritanceDecl>(decl)) return true; // A `typedef` declaration nested under a generic will turn into // a generic that returns a type (a simple type-level function). - if(decl.As<TypeDefDecl>()) return true; + if(as<TypeDefDecl>(decl)) return true; return false; } @@ -5966,15 +5967,15 @@ LoweredValInfo emitDeclRef( // Ignore any global generic type substitutions during lowering. // Really, we don't even expect these to appear. - while(auto globalGenericSubst = subst.As<GlobalGenericParamSubstitution>()) + while(auto globalGenericSubst = as<GlobalGenericParamSubstitution>(subst)) subst = globalGenericSubst->outer; // If the declaration would not get wrapped in a `IRGeneric`, // even if it is nested inside of an AST `GenericDecl`, then - // we should also ignore any generic substiuttions. + // we should also ignore any generic substitutions. if(!canDeclLowerToAGeneric(decl)) { - while(auto genericSubst = subst.As<GenericSubstitution>()) + while(auto genericSubst = as<GenericSubstitution>(subst)) subst = genericSubst->outer; } @@ -5988,7 +5989,7 @@ LoweredValInfo emitDeclRef( } // Otherwise, we look at the kind of substitution, and let it guide us. - if(auto genericSubst = subst.As<GenericSubstitution>()) + if(auto genericSubst = subst.as<GenericSubstitution>()) { // A generic substitution means we will need to output // a `specialize` instruction to specialize the generic. @@ -6037,7 +6038,7 @@ LoweredValInfo emitDeclRef( return LoweredValInfo::simple(irSpecializedVal); } - else if(auto thisTypeSubst = subst.As<ThisTypeSubstitution>()) + else if(auto thisTypeSubst = subst.as<ThisTypeSubstitution>()) { if(decl.Ptr() == thisTypeSubst->interfaceDecl) { @@ -6057,7 +6058,7 @@ LoweredValInfo emitDeclRef( // Note: unlike the case for generics above, in the interface-lookup // case, we don't end up caring about any further outer substitutions. // That is because even if we are naming `ISomething<Foo>.doIt()`, - // a method insided a generic interface, we don't actually care + // a method inside a generic interface, we don't actually care // about the substitution of `Foo` for the parameter `T` of // `ISomething<T>`. That is because we really care about the // witness table for the concrete type that conforms to `ISomething<Foo>`. @@ -6112,7 +6113,7 @@ static void lowerEntryPointToIR( } auto loweredEntryPointFunc = ensureDecl(context, entryPointFuncDecl); - // Attach a marker decoraton so that we recognize + // Attach a marker decoration so that we recognize // this as an entry point. auto builder = context->irBuilder; builder->addEntryPointDecoration(getSimpleVal(context, loweredEntryPointFunc)); @@ -6123,7 +6124,7 @@ static void lowerEntryPointToIR( builder->setInsertInto(builder->getModule()->getModuleInst()); for (RefPtr<Substitutions> subst = entryPointRequest->globalGenericSubst; subst; subst = subst->outer) { - auto gSubst = subst.As<GlobalGenericParamSubstitution>(); + auto gSubst = subst.as<GlobalGenericParamSubstitution>(); if(!gSubst) continue; diff --git a/source/slang/mangle.cpp b/source/slang/mangle.cpp index 8ad0bc9f5..b153cb8dd 100644 --- a/source/slang/mangle.cpp +++ b/source/slang/mangle.cpp @@ -213,14 +213,14 @@ namespace Slang DeclRef<Decl> declRef) { auto parentDeclRef = declRef.GetParent(); - auto parentGenericDeclRef = parentDeclRef.As<GenericDecl>(); + auto parentGenericDeclRef = parentDeclRef.as<GenericDecl>(); if( parentDeclRef ) { // In certain cases we want to skip emitting the parent if(parentGenericDeclRef && (parentGenericDeclRef.getDecl()->inner.Ptr() != declRef.getDecl())) { } - else if(parentDeclRef.As<FunctionDeclBase>()) + else if(parentDeclRef.as<FunctionDeclBase>()) { } else @@ -232,7 +232,7 @@ namespace Slang // A generic declaration is kind of a pseudo-declaration // as far as the user is concerned; so we don't want // to emit its name. - if(auto genericDeclRef = declRef.As<GenericDecl>()) + if(auto genericDeclRef = declRef.as<GenericDecl>()) { return; } @@ -240,7 +240,7 @@ namespace Slang // Inheritance declarations don't have meaningful names, // and so we should emit them based on the type // that is doing the inheriting. - if(auto inheritanceDeclRef = declRef.As<InheritanceDecl>()) + if(auto inheritanceDeclRef = declRef.as<InheritanceDecl>()) { emit(context, "I"); emitType(context, GetSup(inheritanceDeclRef)); @@ -250,7 +250,7 @@ namespace Slang // Similarly, an extension doesn't have a name worth // emitting, and we should base things on its target // type instead. - if(auto extensionDeclRef = declRef.As<ExtensionDecl>()) + if(auto extensionDeclRef = declRef.as<ExtensionDecl>()) { // TODO: as a special case, an "unconditional" extension // that is in the same module as the type it extends should @@ -264,9 +264,9 @@ namespace Slang // Special case: accessors need some way to distinguish themselves // so that a getter/setter/ref-er don't all compile to the same name. - if(declRef.As<GetterDecl>()) emitRaw(context, "Ag"); - if(declRef.As<SetterDecl>()) emitRaw(context, "As"); - if(declRef.As<RefAccessorDecl>()) emitRaw(context, "Ar"); + if(declRef.as<GetterDecl>()) emitRaw(context, "Ag"); + if(declRef.as<SetterDecl>()) emitRaw(context, "As"); + if(declRef.as<RefAccessorDecl>()) emitRaw(context, "Ar"); // Are we the "inner" declaration beneath a generic decl? if(parentGenericDeclRef && (parentGenericDeclRef.getDecl()->inner.Ptr() == declRef.getDecl())) @@ -294,15 +294,15 @@ namespace Slang UInt genericParameterCount = 0; for( auto mm : getMembers(parentGenericDeclRef) ) { - if(mm.As<GenericTypeParamDecl>()) + if(mm.as<GenericTypeParamDecl>()) { genericParameterCount++; } - else if(mm.As<GenericValueParamDecl>()) + else if(mm.as<GenericValueParamDecl>()) { genericParameterCount++; } - else if(mm.As<GenericTypeConstraintDecl>()) + else if(mm.as<GenericTypeConstraintDecl>()) { genericParameterCount++; } @@ -314,16 +314,16 @@ namespace Slang emit(context, genericParameterCount); for( auto mm : getMembers(parentGenericDeclRef) ) { - if(auto genericTypeParamDecl = mm.As<GenericTypeParamDecl>()) + if(auto genericTypeParamDecl = mm.as<GenericTypeParamDecl>()) { emitRaw(context, "T"); } - else if(auto genericValueParamDecl = mm.As<GenericValueParamDecl>()) + else if(auto genericValueParamDecl = mm.as<GenericValueParamDecl>()) { emitRaw(context, "v"); emitType(context, GetType(genericValueParamDecl)); } - else if(mm.As<GenericTypeConstraintDecl>()) + else if(mm.as<GenericTypeConstraintDecl>()) { emitRaw(context, "C"); // TODO: actually emit info about the constraint @@ -342,7 +342,7 @@ namespace Slang // We'll also go ahead and emit the result type as well, // just for completeness. // - if( auto callableDeclRef = declRef.As<CallableDecl>()) + if( auto callableDeclRef = declRef.as<CallableDecl>()) { auto parameters = GetParameters(callableDeclRef); UInt parameterCount = parameters.Count(); @@ -358,7 +358,7 @@ namespace Slang // Don't print result type for an initializer/constructor, // since it is implicit in the qualified name. - if (!callableDeclRef.As<ConstructorDecl>()) + if (!callableDeclRef.as<ConstructorDecl>()) { emitType(context, GetResultType(callableDeclRef)); } diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index 6bb8749dd..904ec3129 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -15,7 +15,7 @@ struct ParameterInfo; struct UsedRange { // What parameter has claimed this range? - ParameterInfo* parameter = nullptr; + VarLayout* parameter; // Begin/end of the range (half-open interval) UInt begin; @@ -69,7 +69,7 @@ struct UsedRanges // then we return that parameter so that the // caller can issue an error. // - ParameterInfo* Add(UsedRange range) + VarLayout* Add(UsedRange range) { // The invariant on entry to this // function is that the `ranges` array @@ -86,8 +86,8 @@ struct UsedRanges // match the parameter on `range`, so that // the compiler can issue useful diagnostics. // - ParameterInfo* newParam = range.parameter; - ParameterInfo* existingParam = nullptr; + VarLayout* newParam = range.parameter; + VarLayout* existingParam = nullptr; // A clever algorithm might use a binary // search to identify the first entry in `ranges` @@ -210,7 +210,7 @@ struct UsedRanges return existingParam; } - ParameterInfo* Add(ParameterInfo* param, UInt begin, UInt end) + VarLayout* Add(VarLayout* param, UInt begin, UInt end) { UsedRange range; range.parameter = param; @@ -219,7 +219,7 @@ struct UsedRanges return Add(range); } - ParameterInfo* Add(ParameterInfo* param, UInt begin, LayoutSize end) + VarLayout* Add(VarLayout* param, UInt begin, LayoutSize end) { UsedRange range; range.parameter = param; @@ -246,7 +246,7 @@ struct UsedRanges // Try to find space for `count` entries - UInt Allocate(ParameterInfo* param, UInt count) + UInt Allocate(VarLayout* param, UInt count) { UInt begin = 0; @@ -279,11 +279,16 @@ struct UsedRanges struct ParameterBindingInfo { - size_t space; - size_t index; + size_t space = 0; + size_t index = 0; LayoutSize count; }; +struct ParameterBindingAndKindInfo : ParameterBindingInfo +{ + LayoutResourceKind kind = LayoutResourceKind::None; +}; + enum { kLayoutResourceKindCount = SLANG_PARAMETER_CATEGORY_COUNT, @@ -353,11 +358,6 @@ struct SharedParameterBindingContext // Dictionary<UInt, RefPtr<UsedRangeSet>> globalSpaceUsedRangeSets; - // What ranges of resource bindings are claimed for particular translation unit? - // This is only used for varying input/output. - // - Dictionary<TranslationUnitRequest*, RefPtr<UsedRangeSet>> translationUnitUsedRangeSets; - // Which register spaces have been claimed so far? UsedRanges usedSpaces; @@ -785,12 +785,12 @@ static bool validateSpecializationsMatch( for(;;) { // Skip any global generic substitutions. - if(auto leftGlobalGeneric = ll.As<GlobalGenericParamSubstitution>()) + if(auto leftGlobalGeneric = ll.as<GlobalGenericParamSubstitution>()) { ll = leftGlobalGeneric->outer; continue; } - if(auto rightGlobalGeneric = rr.As<GlobalGenericParamSubstitution>()) + if(auto rightGlobalGeneric = rr.as<GlobalGenericParamSubstitution>()) { rr = rightGlobalGeneric->outer; continue; @@ -806,9 +806,9 @@ static bool validateSpecializationsMatch( ll = ll->outer; rr = rr->outer; - if(auto leftGeneric = leftSubst.As<GenericSubstitution>()) + if(auto leftGeneric = leftSubst.as<GenericSubstitution>()) { - if(auto rightGeneric = rightSubst.As<GenericSubstitution>()) + if(auto rightGeneric = as<GenericSubstitution>(rightSubst)) { if(validateGenericSubstitutionsMatch(context, leftGeneric, rightGeneric, stack)) { @@ -816,9 +816,9 @@ static bool validateSpecializationsMatch( } } } - else if(auto leftThisType = leftSubst.As<ThisTypeSubstitution>()) + else if(auto leftThisType = leftSubst.as<ThisTypeSubstitution>()) { - if(auto rightThisType = rightSubst.As<ThisTypeSubstitution>()) + if(auto rightThisType = rightSubst.as<ThisTypeSubstitution>()) { if(validateThisTypeSubstitutionsMatch(context, leftThisType, rightThisType, stack)) { @@ -851,9 +851,9 @@ static bool validateTypesMatch( // are ever recursive types. We'd need a more refined system to // cache the matches we've already found. - if( auto leftDeclRefType = left->As<DeclRefType>() ) + if( auto leftDeclRefType = as<DeclRefType>(left) ) { - if( auto rightDeclRefType = right->As<DeclRefType>() ) + if( auto rightDeclRefType = as<DeclRefType>(right) ) { // Are they references to matching decl refs? auto leftDeclRef = leftDeclRefType->declRef; @@ -879,9 +879,9 @@ static bool validateTypesMatch( } // Check that any declared fields match too. - if( auto leftStructDeclRef = leftDeclRef.As<AggTypeDecl>() ) + if( auto leftStructDeclRef = leftDeclRef.as<AggTypeDecl>() ) { - if( auto rightStructDeclRef = rightDeclRef.As<AggTypeDecl>() ) + if( auto rightStructDeclRef = rightDeclRef.as<AggTypeDecl>() ) { List<DeclRef<VarDecl>> leftFields; List<DeclRef<VarDecl>> rightFields; @@ -931,9 +931,9 @@ static bool validateTypesMatch( // If we are looking at `T[N]` and `U[M]` we want to check that // `T` is structurally equivalent to `U` and `N` is the same as `M`. - else if( auto leftArrayType = left->As<ArrayExpressionType>() ) + else if( auto leftArrayType = as<ArrayExpressionType>(left) ) { - if( auto rightArrayType = right->As<ArrayExpressionType>() ) + if( auto rightArrayType = as<ArrayExpressionType>(right) ) { if(!validateTypesMatch(context, leftArrayType->baseType, rightArrayType->baseType, stack) ) return false; @@ -1029,7 +1029,7 @@ RefPtr<Type> tryGetEffectiveTypeForGLSLVaryingInput( return nullptr; auto type = varDecl->getType(); - if( varDecl->HasModifier<InModifier>() || type->As<GLSLInputParameterGroupType>()) + if( varDecl->HasModifier<InModifier>() || as<GLSLInputParameterGroupType>(type)) { // Special case to handle "arrayed" shader inputs, as used // for Geometry and Hull input @@ -1041,8 +1041,8 @@ RefPtr<Type> tryGetEffectiveTypeForGLSLVaryingInput( // Tessellation `patch` variables should stay as written if( !varDecl->HasModifier<GLSLPatchModifier>() ) { - // Unwrap array type, if prsent - if( auto arrayType = type->As<ArrayExpressionType>() ) + // Unwrap array type, if present + if( auto arrayType = as<ArrayExpressionType>(type) ) { type = arrayType->baseType.Ptr(); } @@ -1067,7 +1067,7 @@ RefPtr<Type> tryGetEffectiveTypeForGLSLVaryingOutput( return nullptr; auto type = varDecl->getType(); - if( varDecl->HasModifier<OutModifier>() || type->As<GLSLOutputParameterGroupType>()) + if( varDecl->HasModifier<OutModifier>() || as<GLSLOutputParameterGroupType>(type)) { // Special case to handle "arrayed" shader outputs, as used // for Hull Shader output @@ -1080,8 +1080,8 @@ RefPtr<Type> tryGetEffectiveTypeForGLSLVaryingOutput( // Tessellation `patch` variables should stay as written if( !varDecl->HasModifier<GLSLPatchModifier>() ) { - // Unwrap array type, if prsent - if( auto arrayType = type->As<ArrayExpressionType>() ) + // Unwrap array type, if present + if( auto arrayType = as<ArrayExpressionType>(type) ) { type = arrayType->baseType.Ptr(); } @@ -1098,93 +1098,18 @@ RefPtr<Type> tryGetEffectiveTypeForGLSLVaryingOutput( return nullptr; } -RefPtr<TypeLayout> -getTypeLayoutForGlobalShaderParameter_GLSL( + /// Determine how to lay out a global variable that might be a shader parameter. + /// + /// Returns `nullptr` if the declaration does not represent a shader parameter. +RefPtr<TypeLayout> getTypeLayoutForGlobalShaderParameter( ParameterBindingContext* context, - VarDeclBase* varDecl) -{ - auto layoutContext = context->layoutContext; - auto rules = layoutContext.getRulesFamily(); - auto type = varDecl->getType(); - - // A GLSL shader parameter will be marked with - // a qualifier to match the boundary it uses - // - // In the case of a parameter block, we will have - // consumed this qualifier as part of parsing, - // so that it won't be present on the declaration - // any more. As such we also inspect the type - // of the variable. - - // We want to check for a constant-buffer type with a `push_constant` layout - // qualifier before we move on to anything else. - if( varDecl->HasModifier<PushConstantAttribute>() && type->As<ConstantBufferType>() ) - { - return CreateTypeLayout( - layoutContext.with(rules->getPushConstantBufferRules()), - type); - } - - // TODO(tfoley): We have multiple variations of - // the `uniform` modifier right now, and that - // needs to get fixed... - if( varDecl->HasModifier<HLSLUniformModifier>() || type->As<ConstantBufferType>() ) - { - return CreateTypeLayout( - layoutContext.with(rules->getConstantBufferRules()), - type); - } - - if( varDecl->HasModifier<GLSLBufferModifier>() || type->As<GLSLShaderStorageBufferType>() ) - { - return CreateTypeLayout( - layoutContext.with(rules->getShaderStorageBufferRules()), - type); - } - - if (auto effectiveVaryingInputType = tryGetEffectiveTypeForGLSLVaryingInput(context, varDecl)) - { - // We expect to handle these elsewhere - SLANG_DIAGNOSE_UNEXPECTED(getSink(context), varDecl, "GLSL varying input"); - return CreateTypeLayout( - layoutContext.with(rules->getVaryingInputRules()), - effectiveVaryingInputType); - } - - if (auto effectiveVaryingOutputType = tryGetEffectiveTypeForGLSLVaryingOutput(context, varDecl)) - { - // We expect to handle these elsewhere - SLANG_DIAGNOSE_UNEXPECTED(getSink(context), varDecl, "GLSL varying output"); - return CreateTypeLayout( - layoutContext.with(rules->getVaryingOutputRules()), - effectiveVaryingOutputType); - } - - // A `const` global with a `layout(constant_id = ...)` modifier - // is a declaration of a specialization constant. - if( varDecl->HasModifier<GLSLConstantIDLayoutModifier>() ) - { - return CreateTypeLayout( - layoutContext.with(rules->getSpecializationConstantRules()), - type); - } - - // GLSL says that an "ordinary" global variable - // is just a (thread local) global and not a - // parameter - return nullptr; -} - -RefPtr<TypeLayout> -getTypeLayoutForGlobalShaderParameter_HLSL( - ParameterBindingContext* context, - VarDeclBase* varDecl) + VarDeclBase* varDecl, + Type* type) { auto layoutContext = context->layoutContext; auto rules = layoutContext.getRulesFamily(); - auto type = varDecl->getType(); - if( varDecl->HasModifier<ShaderRecordNVLayoutModifier>() && type->As<ConstantBufferType>() ) + if( varDecl->HasModifier<ShaderRecordNVLayoutModifier>() && as<ConstantBufferType>(type) ) { return CreateTypeLayout( layoutContext.with(rules->getShaderRecordConstantBufferRules()), @@ -1193,7 +1118,7 @@ getTypeLayoutForGlobalShaderParameter_HLSL( // We want to check for a constant-buffer type with a `push_constant` layout // qualifier before we move on to anything else. - if (varDecl->HasModifier<PushConstantAttribute>() && type->As<ConstantBufferType>()) + if (varDecl->HasModifier<PushConstantAttribute>() && as<ConstantBufferType>(type)) { return CreateTypeLayout( layoutContext.with(rules->getPushConstantBufferRules()), @@ -1217,32 +1142,13 @@ getTypeLayoutForGlobalShaderParameter_HLSL( type); } -// Determine how to lay out a global variable that might be -// a shader parameter. -// Returns `nullptr` if the declaration does not represent -// a shader parameter. - -RefPtr<TypeLayout> -getTypeLayoutForGlobalShaderParameter( +RefPtr<TypeLayout> getTypeLayoutForGlobalShaderParameter( ParameterBindingContext* context, VarDeclBase* varDecl) { - switch( context->sourceLanguage ) - { - case SourceLanguage::Slang: - case SourceLanguage::HLSL: - return getTypeLayoutForGlobalShaderParameter_HLSL(context, varDecl); - - case SourceLanguage::GLSL: - return getTypeLayoutForGlobalShaderParameter_GLSL(context, varDecl); - - default: - SLANG_UNEXPECTED("unhandled source language"); - UNREACHABLE_RETURN(nullptr); - } + return getTypeLayoutForGlobalShaderParameter(context, varDecl, varDecl->getType()); } - // enum EntryPointParameterDirection @@ -1264,44 +1170,12 @@ struct EntryPointParameterState }; -static RefPtr<TypeLayout> processEntryPointParameter( +static RefPtr<TypeLayout> processEntryPointVaryingParameter( ParameterBindingContext* context, RefPtr<Type> type, EntryPointParameterState const& state, RefPtr<VarLayout> varLayout); -static void collectGlobalScopeGLSLVaryingParameter( - ParameterBindingContext* context, - RefPtr<VarDeclBase> varDecl, - RefPtr<Type> effectiveType, - EntryPointParameterDirection direction) -{ - int defaultSemanticIndex = 0; - - EntryPointParameterState state; - state.directionMask = direction; - state.ioSemanticIndex = &defaultSemanticIndex; - state.stage = context->stage; - state.loc = varDecl->loc; - - RefPtr<VarLayout> varLayout = new VarLayout(); - varLayout->varDecl = makeDeclRef(varDecl.Ptr()); - - varLayout->typeLayout = processEntryPointParameter( - context, - effectiveType, - state, - varLayout); - - // Now add it to our list of reflection parameters, so - // that it can get a location assigned later... - - ParameterInfo* parameterInfo = new ParameterInfo(); - parameterInfo->translationUnit = context->translationUnit; - context->shared->parameters.Add(parameterInfo); - parameterInfo->varLayouts.Add(varLayout); -} - // Collect a single declaration into our set of parameters static void collectGlobalGenericParameter( ParameterBindingContext* context, @@ -1319,25 +1193,6 @@ static void collectGlobalScopeParameter( ParameterBindingContext* context, RefPtr<VarDeclBase> varDecl) { - // HACK: We need to intercept GLSL varying `in` and `out` here, way earlier - // in the process, so that we can avoid all kinds of nastiness that would - // otherwise be applied to them. - if (context->sourceLanguage == SourceLanguage::GLSL) - { - if (auto effectiveVaryingInputType = tryGetEffectiveTypeForGLSLVaryingInput(context, varDecl)) - { - collectGlobalScopeGLSLVaryingParameter(context, varDecl, effectiveVaryingInputType, kEntryPointParameterDirection_Input); - return; - } - - if (auto effectiveVaryingOutputType = tryGetEffectiveTypeForGLSLVaryingOutput(context, varDecl)) - { - collectGlobalScopeGLSLVaryingParameter(context, varDecl, effectiveVaryingOutputType, kEntryPointParameterDirection_Output); - return; - } - } - - // We use a single operation to both check whether the // variable represents a shader parameter, and to compute // the layout for that parameter's type. @@ -1354,7 +1209,7 @@ static void collectGlobalScopeParameter( // Now create a variable layout that we can use RefPtr<VarLayout> varLayout = new VarLayout(); varLayout->typeLayout = typeLayout; - varLayout->varDecl = DeclRef<Decl>(varDecl.Ptr(), nullptr).As<VarDeclBase>(); + varLayout->varDecl = DeclRef<Decl>(varDecl.Ptr(), nullptr).as<VarDeclBase>(); // This declaration may represent the same logical parameter // as a declaration that came from a different translation unit. @@ -1422,22 +1277,6 @@ static UInt allocateUnusedSpaces( return context->shared->usedSpaces.Allocate(nullptr, count); } -static RefPtr<UsedRangeSet> findUsedRangeSetForTranslationUnit( - ParameterBindingContext* context, - TranslationUnitRequest* translationUnit) -{ - if (!translationUnit) - return findUsedRangeSetForSpace(context, 0); - - RefPtr<UsedRangeSet> usedRangeSet; - if (context->shared->translationUnitUsedRangeSets.TryGetValue(translationUnit, usedRangeSet)) - return usedRangeSet; - - usedRangeSet = new UsedRangeSet(); - context->shared->translationUnitUsedRangeSets.Add(translationUnit, usedRangeSet); - return usedRangeSet; -} - static void addExplicitParameterBinding( ParameterBindingContext* context, RefPtr<ParameterInfo> parameterInfo, @@ -1486,15 +1325,15 @@ static void addExplicitParameterBinding( // need to grab a full space markSpaceUsed(context, semanticInfo.space); } - auto overlappedParameterInfo = usedRangeSet->usedResourceRanges[(int)semanticInfo.kind].Add( - parameterInfo, + auto overlappedVarLayout = usedRangeSet->usedResourceRanges[(int)semanticInfo.kind].Add( + parameterInfo->varLayouts[0], semanticInfo.index, semanticInfo.index + count); - if (overlappedParameterInfo) + if (overlappedVarLayout) { auto paramA = parameterInfo->varLayouts[0]->varDecl.getDecl(); - auto paramB = overlappedParameterInfo->varLayouts[0]->varDecl.getDecl(); + auto paramB = overlappedVarLayout->varDecl.getDecl(); getSink(context)->diagnose(paramA, Diagnostics::parameterBindingsOverlap, getReflectionName(paramA), @@ -1641,22 +1480,6 @@ static void addExplicitParameterBindings_GLSL( semanticInfo.index = attr->set; semanticInfo.space = 0; } - else if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::VertexInput)) != nullptr ) - { - // Try to find `location` binding - if(!findLayoutArg<GLSLLocationLayoutModifier>(varDecl, &semanticInfo.index)) - return; - - usedRangeSet = findUsedRangeSetForTranslationUnit(context, parameterInfo->translationUnit); - } - else if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::FragmentOutput)) != nullptr ) - { - // Try to find `location` binding - if(!findLayoutArg<GLSLLocationLayoutModifier>(varDecl, &semanticInfo.index)) - return; - - usedRangeSet = findUsedRangeSetForTranslationUnit(context, parameterInfo->translationUnit); - } else if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::SpecializationConstant)) != nullptr ) { // Try to find `constant_id` binding @@ -1697,21 +1520,16 @@ void generateParameterBindings( } // Generate the binding information for a shader parameter. -static void completeBindingsForParameter( +static void completeBindingsForParameterImpl( ParameterBindingContext* context, + RefPtr<VarLayout> firstVarLayout, + ParameterBindingInfo bindingInfos[kLayoutResourceKindCount], RefPtr<ParameterInfo> parameterInfo) { // For any resource kind used by the parameter // we need to update its layout information // to include a binding for that resource kind. // - // We will use the first declaration of the parameter as - // a stand-in for all the declarations, so it is important - // that earlier code has validated that the declarations - // "match". - - SLANG_RELEASE_ASSERT(parameterInfo->varLayouts.Count() != 0); - auto firstVarLayout = parameterInfo->varLayouts.First(); auto firstTypeLayout = firstVarLayout->typeLayout; // We need to deal with allocation of full register spaces first, @@ -1731,7 +1549,7 @@ static void completeBindingsForParameter( // has specified an explicit binding, since those won't // go into our contiguously allocated range. // - auto& bindingInfo = parameterInfo->bindingInfo[(int)kind]; + auto& bindingInfo = bindingInfos[(int)kind]; if( bindingInfo.count != 0 ) { continue; @@ -1798,7 +1616,7 @@ static void completeBindingsForParameter( // Did we already apply some explicit binding information // for this resource kind? auto kind = typeRes.kind; - auto& bindingInfo = parameterInfo->bindingInfo[(int)kind]; + auto& bindingInfo = bindingInfos[(int)kind]; if( bindingInfo.count != 0 ) { // If things have already been bound, our work is done. @@ -1892,26 +1710,53 @@ static void completeBindingsForParameter( // space. UInt space = context->shared->defaultSpace; - - RefPtr<UsedRangeSet> usedRangeSet; - switch (kind) - { - default: - usedRangeSet = findUsedRangeSetForSpace(context, space); - break; - - case LayoutResourceKind::VertexInput: - case LayoutResourceKind::FragmentOutput: - usedRangeSet = findUsedRangeSetForTranslationUnit(context, parameterInfo->translationUnit); - break; - } + RefPtr<UsedRangeSet> usedRangeSet = findUsedRangeSetForSpace(context, space); bindingInfo.count = count; - bindingInfo.index = usedRangeSet->usedResourceRanges[(int)kind].Allocate(parameterInfo, count.getFiniteValue()); - + bindingInfo.index = usedRangeSet->usedResourceRanges[(int)kind].Allocate(firstVarLayout, count.getFiniteValue()); bindingInfo.space = space; } } +} + +static void applyBindingInfoToParameter( + RefPtr<VarLayout> varLayout, + ParameterBindingInfo bindingInfos[kLayoutResourceKindCount]) +{ + for(auto k = 0; k < kLayoutResourceKindCount; ++k) + { + auto kind = LayoutResourceKind(k); + auto& bindingInfo = bindingInfos[k]; + + // skip resources we aren't consuming + if(bindingInfo.count == 0) + continue; + + // Add a record to the variable layout + auto varRes = varLayout->AddResourceInfo(kind); + varRes->space = (int) bindingInfo.space; + varRes->index = (int) bindingInfo.index; + } +} + +// Generate the binding information for a shader parameter. +static void completeBindingsForParameter( + ParameterBindingContext* context, + RefPtr<ParameterInfo> parameterInfo) +{ + // We will use the first declaration of the parameter as + // a stand-in for all the declarations, so it is important + // that earlier code has validated that the declarations + // "match". + + SLANG_RELEASE_ASSERT(parameterInfo->varLayouts.Count() != 0); + auto firstVarLayout = parameterInfo->varLayouts.First(); + + completeBindingsForParameterImpl( + context, + firstVarLayout, + parameterInfo->bindingInfo, + parameterInfo); // At this point we should have explicit binding locations chosen for // all the relevant resource kinds, so we can apply these to the @@ -1919,23 +1764,25 @@ static void completeBindingsForParameter( for(auto& varLayout : parameterInfo->varLayouts) { - for(auto k = 0; k < kLayoutResourceKindCount; ++k) - { - auto kind = LayoutResourceKind(k); - auto& bindingInfo = parameterInfo->bindingInfo[k]; - - // skip resources we aren't consuming - if(bindingInfo.count == 0) - continue; - - // Add a record to the variable layout - auto varRes = varLayout->AddResourceInfo(kind); - varRes->space = (int) bindingInfo.space; - varRes->index = (int) bindingInfo.index; - } + applyBindingInfoToParameter(varLayout, parameterInfo->bindingInfo); } } +static void completeBindingsForParameter( + ParameterBindingContext* context, + RefPtr<VarLayout> varLayout) +{ + ParameterBindingInfo bindingInfos[kLayoutResourceKindCount]; + completeBindingsForParameterImpl( + context, + varLayout, + bindingInfos, + nullptr); + applyBindingInfoToParameter(varLayout, bindingInfos); +} + + + static void collectGlobalScopeParameters( ParameterBindingContext* context, ModuleDecl* program) @@ -1950,12 +1797,12 @@ static void collectGlobalScopeParameters( // for generic types in the second pass. for (auto decl : program->Members) { - if (auto genParamDecl = decl.As<GlobalGenericParamDecl>()) + if (auto genParamDecl = as<GlobalGenericParamDecl>(decl)) collectGlobalGenericParameter(context, genParamDecl); } for (auto decl : program->Members) { - if (auto varDecl = decl.As<VarDeclBase>()) + if (auto varDecl = as<VarDeclBase>(decl)) collectGlobalScopeParameter(context, varDecl); } @@ -2121,7 +1968,7 @@ static RefPtr<TypeLayout> processSimpleEntryPointParameter( return typeLayout; } -static RefPtr<TypeLayout> processEntryPointParameterDecl( +static RefPtr<TypeLayout> processEntryPointVaryingParameterDecl( ParameterBindingContext* context, Decl* decl, RefPtr<Type> type, @@ -2159,23 +2006,18 @@ static RefPtr<TypeLayout> processEntryPointParameterDecl( // *or* we couldn't find an explicit semantic to apply on the given // declaration, so we will just recursive with whatever we have at // the moment. - return processEntryPointParameter(context, type, state, varLayout); + return processEntryPointVaryingParameter(context, type, state, varLayout); } -static RefPtr<TypeLayout> processEntryPointParameter( +static RefPtr<TypeLayout> processEntryPointVaryingParameter( ParameterBindingContext* context, RefPtr<Type> type, EntryPointParameterState const& state, RefPtr<VarLayout> varLayout) { - if (varLayout) - { - varLayout->stage = state.stage; - } - // The default handling of varying parameters should not apply // to geometry shader output streams; they have their own special rules. - if( auto gsStreamType = type->As<HLSLStreamOutputType>() ) + if( auto gsStreamType = as<HLSLStreamOutputType>(type) ) { // @@ -2192,7 +2034,7 @@ static RefPtr<TypeLayout> processEntryPointParameter( elementState.stage = state.stage; elementState.loc = state.loc; - auto elementTypeLayout = processEntryPointParameter(context, elementType, elementState, nullptr); + auto elementTypeLayout = processEntryPointVaryingParameter(context, elementType, elementState, nullptr); RefPtr<StreamOutputTypeLayout> typeLayout = new StreamOutputTypeLayout(); typeLayout->type = type; @@ -2294,21 +2136,21 @@ static RefPtr<TypeLayout> processEntryPointParameter( } // Scalar and vector types are treated as outputs directly - if(auto basicType = type->As<BasicExpressionType>()) + if(auto basicType = as<BasicExpressionType>(type)) { return processSimpleEntryPointParameter(context, basicType, state, varLayout); } - else if(auto vectorType = type->As<VectorExpressionType>()) + else if(auto vectorType = as<VectorExpressionType>(type)) { return processSimpleEntryPointParameter(context, vectorType, state, varLayout); } // A matrix is processed as if it was an array of rows - else if( auto matrixType = type->As<MatrixExpressionType>() ) + else if( auto matrixType = as<MatrixExpressionType>(type) ) { auto rowCount = GetIntVal(matrixType->getRowCount()); return processSimpleEntryPointParameter(context, matrixType, state, varLayout, (int) rowCount); } - else if( auto arrayType = type->As<ArrayExpressionType>() ) + else if( auto arrayType = as<ArrayExpressionType>(type) ) { // Note: Bad Things will happen if we have an array input // without a semantic already being enforced. @@ -2316,13 +2158,13 @@ static RefPtr<TypeLayout> processEntryPointParameter( auto elementCount = (UInt) GetIntVal(arrayType->ArrayLength); // We use the first element to derive the layout for the element type - auto elementTypeLayout = processEntryPointParameter(context, arrayType->baseType, state, varLayout); + auto elementTypeLayout = processEntryPointVaryingParameter(context, arrayType->baseType, state, varLayout); // We still walk over subsequent elements to make sure they consume resources // as needed for( UInt ii = 1; ii < elementCount; ++ii ) { - processEntryPointParameter(context, arrayType->baseType, state, nullptr); + processEntryPointVaryingParameter(context, arrayType->baseType, state, nullptr); } RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout(); @@ -2337,16 +2179,16 @@ static RefPtr<TypeLayout> processEntryPointParameter( return arrayTypeLayout; } // Ignore a bunch of types that don't make sense here... - else if (auto textureType = type->As<TextureType>()) { return nullptr; } - else if(auto samplerStateType = type->As<SamplerStateType>()) { return nullptr; } - else if(auto constantBufferType = type->As<ConstantBufferType>()) { return nullptr; } + else if (auto textureType = as<TextureType>(type)) { return nullptr; } + else if(auto samplerStateType = as<SamplerStateType>(type)) { return nullptr; } + else if(auto constantBufferType = as<ConstantBufferType>(type)) { return nullptr; } // Catch declaration-reference types late in the sequence, since // otherwise they will include all of the above cases... - else if( auto declRefType = type->As<DeclRefType>() ) + else if( auto declRefType = as<DeclRefType>(type) ) { auto declRef = declRefType->declRef; - if (auto structDeclRef = declRef.As<StructDecl>()) + if (auto structDeclRef = declRef.as<StructDecl>()) { RefPtr<StructTypeLayout> structLayout = new StructTypeLayout(); structLayout->type = type; @@ -2357,7 +2199,7 @@ static RefPtr<TypeLayout> processEntryPointParameter( RefPtr<VarLayout> fieldVarLayout = new VarLayout(); fieldVarLayout->varDecl = field; - auto fieldTypeLayout = processEntryPointParameterDecl( + auto fieldTypeLayout = processEntryPointVaryingParameterDecl( context, field.getDecl(), GetType(field), @@ -2384,7 +2226,7 @@ static RefPtr<TypeLayout> processEntryPointParameter( return structLayout; } - else if (auto globalGenericParam = declRef.As<GlobalGenericParamDecl>()) + else if (auto globalGenericParam = declRef.as<GlobalGenericParamDecl>()) { auto genParamTypeLayout = new GenericParamTypeLayout(); // we should have already populated ProgramLayout::genericEntryPointParams list at this point, @@ -2400,7 +2242,7 @@ static RefPtr<TypeLayout> processEntryPointParameter( } } // If we ran into an error in checking the user's code, then skip this parameter - else if( auto errorType = type->As<ErrorType>() ) + else if( auto errorType = as<ErrorType>(type) ) { return nullptr; } @@ -2409,6 +2251,210 @@ static RefPtr<TypeLayout> processEntryPointParameter( UNREACHABLE_RETURN(nullptr); } + /// Compute the type layout for a parameter declared directly on an entry point. +static RefPtr<TypeLayout> computeEntryPointParameterTypeLayout( + ParameterBindingContext* context, + SubstitutionSet typeSubst, + RefPtr<ParamDecl> paramDecl, + RefPtr<VarLayout> paramVarLayout, + EntryPointParameterState& state) +{ + auto paramType = paramDecl->type.type->Substitute(typeSubst).as<Type>(); + + if( paramDecl->HasModifier<HLSLUniformModifier>() ) + { + // An entry-point parameter that is explicitly marked `uniform` represents + // a uniform shader parameter passed via the implicitly-defined + // constant buffer (e.g., the `$Params` constant buffer seen in fxc/dxc output). + // + return CreateTypeLayout( + context->layoutContext.with(context->getRulesFamily()->getConstantBufferRules()), + paramType); + } + else + { + // The default case is a varying shader parameter, which could be used for + // input, output, or both. + // + // The varying case needs to not only compute a layout, but also assocaite + // "semantic" strings/indices with the varying parameters by recursively + // walking their structure. + + state.directionMask = 0; + + // If it appears to be an input, process it as such. + if( paramDecl->HasModifier<InModifier>() || paramDecl->HasModifier<InOutModifier>() || !paramDecl->HasModifier<OutModifier>() ) + { + state.directionMask |= kEntryPointParameterDirection_Input; + } + + // If it appears to be an output, process it as such. + if(paramDecl->HasModifier<OutModifier>() || paramDecl->HasModifier<InOutModifier>()) + { + state.directionMask |= kEntryPointParameterDirection_Output; + } + + return processEntryPointVaryingParameterDecl( + context, + paramDecl.Ptr(), + paramDecl->type.type->Substitute(typeSubst).as<Type>(), + state, + paramVarLayout); + } +} + +// There are multiple places where we need to compute the layout +// for a "scope" such as the global scope or an entry point. +// The `ScopeLayoutBuilder` encapsulates the logic around: +// +// * Doing layout for the ordinary/uniform fields, which involves +// using the `struct` layout rules for constant buffers on +// the target. +// +// * Creating a final type/var layout that reflects whether the +// scope needs a constant buffer to be allocated to it. +// +struct ScopeLayoutBuilder +{ + ParameterBindingContext* m_context = nullptr; + LayoutRulesImpl* m_rules = nullptr; + RefPtr<StructTypeLayout> m_structLayout; + UniformLayoutInfo m_structLayoutInfo; + bool m_needConstantBuffer = false; + + void beginLayout( + ParameterBindingContext* context) + { + m_context = context; + m_rules = context->getRulesFamily()->getConstantBufferRules(); + m_structLayout = new StructTypeLayout(); + m_structLayout->rules = m_rules; + + m_structLayoutInfo = m_rules->BeginStructLayout(); + } + + void _addParameter( + RefPtr<VarLayout> firstVarLayout, + ParameterInfo* parameterInfo) + { + // Does the parameter have any uniform data? + auto layoutInfo = firstVarLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform); + LayoutSize uniformSize = layoutInfo ? layoutInfo->count : 0; + if( uniformSize != 0 ) + { + m_needConstantBuffer = true; + + // Make sure uniform fields get laid out properly... + + UniformLayoutInfo fieldInfo( + uniformSize, + firstVarLayout->typeLayout->uniformAlignment); + + LayoutSize uniformOffset = m_rules->AddStructField( + &m_structLayoutInfo, + fieldInfo); + + if( parameterInfo ) + { + for( auto& varLayout : parameterInfo->varLayouts ) + { + varLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset.getFiniteValue(); + } + } + else + { + firstVarLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset.getFiniteValue(); + } + } + + m_structLayout->fields.Add(firstVarLayout); + + if( parameterInfo ) + { + for( auto& varLayout : parameterInfo->varLayouts ) + { + m_structLayout->mapVarToLayout.Add(varLayout->varDecl.getDecl(), varLayout); + } + } + else + { + m_structLayout->mapVarToLayout.Add(firstVarLayout->varDecl.getDecl(), firstVarLayout); + } + } + + void addParameter( + RefPtr<VarLayout> varLayout) + { + _addParameter(varLayout, nullptr); + } + + void addParameter( + ParameterInfo* parameterInfo) + { + SLANG_RELEASE_ASSERT(parameterInfo->varLayouts.Count() != 0); + auto firstVarLayout = parameterInfo->varLayouts.First(); + + _addParameter(firstVarLayout, parameterInfo); + } + + RefPtr<VarLayout> endLayout() + { + m_rules->EndStructLayout(&m_structLayoutInfo); + + RefPtr<TypeLayout> scopeTypeLayout = m_structLayout; + + // If the caller decided to allocate a constant buffer for + // the ordinary data, then we need to wrap up the structure + // type (layout) in a constant buffer type (layout). + // + if( m_needConstantBuffer ) + { + auto constantBufferLayout = createParameterGroupTypeLayout( + m_context->layoutContext, + nullptr, + m_rules, + m_rules->GetObjectLayout(ShaderParameterKind::ConstantBuffer), + m_structLayout); + + scopeTypeLayout = constantBufferLayout; + } + + // We now have a bunch of layout information, which we should + // record into a suitable object that represents the scope + RefPtr<VarLayout> scopeVarLayout = new VarLayout(); + scopeVarLayout->typeLayout = scopeTypeLayout; + return scopeVarLayout; + } +}; + + /// Helper routine to allocate a constant buffer binding if one is needed. + /// + /// This function primarily exists to encapsulate the logic for allocating + /// the resources required for a constant buffer in the appropriate + /// target-specific fashion. + /// +static ParameterBindingAndKindInfo maybeAllocateConstantBufferBinding( + ParameterBindingContext* context, + bool needConstantBuffer) +{ + if( !needConstantBuffer ) return ParameterBindingAndKindInfo(); + + UInt space = context->shared->defaultSpace; + auto usedRangeSet = findUsedRangeSetForSpace(context, space); + + auto layoutInfo = context->getRulesFamily()->getConstantBufferRules()->GetObjectLayout( + ShaderParameterKind::ConstantBuffer); + + ParameterBindingAndKindInfo info; + info.kind = layoutInfo.kind; + info.count = layoutInfo.size; + info.index = usedRangeSet->usedResourceRanges[(int)layoutInfo.kind].Allocate(nullptr, layoutInfo.size.getFiniteValue()); + info.space = space; + return info; +} + + /// Iterate over the parameters of an entry point to compute its requirements. + /// static void collectEntryPointParameters( ParameterBindingContext* context, EntryPointRequest* entryPoint, @@ -2420,100 +2466,138 @@ static void collectEntryPointParameters( // Something must have failed earlier, so that // we didn't find a declaration to match this // entry point request. + // return; } - // Create the layout object here - auto entryPointLayout = new EntryPointLayout(); + // We will take responsibility for creating and filling in + // the `EntryPointLayout` object here. + // + RefPtr<EntryPointLayout> entryPointLayout = new EntryPointLayout(); entryPointLayout->profile = entryPoint->profile; entryPointLayout->entryPoint = entryPointFuncDecl; - context->entryPointLayout = entryPointLayout; + // The entry point layout must be added to the output + // program layout so that it can be accessed by reflection. + // context->shared->programLayout->entryPoints.Add(entryPointLayout); + // For the duration of our parameter collection work we will + // establish this entry point as the current one in the context. + // + context->entryPointLayout = entryPointLayout; + // Note: this isn't really the best place for this logic to sit, - // but it is the simplest place where we have a direct correspondance + // but it is the simplest place where we have a direct correspondence // between a single `EntryPointRequest` and its matching `EntryPointLayout`, // so we'll use it. // for( auto taggedUnionType : entryPoint->taggedUnionTypes ) { - auto substType = taggedUnionType->Substitute(typeSubst).As<Type>(); + auto substType = taggedUnionType->Substitute(typeSubst).dynamicCast<Type>(); auto typeLayout = CreateTypeLayout(context->layoutContext, substType); entryPointLayout->taggedUnionTypeLayouts.Add(typeLayout); } - // Okay, we seemingly have an entry-point function, and now we need to collect info on its parameters too + // We are going to iterate over the entry-point parameters, + // and while we do so we will go ahead and perform layout/binding + // assignment for two cases: // - // TODO: Long-term we probably want complete information on all inputs/outputs of an entry point, - // but for now we are really just trying to scrape information on fragment outputs, so lets do that: + // First, the varying parameters of the entry point will have + // their semantics and locations assigned, so we set up state + // for tracking that layout. // - // TODO: check whether we should enumerate the parameters before the return type, or vice versa - int defaultSemanticIndex = 0; - EntryPointParameterState state; state.ioSemanticIndex = &defaultSemanticIndex; state.optSemanticName = nullptr; state.semanticSlotCount = 0; state.stage = entryPoint->getStage(); - for( auto m : entryPointFuncDecl->Members ) - { - auto paramDecl = m.As<VarDeclBase>(); - if(!paramDecl) - continue; + // Second, we will compute offsets for any "ordinary" data + // in the parameter list (e.g., a `uniform float4x4 mvp` parameter), + // which is what the `ScopeLayoutBuilder` is designed to help with. + // + ScopeLayoutBuilder scopeBuilder; + scopeBuilder.beginLayout(context); + auto paramsStructLayout = scopeBuilder.m_structLayout; - // We have an entry-point parameter, and need to figure out what to do with it. + for( auto paramDecl : entryPointFuncDecl->getMembersOfType<ParamDecl>() ) + { + // Any error messages we emit during the process should + // refer to the location of this parameter. + // state.loc = paramDecl->loc; - // TODO: need to handle `uniform`-qualified parameters here - if (paramDecl->HasModifier<HLSLUniformModifier>()) - continue; - - state.directionMask = 0; - - // If it appears to be an input, process it as such. - if( paramDecl->HasModifier<InModifier>() || paramDecl->HasModifier<InOutModifier>() || !paramDecl->HasModifier<OutModifier>() ) - { - state.directionMask |= kEntryPointParameterDirection_Input; - } - - // If it appears to be an output, process it as such. - if(paramDecl->HasModifier<OutModifier>() || paramDecl->HasModifier<InOutModifier>()) - { - state.directionMask |= kEntryPointParameterDirection_Output; - } - + // We are going to construct the variable layout for this + // parameter *before* computing the type layout, because + // the type layout computation is also determining the effective + // semantic of the parameter, which needs to be stored + // back onto the `VarLayout`. + // RefPtr<VarLayout> paramVarLayout = new VarLayout(); paramVarLayout->varDecl = makeDeclRef(paramDecl.Ptr()); + paramVarLayout->stage = state.stage; - auto paramTypeLayout = processEntryPointParameterDecl( + auto paramTypeLayout = computeEntryPointParameterTypeLayout( context, - paramDecl.Ptr(), - paramDecl->type.type->Substitute(typeSubst).As<Type>(), - state, - paramVarLayout); + typeSubst, + paramDecl, + paramVarLayout, + state); + paramVarLayout->typeLayout = paramTypeLayout; - // Skip parameters for which we could not compute a layout + // We expect to always be able to compute a layout for + // entry-point parameters, but to be defensive we will + // skip parameters that couldn't have a layout computed + // when assertions are disabled. + // + SLANG_ASSERT(paramTypeLayout); if(!paramTypeLayout) continue; - paramVarLayout->typeLayout = paramTypeLayout; + // Now that we've computed the layout to use for the parameter, + // we need to add its resource usage to that of the entry + // point as a whole. + // + // Any "ordinary" data (e.g., a `float4x4`) needs to be accounted + // for using the `ScopeLayoutBuilder`, since it will handle + // the details of target-specific `struct` type layout. + // + scopeBuilder.addParameter(paramVarLayout); - for (auto rr : paramTypeLayout->resourceInfos) + // All of the other resources types will be handled in a + // simpler loop that just increments the relevant counters. + // + for (auto paramTypeResInfo : paramTypeLayout->resourceInfos) { - auto entryPointRes = entryPointLayout->findOrAddResourceInfo(rr.kind); - paramVarLayout->findOrAddResourceInfo(rr.kind)->index = entryPointRes->count.getFiniteValue(); - entryPointRes->count += rr.count; - } + // We need to skip ordinary data because it is being + // handled by the `scopeBuilder`. + // + if(paramTypeResInfo.kind == LayoutResourceKind::Uniform) + continue; - entryPointLayout->fields.Add(paramVarLayout); - entryPointLayout->mapVarToLayout.Add(paramDecl, paramVarLayout); + // Whatever resources the parameter uses, we need to + // assign the parameter's location/register/binding offset to + // be the sum of everything added so far. + // + auto entryPointResInfo = paramsStructLayout->findOrAddResourceInfo(paramTypeResInfo.kind); + paramVarLayout->findOrAddResourceInfo(paramTypeResInfo.kind)->index = entryPointResInfo->count.getFiniteValue(); + + // We then need to add the resources consumed by the parameter + // to those consumed by the entry point. + // + entryPointResInfo->count += paramTypeResInfo.count; + } } + entryPointLayout->parametersLayout = scopeBuilder.endLayout(); - // If we have a non-`void` output type for the entry point, then process it as - // an output parameter. + // For an entry point with a non-`void` return type, we need to process the + // return type as a varying output parameter. + // + // TODO: Ideally we should make the layout process more robust to empty/void + // types and apply this logic unconditionally. + // auto resultType = entryPointFuncDecl->ReturnType.type; if( !resultType->Equals(resultType->getSession()->getVoidType()) ) { @@ -2521,11 +2605,12 @@ static void collectEntryPointParameters( state.directionMask = kEntryPointParameterDirection_Output; RefPtr<VarLayout> resultLayout = new VarLayout(); + resultLayout->stage = state.stage; - auto resultTypeLayout = processEntryPointParameterDecl( + auto resultTypeLayout = processEntryPointVaryingParameterDecl( context, entryPointFuncDecl, - resultType->Substitute(typeSubst).As<Type>(), + resultType->Substitute(typeSubst).dynamicCast<Type>(), state, resultLayout); @@ -2535,7 +2620,7 @@ static void collectEntryPointParameters( for (auto rr : resultTypeLayout->resourceInfos) { - auto entryPointRes = entryPointLayout->findOrAddResourceInfo(rr.kind); + auto entryPointRes = paramsStructLayout->findOrAddResourceInfo(rr.kind); resultLayout->findOrAddResourceInfo(rr.kind)->index = entryPointRes->count.getFiniteValue(); entryPointRes->count += rr.count; } @@ -2687,6 +2772,7 @@ void generateParameterBindings( context.shared = &sharedContext; context.translationUnit = nullptr; context.layoutContext = layoutContext; + // Walk through AST to discover all the parameters collectParameters(&context, compileReq); @@ -2787,24 +2873,10 @@ void generateParameterBindings( // If there are any global-scope uniforms, then we need to // allocate a constant-buffer binding for them here. - ParameterBindingInfo globalConstantBufferBinding; - globalConstantBufferBinding.index = 0; - globalConstantBufferBinding.space = 0; - if( needDefaultConstantBuffer ) - { - // TODO: this logic is only correct for D3D targets, where - // global-scope uniforms get wrapped into a constant buffer. - - UInt space = sharedContext.defaultSpace; - auto usedRangeSet = findUsedRangeSetForSpace(&context, space); - - globalConstantBufferBinding.index = - usedRangeSet->usedResourceRanges[ - (int)LayoutResourceKind::ConstantBuffer].Allocate(nullptr, 1); - - globalConstantBufferBinding.space = space; - } - + // + ParameterBindingAndKindInfo globalConstantBufferBinding = maybeAllocateConstantBufferBinding( + &context, + needDefaultConstantBuffer); // Now walk through again to actually give everything // ranges of registers... @@ -2813,135 +2885,200 @@ void generateParameterBindings( completeBindingsForParameter(&context, parameter); } - // TODO: need to deal with parameters declared inside entry-point - // parameter lists at some point... - - - // Next we need to create a type layout to reflect the information - // we have collected. - - // We will lay out any bare uniforms at the global scope into - // a single constant buffer. This is appropriate for HLSL global-scope - // uniforms, and Vulkan GLSL doesn't allow uniforms at global scope, - // so it should work out. + // After we have allocated registers/bindings to everything + // in the global scope we will process the parameters + // of each entry point in order. // - // For legacy GLSL targets, we'd probably need a distinct resource - // kind and set of rules here, since legacy uniforms are not the - // same as the contents of a constant buffer. - auto globalScopeRules = context.getRulesFamily()->getConstantBufferRules(); - - RefPtr<StructTypeLayout> globalScopeStructLayout = new StructTypeLayout(); - globalScopeStructLayout->rules = globalScopeRules; - - UniformLayoutInfo structLayoutInfo = globalScopeRules->BeginStructLayout(); - for( auto& parameterInfo : sharedContext.parameters ) + // Note: the effect of the current implemetnation is to + // allocate non-overlapping registers/bindings between all + // the entry points in the compile request (e.g., if you + // have a vertex and fragment shader being compiled together, + // we will allocate distinct constant buffer registers for + // their uniform parameters). + // + // TODO: We probably need to provide some more nuanced control + // over whether entry points get overlapping or non-overlapping + // bindings. It seems clear that if we were compiling multiple + // compute kernels in one invocation we'd want them to get + // overlapping bindings, because we cannot ever have them bound + // together in a single pipeline state. + // + // Similarly, entry point parameters of DirectX Raytracing (DXR) + // shaders should probably be allowed to overlap by default, + // since those parameters should really go into the "local root signature." + // (Note: there is a bit more subtlety around ray tracing + // shaders that will be assembled into a "hit group") + // + // For now we are just doing the simplest thing, which will be + // appropriate for: + // + // * Compiling a single compute shader in a compile request. + // * Compiling some number of rasterization shader entry points + // in a single request, to be used together. + // * Compiling a single ray-tracing shader in a compile request. + // + for( auto entryPoint : sharedContext.programLayout->entryPoints ) { - SLANG_RELEASE_ASSERT(parameterInfo->varLayouts.Count() != 0); - auto firstVarLayout = parameterInfo->varLayouts.First(); - - // Does the field have any uniform data? - auto layoutInfo = firstVarLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform); - LayoutSize uniformSize = layoutInfo ? layoutInfo->count : 0; - if( uniformSize != 0 ) - { - // Make sure uniform fields get laid out properly... - - UniformLayoutInfo fieldInfo( - uniformSize, - firstVarLayout->typeLayout->uniformAlignment); - - LayoutSize uniformOffset = globalScopeRules->AddStructField( - &structLayoutInfo, - fieldInfo); - - for( auto& varLayout : parameterInfo->varLayouts ) - { - varLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset.getFiniteValue(); - } - } - - globalScopeStructLayout->fields.Add(firstVarLayout); - - for( auto& varLayout : parameterInfo->varLayouts ) - { - globalScopeStructLayout->mapVarToLayout.Add(varLayout->varDecl.getDecl(), varLayout); - } + auto entryPointParamsLayout = entryPoint->parametersLayout; + completeBindingsForParameter(&context, entryPointParamsLayout); } - globalScopeRules->EndStructLayout(&structLayoutInfo); - RefPtr<TypeLayout> globalScopeLayout = globalScopeStructLayout; - - // If there are global-scope uniforms, then we need to wrap - // up a global constant buffer type layout to hold them - if( needDefaultConstantBuffer ) + // Next we need to create a type layout to reflect the information + // we have collected, and we will use the `ScopeLayoutBuilder` + // to encapsulate the logic that can be shared with the entry-point + // case. + // + ScopeLayoutBuilder globalScopeLayoutBuilder; + globalScopeLayoutBuilder.beginLayout(&context); + for( auto& parameterInfo : sharedContext.parameters ) { - auto globalConstantBufferLayout = createParameterGroupTypeLayout( - layoutContext, - nullptr, - globalScopeRules, - globalScopeRules->GetObjectLayout(ShaderParameterKind::ConstantBuffer), - globalScopeStructLayout); - - globalScopeLayout = globalConstantBufferLayout; + globalScopeLayoutBuilder.addParameter(parameterInfo); } - // We now have a bunch of layout information, which we should - // record into a suitable object that represents the program - RefPtr<VarLayout> globalVarLayout = new VarLayout(); - globalVarLayout->typeLayout = globalScopeLayout; - if (needDefaultConstantBuffer) + auto globalScopeVarLayout = globalScopeLayoutBuilder.endLayout(); + if( globalConstantBufferBinding.count != 0 ) { - auto cbInfo = globalVarLayout->findOrAddResourceInfo(LayoutResourceKind::ConstantBuffer); + auto cbInfo = globalScopeVarLayout->findOrAddResourceInfo(globalConstantBufferBinding.kind); cbInfo->space = globalConstantBufferBinding.space; cbInfo->index = globalConstantBufferBinding.index; } - programLayout->globalScopeLayout = globalVarLayout; + programLayout->parametersLayout = globalScopeVarLayout; } -StructTypeLayout* getGlobalStructLayout( - ProgramLayout* programLayout); - RefPtr<ProgramLayout> specializeProgramLayout( - TargetRequest * targetReq, - ProgramLayout* programLayout, + TargetRequest* targetReq, + ProgramLayout* oldProgramLayout, SubstitutionSet typeSubst) { + // The goal of the layout specialization step is to take an existing `ProgramLayout`, + // and add a layout to any parameter(s) that could not be laid out previously, because + // they had a dependence on generic type parameters that made layout impossible at + // the time. + // + // TODO: It would be far simpler to just "re-do" the entire layout process, just + // with knowledge of what the global type substitution is, but that would mean that + // global parameters that come after a generic-dependent parameter might change + // their location/binding/register depending on what types are plugged in. + // Our current design preserves the layout for any global parameter that was placed during + // the initial layout of a program (before the generic arguments were know). + // It isn't clear that this design choice pays off in practice, since there is lot + // of complexity in this function. + RefPtr<ProgramLayout> newProgramLayout; newProgramLayout = new ProgramLayout(); newProgramLayout->targetRequest = targetReq; - newProgramLayout->globalGenericParams = programLayout->globalGenericParams; - - List<RefPtr<TypeLayout>> paramTypeLayouts; - auto globalStructLayout = getGlobalStructLayout(programLayout); - SLANG_ASSERT(globalStructLayout); - RefPtr<StructTypeLayout> structLayout = new StructTypeLayout(); - RefPtr<TypeLayout> globalScopeLayout = structLayout; - structLayout->uniformAlignment = globalStructLayout->uniformAlignment; - - // Try to find rules based on the selected code-generation target - auto layoutContext = getInitialLayoutContextForTarget(targetReq); + newProgramLayout->globalGenericParams = oldProgramLayout->globalGenericParams; - // If there was no target, or there are no rules for the target, - // then bail out here. - if (!layoutContext.rules) - return newProgramLayout; + // The basic idea will be to iterate over the parameters in the old layout, + // and "pick up where we left off" in terms of allocating registers to things. + // + // That means we will look at the existing parameters (that were laid out already) + // and mark any registers/bytes/bindings/etc. that they occupy as "used" so + // that the subsequent layout of the generic-dependency parameters will not + // collide with them. + // + // We will use the same kind of context type as the original parameter binding + // step did, so we initialize its state here: + + auto layoutContext = getInitialLayoutContextForTarget(targetReq); + SLANG_ASSERT(layoutContext.rules); - - // we need to initialize a layout context to mark used registers SharedParameterBindingContext sharedContext; sharedContext.compileRequest = targetReq->compileRequest; sharedContext.defaultLayoutRules = layoutContext.getRulesFamily(); sharedContext.programLayout = newProgramLayout; sharedContext.targetRequest = targetReq; - // Create a sub-context to collect parameters that get - // declared into the global scope ParameterBindingContext context; context.shared = &sharedContext; context.translationUnit = nullptr; context.layoutContext = layoutContext; - - + + // We will also need state for laying out any global-scope parameters + // that include ordinary/uniform data. + // + auto oldGlobalStructLayout = getGlobalStructLayout(oldProgramLayout); + SLANG_ASSERT(oldGlobalStructLayout); + + ScopeLayoutBuilder newGlobalScopeLayoutBuilder; + newGlobalScopeLayoutBuilder.beginLayout(&context); + auto& newGlobalStructLayoutInfo = newGlobalScopeLayoutBuilder.m_structLayoutInfo; + auto newGlobalStructLayout = newGlobalScopeLayoutBuilder.m_structLayout; + + // The initial state for uniform layout will be based on whatever + // global-scope ordinary/uniform parameters were laid out before. + // The alignment can be read directly from the old global layout. + // + newGlobalStructLayoutInfo.alignment = oldGlobalStructLayout->uniformAlignment; + newGlobalStructLayoutInfo.size = 0; + + // The remaining information needs to be collected by looking at + // the individual parameters in the existing layout. + // + bool oldAnyUniforms = false; + for(auto oldVarLayout : oldGlobalStructLayout->fields) + { + // If a parameter made use of a global generic parameter, then we would + // have skipped applying layout to it in the original layout process, + // and so we should skip it for the process of recovering the existing + // layout information. + // + if (oldVarLayout->FindResourceInfo(LayoutResourceKind::GenericResource)) + continue; + + // Otherwise, we will "reserve" any resources that the parameter was + // determined to consume. + // + // The easy case is any registers/bindings used for textures/sampler/etc. + // We iterate over the kinds of resources consumed by teh parameter. + // + for( auto varResInfo : oldVarLayout->resourceInfos ) + { + // For each kind of resource consumed the `varResInfo` will tell us + // the start of the consumed range, whle the type will be needed + // to tell us the amount of resources consumed. + // + if( auto typeResInfo = oldVarLayout->typeLayout->FindResourceInfo(varResInfo.kind) ) + { + // We will mark the range of resources consumed by theis parameter + // as "used" so that it cannot be claimed by later parameters. + // + auto usedRangeSet = findUsedRangeSetForSpace(&context, varResInfo.space); + markSpaceUsed(&context, varResInfo.space); + usedRangeSet->usedResourceRanges[(int)varResInfo.kind].Add( + nullptr, // we don't need to track parameter info here + varResInfo.index, + varResInfo.index + typeResInfo->count); + } + } + + // The more subtle case is when the parameter consumes ordinary bytes + // of uniform (constant buffer) memory, because we do not use the + // same "used range" model to allocate space for ordinary data. + // + // Instead, we simply track the highest byte offset covered by any parameter. + // + if (auto varUniformInfo = oldVarLayout->FindResourceInfo(LayoutResourceKind::Uniform)) + { + oldAnyUniforms = true; + + if( auto typeUniformInfo = oldVarLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) + { + newGlobalStructLayoutInfo.size = maximum( + newGlobalStructLayoutInfo.size, + varUniformInfo->index + typeUniformInfo->count); + } + } + } + + // Rather than attempt to re-use the entry-point layout information + // that was collected in the first pass, we will re-collect the + // information for entry points from scratch. + // + // This ensures that when an entry point makes use of a generic type + // parameter, the layout of its parameter list strictly follows + // the declaration order. + // for (auto & translationUnit : targetReq->compileRequest->translationUnits) { for (auto & entryPoint : translationUnit->entryPoints) @@ -2951,137 +3088,145 @@ RefPtr<ProgramLayout> specializeProgramLayout( context.entryPointLayout = nullptr; } - auto constantBufferRules = context.getRulesFamily()->getConstantBufferRules(); - structLayout->rules = constantBufferRules; - structLayout->fields.SetSize(globalStructLayout->fields.Count()); - UniformLayoutInfo structLayoutInfo; - structLayoutInfo.alignment = globalStructLayout->uniformAlignment; - structLayoutInfo.size = 0; - bool anyUniforms = false; - Dictionary<RefPtr<VarLayout>, RefPtr<VarLayout>> varLayoutMapping; - for (uint32_t varId = 0; varId < globalStructLayout->fields.Count(); varId++) + // Now that we've marked thing as being used, we can make a second + // sweep to compute the requirements of any generic-dependent parameters. + // + // Along the way we will build up the new layout for the global-scope + // structure type, including the offsets of all ordinary/uniform fields. + // + + bool newAnyUniforms = oldAnyUniforms; + List<RefPtr<VarLayout>> newVarLayouts; + Dictionary<RefPtr<VarLayout>, RefPtr<VarLayout>> mapOldLayoutToNew; + for(auto oldVarLayout : oldGlobalStructLayout->fields) { - auto &varLayout = globalStructLayout->fields[varId]; - // To recover layout context, we skip generic resources in the first pass - if (varLayout->FindResourceInfo(LayoutResourceKind::GenericResource)) + // In this pass, the variables that *don't* depend on generic parameters + // are the easy ones to handle. We can just copy them over to the new layout. + // + if(!oldVarLayout->FindResourceInfo(LayoutResourceKind::GenericResource)) + { + newGlobalStructLayout->fields.Add(oldVarLayout); continue; + } - if (auto uniformInfo = varLayout->FindResourceInfo(LayoutResourceKind::Uniform)) - { - anyUniforms = true; + // In the case where things are generic-dependent, we need to re-do + // the type layout process on the type that results from doing + // substutition with the global generic arguments. + // + RefPtr<Type> oldType = oldVarLayout->getTypeLayout()->getType(); + RefPtr<Type> newType = oldType->Substitute(typeSubst).as<Type>(); - if( auto tUniformInfo = varLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) - { - structLayoutInfo.size = maximum(structLayoutInfo.size, uniformInfo->index + tUniformInfo->count); - } - } - for( auto resInfo : varLayout->resourceInfos ) + RefPtr<TypeLayout> newTypeLayout = getTypeLayoutForGlobalShaderParameter( + &context, + oldVarLayout->varDecl, + newType); + + RefPtr<VarLayout> newVarLayout = new VarLayout(); + newVarLayout->varDecl = oldVarLayout->varDecl; + newVarLayout->stage = oldVarLayout->stage; + newVarLayout->typeLayout = newTypeLayout; + + newGlobalScopeLayoutBuilder.addParameter(newVarLayout); + newVarLayouts.Add(newVarLayout); + mapOldLayoutToNew.Add(oldVarLayout, newVarLayout); + + if(auto uniformInfo = newTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform)) { - if( auto tresInfo = varLayout->typeLayout->FindResourceInfo(resInfo.kind) ) + if(uniformInfo->count != 0) { - auto usedRangeSet = findUsedRangeSetForSpace(&context, resInfo.space); - markSpaceUsed(&context, resInfo.space); - usedRangeSet->usedResourceRanges[(int)resInfo.kind].Add( - nullptr, // we don't need to track parameter info here - resInfo.index, - resInfo.index + tresInfo->count); + newAnyUniforms = true; + diagnoseGlobalUniform(&sharedContext, newVarLayout->varDecl); } } - structLayout->fields[varId] = varLayout; - varLayoutMapping[varLayout] = varLayout; } - auto originalGlobalCBufferInfo = programLayout->globalScopeLayout->FindResourceInfo(LayoutResourceKind::ConstantBuffer); - VarLayout::ResourceInfo globalCBufferInfo; - globalCBufferInfo.kind = LayoutResourceKind::None; - globalCBufferInfo.space = 0; - globalCBufferInfo.index = 0; - if (originalGlobalCBufferInfo) + auto newGlobalScopeVarLayout = newGlobalScopeLayoutBuilder.endLayout(); + + // We had better have made a copy of every field in the original layout. + // + SLANG_ASSERT(oldGlobalStructLayout->fields.Count() == newGlobalStructLayout->fields.Count()); + + // If there were no global-scope uniforms before, but there + // are now that we've done global substitution, then we + // need to allocate a global constant buffer to hold them. + // + auto newGlobalConstantBufferBinding = maybeAllocateConstantBufferBinding(&context, newAnyUniforms && !oldAnyUniforms); + + // Now we need to "complete" finding for each of the new parameters, + // which is the step that actually allocates resource to them. + // + // Note: we don't support generic-dependent parameters with explicit bindings, + // so we should probably emit an error message about that in the original + // layout step. + // + for(auto newVarLayout : newVarLayouts) { - globalCBufferInfo.kind = LayoutResourceKind::ConstantBuffer; - globalCBufferInfo.space = originalGlobalCBufferInfo->space; - globalCBufferInfo.index = originalGlobalCBufferInfo->index; + completeBindingsForParameter(&context, newVarLayout); } - // we have the context restored, can continue to layout the generic variables now - for (uint32_t varId = 0; varId < globalStructLayout->fields.Count(); varId++) + + // One remaining missing step is that the `StructLayout` type maintains + // a map from variable declarations to their layouts, and in some cases + // multiple declarations will map to the same layout (because, e.g., the + // same `cbuffer` was declared in both a vertex and fragment shader file). + // + // We need to clone that remapping information over from the old program + // layout. This is why we created the `mapOldLayoutToNew` mapping in + // the preceding loop. + // + // TODO: This step would be easier if the `StructLayout::mapVarToLayout` + // dictionary were instead a mapping from variable declaration to the + // *index* of the corresponding layout in the `fields` array. + // + for(auto entry : oldGlobalStructLayout->mapVarToLayout) { - auto &varLayout = globalStructLayout->fields[varId]; - if (varLayout->typeLayout->FindResourceInfo(LayoutResourceKind::GenericResource)) - { - RefPtr<Type> newType = varLayout->typeLayout->type->Substitute(typeSubst).As<Type>(); - RefPtr<TypeLayout> newTypeLayout = CreateTypeLayout( - layoutContext.with(constantBufferRules), - newType); - auto layoutInfo = newTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform); - LayoutSize uniformSize = layoutInfo ? layoutInfo->count : 0; - if (uniformSize != 0) - { - if (globalCBufferInfo.kind == LayoutResourceKind::None) - { - // user defined a uniform via a global generic type argument - // but we have not reserved a binding for the global uniform buffer - UInt space = 0; - auto usedRangeSet = findUsedRangeSetForSpace(&context, space); - globalCBufferInfo.kind = LayoutResourceKind::ConstantBuffer; - globalCBufferInfo.index = - usedRangeSet->usedResourceRanges[ - (int)LayoutResourceKind::ConstantBuffer].Allocate(nullptr, 1); - globalCBufferInfo.space = space; - } - } - RefPtr<VarLayout> newVarLayout = new VarLayout(); - RefPtr<ParameterInfo> paramInfo = new ParameterInfo(); - newVarLayout->varDecl = varLayout->varDecl; - newVarLayout->stage = varLayout->stage; - newVarLayout->typeLayout = newTypeLayout; - paramInfo->varLayouts.Add(newVarLayout); - completeBindingsForParameter(&context, paramInfo); - // update uniform layout - - if (uniformSize != 0) - { - // Make sure uniform fields get laid out properly... - UniformLayoutInfo fieldInfo( - uniformSize, - newTypeLayout->uniformAlignment); - LayoutSize uniformOffset = layoutContext.getRulesFamily()->getConstantBufferRules()->AddStructField( - &structLayoutInfo, - fieldInfo); - newVarLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset.getFiniteValue(); - anyUniforms = true; - - diagnoseGlobalUniform(&sharedContext, varLayout->varDecl); - } - structLayout->fields[varId] = newVarLayout; - varLayoutMapping[varLayout] = newVarLayout; - } + RefPtr<VarLayout> varLayout = entry.Value; + mapOldLayoutToNew.TryGetValue(varLayout, varLayout); + newGlobalStructLayout->mapVarToLayout[entry.Key] = varLayout; } - for (auto mapping : globalStructLayout->mapVarToLayout) + + // Just as for the initial computation of layout, we will complete + // binding for entry-point parameters *after* we have laid out + // all the global-scope parameters. + // + // Note that this includes layout of generic-dependent global scope + // parameters, so it is possible for entry point uniform parameters + // to end up with a different register/binding after generic specialization. + // (There really isn't a great way around that) + // + for( auto entryPoint : sharedContext.programLayout->entryPoints ) { - RefPtr<VarLayout> updatedVarLayout = mapping.Value; - varLayoutMapping.TryGetValue(updatedVarLayout, updatedVarLayout); - structLayout->mapVarToLayout[mapping.Key] = updatedVarLayout; + auto entryPointParamsLayout = entryPoint->parametersLayout; + completeBindingsForParameter(&context, entryPointParamsLayout); } - // If there are global-scope uniforms, then we need to wrap - // up a global constant buffer type layout to hold them - RefPtr<VarLayout> globalVarLayout = new VarLayout(); - if (anyUniforms) + // As a last step we need to set up the binding/offset information + // for the global scope itself. + // + // We will start by copying whatever information was in the old layout. + // { - auto globalConstantBufferLayout = createParameterGroupTypeLayout( - layoutContext, - nullptr, - constantBufferRules, - constantBufferRules->GetObjectLayout(ShaderParameterKind::ConstantBuffer), - structLayout); + auto oldGlobalScopeVarLayout = oldProgramLayout->parametersLayout; + for( auto oldResInfo : oldGlobalScopeVarLayout->resourceInfos ) + { + auto newResInfo = newGlobalScopeVarLayout->findOrAddResourceInfo(oldResInfo.kind); + newResInfo->space = oldResInfo.space; + newResInfo->kind = oldResInfo.kind; + } + } - globalScopeLayout = globalConstantBufferLayout; - auto cbInfo = globalVarLayout->findOrAddResourceInfo(LayoutResourceKind::ConstantBuffer); - *cbInfo = globalCBufferInfo; + // If we had to create a constant buffer to house the global-scope + // ordinary/uniform data, then we need to make sure to set that + // information on the global scope. + // + if(newGlobalConstantBufferBinding.kind != LayoutResourceKind::None ) + { + auto resInfo = newGlobalScopeVarLayout->findOrAddResourceInfo(newGlobalConstantBufferBinding.kind); + resInfo->space = newGlobalConstantBufferBinding.space; + resInfo->index = newGlobalConstantBufferBinding.index; } - globalVarLayout->typeLayout = globalScopeLayout; - programLayout->globalScopeLayout = globalVarLayout; - newProgramLayout->globalScopeLayout = globalVarLayout; + + newProgramLayout->parametersLayout = newGlobalScopeVarLayout; return newProgramLayout; } -} + +} // namespace Slang diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index 914ec0a23..4c96929d8 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -21,7 +21,7 @@ namespace Slang void add(Modifier* modifier) { // Doesn't handle SharedModifiers - SLANG_ASSERT(modifier->As<SharedModifiers>() == nullptr); + SLANG_ASSERT(as<SharedModifiers>(modifier) == nullptr); // Splice at end *m_next = modifier; @@ -33,7 +33,7 @@ namespace Slang Modifier* cur = m_result; while (cur) { - T* castCur = cur->As<T>(); + T* castCur = as<T>(cur); if (castCur) { return castCur; @@ -619,7 +619,7 @@ namespace Slang // About to look at shared modifiers? Done. RefPtr<Modifier> linkMod = *modifierLink; - if(linkMod.As<SharedModifiers>()) + if(as<SharedModifiers>(linkMod)) { break; } @@ -821,7 +821,7 @@ namespace Slang auto keywordToken = advanceToken(parser); RefPtr<RefObject> parsedObject = syntaxDecl->parseCallback(parser, syntaxDecl->parseUserData); - auto syntax = parsedObject.As<T>(); + auto syntax = dynamicCast<T>(parsedObject); if (syntax) { @@ -1701,7 +1701,7 @@ namespace Slang RefPtr<Expr> base) { Name * baseName = nullptr; - if (auto varExpr = base->As<VarExpr>()) + if (auto varExpr = as<VarExpr>(base)) baseName = varExpr->name; // if base is a known generics, parse as generics if (baseName && isGenericName(parser, baseName)) @@ -1800,14 +1800,14 @@ namespace Slang // // TODO: We should really make these keywords be registered like any other // syntax category, rather than be special-cased here. The main issue here - // is that we need to allow them to be used as type specififers, as in: + // is that we need to allow them to be used as type specifiers, as in: // // struct Foo { int x } foo; // // The ideal answer would be to register certain keywords as being able - // to parse a type specififer, and look for those keywords here. + // to parse a type specifier, and look for those keywords here. // We should ideally add special case logic that bails out of declarator - // parsing iff we have one of these kinds of type specififers and the + // parsing iff we have one of these kinds of type specifiers and the // closing `}` is at the end of its line, as a bit of a special case // to allow the common idiom. // @@ -2997,7 +2997,7 @@ namespace Slang // then we really want the modifiers to apply to the inner declaration. // RefPtr<Decl> declToModify = decl; - if(auto genericDecl = decl.As<GenericDecl>()) + if(auto genericDecl = decl.as<GenericDecl>()) declToModify = genericDecl->inner; AddModifiers(declToModify.Ptr(), modifiers.first); @@ -3024,7 +3024,7 @@ namespace Slang // A declaration that starts with an identifier might be: // // - A keyword-based declaration (e.g., `cbuffer ...`) - // - The begining of a type in a declarator-based declaration (e.g., `int ...`) + // - The beginning of a type in a declarator-based declaration (e.g., `int ...`) // - A GLSL block declaration (e.g., `uniform Foo { ... }`) // Let's deal with the GLSL block case first. This is something like: @@ -3084,11 +3084,11 @@ namespace Slang if (decl) { - if( auto dd = decl.As<Decl>() ) + if( auto dd = as<Decl>(decl) ) { CompleteDecl(parser, dd, containerDecl, modifiers); } - else if(auto declGroup = decl.As<DeclGroup>()) + else if(auto declGroup = as<DeclGroup>(decl)) { // We are going to add the same modifiers to *all* of these declarations, // so we want to give later passes a way to detect which modifiers @@ -3122,11 +3122,11 @@ namespace Slang auto declBase = ParseDecl(parser, containerDecl); if(!declBase) return nullptr; - if( auto decl = declBase.As<Decl>() ) + if( auto decl = as<Decl>(declBase) ) { return decl; } - else if( auto declGroup = declBase.As<DeclGroup>() ) + else if( auto declGroup = as<DeclGroup>(declBase) ) { if( declGroup->decls.Count() == 1 ) { @@ -3496,7 +3496,7 @@ namespace Slang statement = ParseExpressionStatement(); } - if (statement && !statement->As<DeclStmt>()) + if (statement && !as<DeclStmt>(statement)) { // Install any modifiers onto the statement. // Note: this path is bypassed in the case of a @@ -3531,7 +3531,7 @@ namespace Slang { body = stmt; } - else if (auto seqStmt = body.As<SeqStmt>()) + else if (auto seqStmt = as<SeqStmt>(body)) { seqStmt->stmts.Add(stmt); } diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp index 9a5a5faf9..ce42cf10d 100644 --- a/source/slang/reflection.cpp +++ b/source/slang/reflection.cpp @@ -159,7 +159,7 @@ SLANG_API SlangResult spReflectionUserAttribute_GetArgumentValueInt(SlangReflect RefPtr<RefObject> val; if (userAttr->intArgVals.TryGetValue(index, val)) { - *rs = (int)val.As<ConstantIntVal>()->value; + *rs = (int)as<ConstantIntVal>(val)->value; return 0; } return SLANG_ERROR_INVALID_PARAMETER; @@ -169,7 +169,7 @@ SLANG_API SlangResult spReflectionUserAttribute_GetArgumentValueFloat(SlangRefle auto userAttr = convert(attrib); if (!userAttr) return SLANG_ERROR_INVALID_PARAMETER; if (index >= userAttr->args.Count()) return SLANG_ERROR_INVALID_PARAMETER; - if (auto cexpr = userAttr->args[index].As<FloatingPointLiteralExpr>()) + if (auto cexpr = as<FloatingPointLiteralExpr>(userAttr->args[index])) { *rs = (float)cexpr->value; return 0; @@ -181,7 +181,7 @@ SLANG_API const char* spReflectionUserAttribute_GetArgumentValueString(SlangRefl auto userAttr = convert(attrib); if (!userAttr) return nullptr; if (index >= userAttr->args.Count()) return nullptr; - if (auto cexpr = userAttr->args[index].As<StringLiteralExpr>()) + if (auto cexpr = as<StringLiteralExpr>(userAttr->args[index])) { if (bufLen) *bufLen = cexpr->token.Content.size(); @@ -202,49 +202,49 @@ SLANG_API SlangTypeKind spReflectionType_GetKind(SlangReflectionType* inType) // TODO(tfoley: Don't emit the same type more than once... - if (auto basicType = type->As<BasicExpressionType>()) + if (auto basicType = as<BasicExpressionType>(type)) { return SLANG_TYPE_KIND_SCALAR; } - else if (auto vectorType = type->As<VectorExpressionType>()) + else if (auto vectorType = as<VectorExpressionType>(type)) { return SLANG_TYPE_KIND_VECTOR; } - else if (auto matrixType = type->As<MatrixExpressionType>()) + else if (auto matrixType = as<MatrixExpressionType>(type)) { return SLANG_TYPE_KIND_MATRIX; } - else if (auto parameterBlockType = type->As<ParameterBlockType>()) + else if (auto parameterBlockType = as<ParameterBlockType>(type)) { return SLANG_TYPE_KIND_PARAMETER_BLOCK; } - else if (auto constantBufferType = type->As<ConstantBufferType>()) + else if (auto constantBufferType = as<ConstantBufferType>(type)) { return SLANG_TYPE_KIND_CONSTANT_BUFFER; } - else if( auto streamOutputType = type->As<HLSLStreamOutputType>() ) + else if( auto streamOutputType = as<HLSLStreamOutputType>(type) ) { return SLANG_TYPE_KIND_OUTPUT_STREAM; } - else if (type->As<TextureBufferType>()) + else if (as<TextureBufferType>(type)) { return SLANG_TYPE_KIND_TEXTURE_BUFFER; } - else if (type->As<GLSLShaderStorageBufferType>()) + else if (as<GLSLShaderStorageBufferType>(type)) { return SLANG_TYPE_KIND_SHADER_STORAGE_BUFFER; } - else if (auto samplerStateType = type->As<SamplerStateType>()) + else if (auto samplerStateType = as<SamplerStateType>(type)) { return SLANG_TYPE_KIND_SAMPLER_STATE; } - else if (auto textureType = type->As<TextureTypeBase>()) + else if (auto textureType = as<TextureTypeBase>(type)) { return SLANG_TYPE_KIND_RESOURCE; } // TODO: need a better way to handle this stuff... #define CASE(TYPE) \ - else if(type->As<TYPE>()) do { \ + else if(as<TYPE>(type)) do { \ return SLANG_TYPE_KIND_RESOURCE; \ } while(0) @@ -259,27 +259,27 @@ SLANG_API SlangTypeKind spReflectionType_GetKind(SlangReflectionType* inType) CASE(UntypedBufferResourceType); #undef CASE - else if (auto arrayType = type->As<ArrayExpressionType>()) + else if (auto arrayType = as<ArrayExpressionType>(type)) { return SLANG_TYPE_KIND_ARRAY; } - else if( auto declRefType = type->As<DeclRefType>() ) + else if( auto declRefType = as<DeclRefType>(type) ) { auto declRef = declRefType->declRef; - if( auto structDeclRef = declRef.As<StructDecl>() ) + if( auto structDeclRef = declRef.as<StructDecl>() ) { return SLANG_TYPE_KIND_STRUCT; } - else if (auto genericParamType = declRef.As<GlobalGenericParamDecl>()) + else if (auto genericParamType = declRef.as<GlobalGenericParamDecl>()) { return SLANG_TYPE_KIND_GENERIC_TYPE_PARAMETER; } - else if (auto interfaceType = declRef.As<InterfaceDecl>()) + else if (auto interfaceType = declRef.as<InterfaceDecl>()) { return SLANG_TYPE_KIND_INTERFACE; } } - else if (auto errorType = type->As<ErrorType>()) + else if (auto errorType = as<ErrorType>(type)) { // This means we saw a type we didn't understand in the user's code return SLANG_TYPE_KIND_NONE; @@ -296,10 +296,10 @@ SLANG_API unsigned int spReflectionType_GetFieldCount(SlangReflectionType* inTyp // TODO: maybe filter based on kind - if(auto declRefType = type->As<DeclRefType>()) + if(auto declRefType = as<DeclRefType>(type)) { auto declRef = declRefType->declRef; - if( auto structDeclRef = declRef.As<StructDecl>()) + if( auto structDeclRef = declRef.as<StructDecl>()) { return GetFields(structDeclRef).Count(); } @@ -315,10 +315,10 @@ SLANG_API SlangReflectionVariable* spReflectionType_GetFieldByIndex(SlangReflect // TODO: maybe filter based on kind - if(auto declRefType = type->As<DeclRefType>()) + if(auto declRefType = as<DeclRefType>(type)) { auto declRef = declRefType->declRef; - if( auto structDeclRef = declRef.As<StructDecl>()) + if( auto structDeclRef = declRef.as<StructDecl>()) { auto fieldDeclRef = GetFields(structDeclRef).ToArray()[index]; return (SlangReflectionVariable*) fieldDeclRef.getDecl(); @@ -333,11 +333,11 @@ SLANG_API size_t spReflectionType_GetElementCount(SlangReflectionType* inType) auto type = convert(inType); if(!type) return 0; - if(auto arrayType = type->As<ArrayExpressionType>()) + if(auto arrayType = as<ArrayExpressionType>(type)) { return arrayType->ArrayLength ? (size_t) GetIntVal(arrayType->ArrayLength) : 0; } - else if( auto vectorType = type->As<VectorExpressionType>()) + else if( auto vectorType = as<VectorExpressionType>(type)) { return (size_t) GetIntVal(vectorType->elementCount); } @@ -350,19 +350,19 @@ SLANG_API SlangReflectionType* spReflectionType_GetElementType(SlangReflectionTy auto type = convert(inType); if(!type) return nullptr; - if(auto arrayType = type->As<ArrayExpressionType>()) + if(auto arrayType = as<ArrayExpressionType>(type)) { return (SlangReflectionType*) arrayType->baseType.Ptr(); } - else if( auto constantBufferType = type->As<ConstantBufferType>()) + else if( auto constantBufferType = as<ConstantBufferType>(type)) { return convert(constantBufferType->elementType.Ptr()); } - else if( auto vectorType = type->As<VectorExpressionType>()) + else if( auto vectorType = as<VectorExpressionType>(type)) { return convert(vectorType->elementType.Ptr()); } - else if( auto matrixType = type->As<MatrixExpressionType>()) + else if( auto matrixType = as<MatrixExpressionType>(type)) { return convert(matrixType->getElementType()); } @@ -375,15 +375,15 @@ SLANG_API unsigned int spReflectionType_GetRowCount(SlangReflectionType* inType) auto type = convert(inType); if(!type) return 0; - if(auto matrixType = type->As<MatrixExpressionType>()) + if(auto matrixType = as<MatrixExpressionType>(type)) { return (unsigned int) GetIntVal(matrixType->getRowCount()); } - else if(auto vectorType = type->As<VectorExpressionType>()) + else if(auto vectorType = as<VectorExpressionType>(type)) { return 1; } - else if( auto basicType = type->As<BasicExpressionType>() ) + else if( auto basicType = as<BasicExpressionType>(type) ) { return 1; } @@ -396,15 +396,15 @@ SLANG_API unsigned int spReflectionType_GetColumnCount(SlangReflectionType* inTy auto type = convert(inType); if(!type) return 0; - if(auto matrixType = type->As<MatrixExpressionType>()) + if(auto matrixType = as<MatrixExpressionType>(type)) { return (unsigned int) GetIntVal(matrixType->getColumnCount()); } - else if(auto vectorType = type->As<VectorExpressionType>()) + else if(auto vectorType = as<VectorExpressionType>(type)) { return (unsigned int) GetIntVal(vectorType->elementCount); } - else if( auto basicType = type->As<BasicExpressionType>() ) + else if( auto basicType = as<BasicExpressionType>(type) ) { return 1; } @@ -417,16 +417,16 @@ SLANG_API SlangScalarType spReflectionType_GetScalarType(SlangReflectionType* in auto type = convert(inType); if(!type) return 0; - if(auto matrixType = type->As<MatrixExpressionType>()) + if(auto matrixType = as<MatrixExpressionType>(type)) { type = matrixType->getElementType(); } - else if(auto vectorType = type->As<VectorExpressionType>()) + else if(auto vectorType = as<VectorExpressionType>(type)) { type = vectorType->elementType.Ptr(); } - if(auto basicType = type->As<BasicExpressionType>()) + if(auto basicType = as<BasicExpressionType>(type)) { switch (basicType->baseType) { @@ -463,7 +463,7 @@ SLANG_API unsigned int spReflectionType_GetUserAttributeCount(SlangReflectionTyp { auto type = convert(inType); if (!type) return 0; - if (auto declRefType = type->AsDeclRefType()) + if (auto declRefType = as<DeclRefType>(type)) { return getUserAttributeCount(declRefType->declRef.getDecl()); } @@ -473,7 +473,7 @@ SLANG_API SlangReflectionUserAttribute* spReflectionType_GetUserAttribute(SlangR { auto type = convert(inType); if (!type) return 0; - if (auto declRefType = type->AsDeclRefType()) + if (auto declRefType = as<DeclRefType>(type)) { return getUserAttributeByIndex(declRefType->declRef.getDecl(), index); } @@ -483,7 +483,7 @@ SLANG_API SlangReflectionUserAttribute* spReflectionType_FindUserAttributeByName { auto type = convert(inType); if (!type) return 0; - if (auto declRefType = type->AsDeclRefType()) + if (auto declRefType = as<DeclRefType>(type)) { return findUserAttributeByName(declRefType->getSession(), declRefType->declRef.getDecl(), name); } @@ -495,19 +495,19 @@ SLANG_API SlangResourceShape spReflectionType_GetResourceShape(SlangReflectionTy auto type = convert(inType); if(!type) return 0; - while(auto arrayType = type->As<ArrayExpressionType>()) + while(auto arrayType = as<ArrayExpressionType>(type)) { type = arrayType->baseType.Ptr(); } - if(auto textureType = type->As<TextureTypeBase>()) + if(auto textureType = as<TextureTypeBase>(type)) { return textureType->getShape(); } // TODO: need a better way to handle this stuff... #define CASE(TYPE, SHAPE, ACCESS) \ - else if(type->As<TYPE>()) do { \ + else if(as<TYPE>(type)) do { \ return SHAPE; \ } while(0) @@ -530,19 +530,19 @@ SLANG_API SlangResourceAccess spReflectionType_GetResourceAccess(SlangReflection auto type = convert(inType); if(!type) return 0; - while(auto arrayType = type->As<ArrayExpressionType>()) + while(auto arrayType = as<ArrayExpressionType>(type)) { type = arrayType->baseType.Ptr(); } - if(auto textureType = type->As<TextureTypeBase>()) + if(auto textureType = as<TextureTypeBase>(type)) { return textureType->getAccess(); } // TODO: need a better way to handle this stuff... #define CASE(TYPE, SHAPE, ACCESS) \ - else if(type->As<TYPE>()) do { \ + else if(as<TYPE>(type)) do { \ return ACCESS; \ } while(0) @@ -567,7 +567,7 @@ SLANG_API char const* spReflectionType_GetName(SlangReflectionType* inType) { auto type = convert(inType); - if( auto declRefType = type->As<DeclRefType>() ) + if( auto declRefType = as<DeclRefType>(type) ) { auto declRef = declRefType->declRef; @@ -613,20 +613,20 @@ SLANG_API SlangReflectionType* spReflectionType_GetResourceResultType(SlangRefle auto type = convert(inType); if(!type) return nullptr; - while(auto arrayType = type->As<ArrayExpressionType>()) + while(auto arrayType = as<ArrayExpressionType>(type)) { type = arrayType->baseType.Ptr(); } - if (auto textureType = type->As<TextureTypeBase>()) + if (auto textureType = as<TextureTypeBase>(type)) { return convert(textureType->elementType.Ptr()); } // TODO: need a better way to handle this stuff... #define CASE(TYPE, SHAPE, ACCESS) \ - else if(type->As<TYPE>()) do { \ - return convert(type->As<TYPE>()->elementType.Ptr()); \ + else if(as<TYPE>(type)) do { \ + return convert(as<TYPE>(type)->elementType.Ptr()); \ } while(0) // TODO: structured buffer needs to expose type layout! @@ -695,7 +695,7 @@ SLANG_API size_t spReflectionTypeLayout_GetElementStride(SlangReflectionTypeLayo { switch (category) { - // We store the stride explictly for the uniform case + // We store the stride explicitly for the uniform case case SLANG_PARAMETER_CATEGORY_UNIFORM: return arrayTypeLayout->uniformStride; @@ -950,7 +950,7 @@ namespace Slang // Is the category they were asking about one that makes sense for the type // of this variable? Type* type = typeLayout->getType(); - while (auto arrayType = type->As<ArrayExpressionType>()) + while (auto arrayType = as<ArrayExpressionType>(type)) type = arrayType->baseType; switch (spReflectionType_GetKind(convert(type))) { @@ -1109,12 +1109,12 @@ namespace Slang { static unsigned getParameterCount(RefPtr<TypeLayout> typeLayout) { - if(auto parameterGroupLayout = typeLayout.As<ParameterGroupTypeLayout>()) + if(auto parameterGroupLayout = as<ParameterGroupTypeLayout>(typeLayout)) { typeLayout = parameterGroupLayout->offsetElementTypeLayout; } - if(auto structLayout = typeLayout.As<StructTypeLayout>()) + if(auto structLayout = as<StructTypeLayout>(typeLayout)) { return (unsigned) structLayout->fields.Count(); } @@ -1124,12 +1124,12 @@ namespace Slang static VarLayout* getParameterByIndex(RefPtr<TypeLayout> typeLayout, unsigned index) { - if(auto parameterGroupLayout = typeLayout.As<ParameterGroupTypeLayout>()) + if(auto parameterGroupLayout = as<ParameterGroupTypeLayout>(typeLayout)) { typeLayout = parameterGroupLayout->offsetElementTypeLayout; } - if(auto structLayout = typeLayout.As<StructTypeLayout>()) + if(auto structLayout = as<StructTypeLayout>(typeLayout)) { return structLayout->fields[index]; } @@ -1155,7 +1155,7 @@ SLANG_API unsigned spReflectionEntryPoint_getParameterCount( auto entryPointLayout = convert(inEntryPoint); if(!entryPointLayout) return 0; - return getParameterCount(entryPointLayout); + return getParameterCount(entryPointLayout->parametersLayout->typeLayout); } SLANG_API SlangReflectionVariableLayout* spReflectionEntryPoint_getParameterByIndex( @@ -1165,7 +1165,7 @@ SLANG_API SlangReflectionVariableLayout* spReflectionEntryPoint_getParameterByIn auto entryPointLayout = convert(inEntryPoint); if(!entryPointLayout) return 0; - return convert(getParameterByIndex(entryPointLayout, index)); + return convert(getParameterByIndex(entryPointLayout->parametersLayout->typeLayout, index)); } SLANG_API SlangStage spReflectionEntryPoint_getStage(SlangReflectionEntryPoint* inEntryPoint) @@ -1276,12 +1276,6 @@ SLANG_API SlangReflectionType* spReflectionTypeParameter_GetConstraintByIndex(Sl // Shader Reflection -namespace Slang -{ - StructTypeLayout* getGlobalStructLayout( - ProgramLayout* programLayout); -} - SLANG_API unsigned spReflection_GetParameterCount(SlangReflection* inProgram) { auto program = convert(inProgram); @@ -1365,7 +1359,7 @@ SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* { auto program = convert(inProgram); if (!program) return 0; - auto cb = program->globalScopeLayout->FindResourceInfo(LayoutResourceKind::ConstantBuffer); + auto cb = program->parametersLayout->FindResourceInfo(LayoutResourceKind::ConstantBuffer); if (!cb) return 0; return cb->index; } diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index 1ad408e73..b0ac37440 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -185,6 +185,7 @@ <ClInclude Include="ir-constexpr.h" /> <ClInclude Include="ir-dce.h" /> <ClInclude Include="ir-dominators.h" /> + <ClInclude Include="ir-entry-point-uniforms.h" /> <ClInclude Include="ir-glsl-legalize.h" /> <ClInclude Include="ir-inst-defs.h" /> <ClInclude Include="ir-insts.h" /> @@ -240,6 +241,7 @@ <ClCompile Include="ir-constexpr.cpp" /> <ClCompile Include="ir-dce.cpp" /> <ClCompile Include="ir-dominators.cpp" /> + <ClCompile Include="ir-entry-point-uniforms.cpp" /> <ClCompile Include="ir-glsl-legalize.cpp" /> <ClCompile Include="ir-legalize-types.cpp" /> <ClCompile Include="ir-link.cpp" /> diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index 9e3de4b93..0a44f9f57 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -54,6 +54,9 @@ <ClInclude Include="ir-dominators.h"> <Filter>Header Files</Filter> </ClInclude> + <ClInclude Include="ir-entry-point-uniforms.h"> + <Filter>Header Files</Filter> + </ClInclude> <ClInclude Include="ir-glsl-legalize.h"> <Filter>Header Files</Filter> </ClInclude> @@ -215,6 +218,9 @@ <ClCompile Include="ir-dominators.cpp"> <Filter>Source Files</Filter> </ClCompile> + <ClCompile Include="ir-entry-point-uniforms.cpp"> + <Filter>Source Files</Filter> + </ClCompile> <ClCompile Include="ir-glsl-legalize.cpp"> <Filter>Source Files</Filter> </ClCompile> diff --git a/source/slang/syntax-base-defs.h b/source/slang/syntax-base-defs.h index 81f03d43f..7f49bb607 100644 --- a/source/slang/syntax-base-defs.h +++ b/source/slang/syntax-base-defs.h @@ -16,17 +16,6 @@ END_SYNTAX_CLASS() ABSTRACT_SYNTAX_CLASS(SyntaxNodeBase, NodeBase) // The primary source location associated with this AST node FIELD(SourceLoc, loc) - - RAW( - // Allow dynamic casting with a convenient syntax - template<typename T> - T* As() - { - SLANG_ASSERT(this); - return dynamic_cast<T*>(this); - } - ) - END_SYNTAX_CLASS() // Base class for compile-time values (most often a type). @@ -60,6 +49,15 @@ ABSTRACT_SYNTAX_CLASS(Val, NodeBase) ) END_SYNTAX_CLASS() +RAW( + class Type; + + template <typename T> + SLANG_FORCE_INLINE T* as(Type* obj); + template <typename T> + SLANG_FORCE_INLINE const T* as(const Type* obj); + ) + // A type, representing a classifier for some term in the AST. // // Types can include "sugar" in that they may refer to a @@ -68,7 +66,7 @@ END_SYNTAX_CLASS() // // In order to operation on types, though, we often want // to look past any sugar, and operate on an underlying -// "canonical" type. The reprsentation caches a pointer to +// "canonical" type. The representation caches a pointer to // a canonical type on every type, so we can easily // operate on the raw representation when needed. ABSTRACT_SYNTAX_CLASS(Type, Val) @@ -85,31 +83,12 @@ public: bool Equals(Type * type); bool Equals(RefPtr<Type> type); - bool IsVectorType() { return As<VectorExpressionType>() != nullptr; } - bool IsArray() { return As<ArrayExpressionType>() != nullptr; } - - template<typename T> - T* As() - { - return dynamic_cast<T*>(GetCanonicalType()); - } - - // Convenience/legacy wrappers for `As<>` - ArithmeticExpressionType * AsArithmeticType() { return As<ArithmeticExpressionType>(); } - BasicExpressionType * AsBasicType() { return As<BasicExpressionType>(); } - VectorExpressionType * AsVectorType() { return As<VectorExpressionType>(); } - MatrixExpressionType * AsMatrixType() { return As<MatrixExpressionType>(); } - ArrayExpressionType * AsArrayType() { return As<ArrayExpressionType>(); } - - DeclRefType* AsDeclRefType() { return As<DeclRefType>(); } - - NamedExpressionType* AsNamedType(); - bool IsTextureOrSampler(); - bool IsTexture() { return As<TextureType>() != nullptr; } - bool IsSampler() { return As<SamplerStateType>() != nullptr; } + bool IsTexture() { return as<TextureType>(this) != nullptr; } + bool IsSampler() { return as<SamplerStateType>(this) != nullptr; } bool IsStruct(); - bool IsClass(); + //bool IsClass(); + Type* GetCanonicalType(); virtual RefPtr<Val> SubstituteImpl(SubstitutionSet subst, int* ioDiff) override; @@ -125,7 +104,12 @@ protected: Session* session = nullptr; ) END_SYNTAX_CLASS() - +RAW( + template <typename T> + SLANG_FORCE_INLINE T* as(Type* obj) { return obj ? dynamicCast<T>(obj->GetCanonicalType()) : nullptr; } + template <typename T> + SLANG_FORCE_INLINE const T* as(const Type* obj) { return obj ? dynamicCast<T>(const_cast<Type*>(obj)->GetCanonicalType()) : nullptr; } +) // A substitution represents a binding of certain // type-level variables to concrete argument values diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index fe8b2c4fe..2be1a79ed 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -130,11 +130,6 @@ void Type::accept(IValVisitor* visitor, void* extra) return false; } - NamedExpressionType* Type::AsNamedType() - { - return dynamic_cast<NamedExpressionType*>(this); - } - RefPtr<Val> Type::SubstituteImpl(SubstitutionSet subst, int* ioDiff) { int diff = 0; @@ -174,11 +169,12 @@ void Type::accept(IValVisitor* visitor, void* extra) { return IsTexture() || IsSampler(); } + bool Type::IsStruct() { - auto declRefType = AsDeclRefType(); + auto declRefType = as<DeclRefType>(this); if (!declRefType) return false; - auto structDeclRef = declRefType->declRef.As<StructDecl>(); + auto structDeclRef = declRefType->declRef.as<StructDecl>(); if (!structDeclRef) return false; return true; } @@ -276,29 +272,29 @@ void Type::accept(IValVisitor* visitor, void* extra) RefPtr<PtrType> Session::getPtrType( RefPtr<Type> valueType) { - return getPtrType(valueType, "PtrType").As<PtrType>(); + return getPtrType(valueType, "PtrType").dynamicCast<PtrType>(); } // Construct the type `Out<valueType>` RefPtr<OutType> Session::getOutType(RefPtr<Type> valueType) { - return getPtrType(valueType, "OutType").As<OutType>(); + return getPtrType(valueType, "OutType").dynamicCast<OutType>(); } RefPtr<InOutType> Session::getInOutType(RefPtr<Type> valueType) { - return getPtrType(valueType, "InOutType").As<InOutType>(); + return getPtrType(valueType, "InOutType").dynamicCast<InOutType>(); } RefPtr<RefType> Session::getRefType(RefPtr<Type> valueType) { - return getPtrType(valueType, "RefType").As<RefType>(); + return getPtrType(valueType, "RefType").dynamicCast<RefType>(); } RefPtr<PtrTypeBase> Session::getPtrType(RefPtr<Type> valueType, char const* ptrTypeName) { auto genericDecl = findMagicDecl( - this, ptrTypeName).As<GenericDecl>(); + this, ptrTypeName).dynamicCast<GenericDecl>(); return getPtrType(valueType, genericDecl); } @@ -314,7 +310,7 @@ void Type::accept(IValVisitor* visitor, void* extra) auto rsType = DeclRefType::Create( this, declRef); - return rsType->As<PtrTypeBase>(); + return as<PtrTypeBase>( rsType); } RefPtr<ArrayExpressionType> Session::getArrayType( @@ -341,7 +337,7 @@ void Type::accept(IValVisitor* visitor, void* extra) bool ArrayExpressionType::EqualsImpl(Type * type) { - auto arrType = type->AsArrayType(); + auto arrType = as<ArrayExpressionType>(type); if (!arrType) return false; return (areValsEqual(ArrayLength, arrType->ArrayLength) && baseType->Equals(arrType->baseType.Ptr())); @@ -350,8 +346,8 @@ void Type::accept(IValVisitor* visitor, void* extra) RefPtr<Val> ArrayExpressionType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) { int diff = 0; - auto elementType = baseType->SubstituteImpl(subst, &diff).As<Type>(); - auto arrlen = ArrayLength->SubstituteImpl(subst, &diff).As<IntVal>(); + auto elementType = baseType->SubstituteImpl(subst, &diff).dynamicCast<Type>(); + auto arrlen = ArrayLength->SubstituteImpl(subst, &diff).dynamicCast<IntVal>(); SLANG_ASSERT(arrlen); if (diff) { @@ -401,7 +397,7 @@ void Type::accept(IValVisitor* visitor, void* extra) bool DeclRefType::EqualsImpl(Type * type) { - if (auto declRefType = type->AsDeclRefType()) + if (auto declRefType = as<DeclRefType>(type)) { return declRef.Equals(declRefType->declRef); } @@ -432,7 +428,7 @@ void Type::accept(IValVisitor* visitor, void* extra) RefPtr<WitnessTable> RequirementWitness::getWitnessTable() { SLANG_ASSERT(getFlavor() == Flavor::witnessTable); - return m_obj.As<WitnessTable>(); + return m_obj.dynamicCast<WitnessTable>(); } @@ -464,7 +460,7 @@ void Type::accept(IValVisitor* visitor, void* extra) { if(auto declaredSubtypeWitness = dynamic_cast<DeclaredSubtypeWitness*>(subtypeWitness)) { - if(auto inheritanceDeclRef = declaredSubtypeWitness->declRef.As<InheritanceDecl>()) + if(auto inheritanceDeclRef = declaredSubtypeWitness->declRef.as<InheritanceDecl>()) { // A conformance that was declared as part of an inheritance clause // will have built up a dictionary of the satisfying declarations @@ -500,7 +496,7 @@ void Type::accept(IValVisitor* visitor, void* extra) // reference to `ISidekick.Hero` with a this-type substitution that references // the `{S:ISidekick}` declaration as a witness. // - // The front-end will expand the generic appliation `followHero<Sidekick<Batman>>` + // The front-end will expand the generic application `followHero<Sidekick<Batman>>` // to `followHero<Sidekick<Batman>, {Sidekick<H>:ISidekick}[H->Batman]>` // (that is, the hidden second parameter will reference the inheritance // clause on `Sidekick<H>`, with a substitution to map `H` to `Batman`. @@ -541,7 +537,7 @@ void Type::accept(IValVisitor* visitor, void* extra) // search for a substitution that might apply to us for(auto s = subst.substitutions; s; s = s->outer) { - auto genericSubst = s.As<GenericSubstitution>(); + auto genericSubst = s.dynamicCast<GenericSubstitution>(); if(!genericSubst) continue; @@ -560,11 +556,11 @@ void Type::accept(IValVisitor* visitor, void* extra) (*ioDiff)++; return genericSubst->args[index]; } - else if (auto typeParam = m.As<GenericTypeParamDecl>()) + else if (auto typeParam = as<GenericTypeParamDecl>(m)) { index++; } - else if (auto valParam = m.As<GenericValueParamDecl>()) + else if (auto valParam = as<GenericValueParamDecl>(m)) { index++; } @@ -579,7 +575,7 @@ void Type::accept(IValVisitor* visitor, void* extra) // search for a substitution that might apply to us for(auto s = subst.substitutions; s; s = s->outer) { - auto genericSubst = s.As<GlobalGenericParamSubstitution>(); + auto genericSubst = as<GlobalGenericParamSubstitution>(s); if(!genericSubst) continue; @@ -604,15 +600,15 @@ void Type::accept(IValVisitor* visitor, void* extra) // the outer interface, then try to replace the type with the // actual value of the associated type for the given implementation. // - if(auto substAssocTypeDecl = substDeclRef.decl->As<AssocTypeDecl>()) + if(auto substAssocTypeDecl = as<AssocTypeDecl>(substDeclRef.decl)) { for(auto s = substDeclRef.substitutions.substitutions; s; s = s->outer) { - auto thisSubst = s.As<ThisTypeSubstitution>(); + auto thisSubst = s.as<ThisTypeSubstitution>(); if(!thisSubst) continue; - if(auto interfaceDecl = substAssocTypeDecl->ParentDecl->As<InterfaceDecl>()) + if(auto interfaceDecl = as<InterfaceDecl>(substAssocTypeDecl->ParentDecl)) { if(thisSubst->interfaceDecl == interfaceDecl) { @@ -644,14 +640,14 @@ void Type::accept(IValVisitor* visitor, void* extra) static RefPtr<Type> ExtractGenericArgType(RefPtr<Val> val) { - auto type = val.As<Type>(); + auto type = val.dynamicCast<Type>(); SLANG_RELEASE_ASSERT(type.Ptr()); return type; } static RefPtr<IntVal> ExtractGenericArgInteger(RefPtr<Val> val) { - auto intVal = val.As<IntVal>(); + auto intVal = val.as<IntVal>(); SLANG_RELEASE_ASSERT(intVal.Ptr()); return intVal; } @@ -690,7 +686,7 @@ void Type::accept(IValVisitor* visitor, void* extra) dd = parentDecl; - if(auto genericParentDecl = parentDecl.As<GenericDecl>()) + if(auto genericParentDecl = parentDecl.as<GenericDecl>()) { // Don't specialize any parameters of a generic. if(childDecl != genericParentDecl->inner) @@ -700,7 +696,7 @@ void Type::accept(IValVisitor* visitor, void* extra) RefPtr<GenericSubstitution> foundSubst; for(auto s = declRef.substitutions.substitutions; s; s = s->outer) { - auto genSubst = s.As<GenericSubstitution>(); + auto genSubst = s.as<GenericSubstitution>(); if(!genSubst) continue; @@ -753,7 +749,7 @@ void Type::accept(IValVisitor* visitor, void* extra) GenericSubstitution* subst = nullptr; for(auto s = declRef.substitutions.substitutions; s; s = s->outer) { - if(auto genericSubst = s.As<GenericSubstitution>()) + if(auto genericSubst = s.as<GenericSubstitution>()) { subst = genericSubst; break; @@ -967,7 +963,7 @@ void Type::accept(IValVisitor* visitor, void* extra) bool ErrorType::EqualsImpl(Type* type) { - if (auto errorType = type->As<ErrorType>()) + if (auto errorType = as<ErrorType>(type)) return true; return false; } @@ -1041,7 +1037,7 @@ void Type::accept(IValVisitor* visitor, void* extra) bool FuncType::EqualsImpl(Type * type) { - if (auto funcType = type->As<FuncType>()) + if (auto funcType = as<FuncType>(type)) { auto paramCount = getParamCount(); auto otherParamCount = funcType->getParamCount(); @@ -1072,13 +1068,13 @@ void Type::accept(IValVisitor* visitor, void* extra) int diff = 0; // result type - RefPtr<Type> substResultType = resultType->SubstituteImpl(subst, &diff).As<Type>(); + RefPtr<Type> substResultType = resultType->SubstituteImpl(subst, &diff).dynamicCast<Type>(); // parameter types List<RefPtr<Type>> substParamTypes; for( auto pp : paramTypes ) { - substParamTypes.Add(pp->SubstituteImpl(subst, &diff).As<Type>()); + substParamTypes.Add(pp->SubstituteImpl(subst, &diff).dynamicCast<Type>()); } // early exit for no change... @@ -1138,7 +1134,7 @@ void Type::accept(IValVisitor* visitor, void* extra) bool TypeType::EqualsImpl(Type * t) { - if (auto typeType = t->As<TypeType>()) + if (auto typeType = as<TypeType>(t)) { return t->Equals(typeType->type); } @@ -1167,7 +1163,7 @@ void Type::accept(IValVisitor* visitor, void* extra) bool GenericDeclRefType::EqualsImpl(Type * type) { - if (auto genericDeclRefType = type->As<GenericDeclRefType>()) + if (auto genericDeclRefType = as<GenericDeclRefType>(type)) { return declRef.Equals(genericDeclRefType->declRef); } @@ -1197,7 +1193,7 @@ void Type::accept(IValVisitor* visitor, void* extra) BasicExpressionType* VectorExpressionType::GetScalarType() { - return elementType->AsBasicType(); + return as<BasicExpressionType>(elementType); } // @@ -1206,7 +1202,7 @@ void Type::accept(IValVisitor* visitor, void* extra) { for(RefPtr<Substitutions> s = subst; s; s = s->outer) { - if(auto genericSubst = s.As<GenericSubstitution>()) + if(auto genericSubst = as<GenericSubstitution>(s)) return genericSubst; } return nullptr; @@ -1223,22 +1219,22 @@ void Type::accept(IValVisitor* visitor, void* extra) BasicExpressionType* MatrixExpressionType::GetScalarType() { - return getElementType()->AsBasicType(); + return as<BasicExpressionType>(getElementType()); } Type* MatrixExpressionType::getElementType() { - return findInnerMostGenericSubstitution(declRef.substitutions)->args[0].As<Type>().Ptr(); + return dynamicCast<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); } IntVal* MatrixExpressionType::getRowCount() { - return findInnerMostGenericSubstitution(declRef.substitutions)->args[1].As<IntVal>().Ptr(); + return dynamicCast<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->args[1]); } IntVal* MatrixExpressionType::getColumnCount() { - return findInnerMostGenericSubstitution(declRef.substitutions)->args[2].As<IntVal>().Ptr(); + return dynamicCast<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->args[2]); } RefPtr<Type> MatrixExpressionType::getRowType() @@ -1255,7 +1251,7 @@ void Type::accept(IValVisitor* visitor, void* extra) RefPtr<IntVal> elementCount) { auto vectorGenericDecl = findMagicDecl( - this, "Vector").As<GenericDecl>(); + this, "Vector").as<GenericDecl>(); auto vectorTypeDecl = vectorGenericDecl->inner; auto substitutions = new GenericSubstitution(); @@ -1265,9 +1261,9 @@ void Type::accept(IValVisitor* visitor, void* extra) auto declRef = DeclRef<Decl>(vectorTypeDecl.Ptr(), substitutions); - return DeclRefType::Create( + return as<VectorExpressionType>(DeclRefType::Create( this, - declRef)->As<VectorExpressionType>(); + declRef)); } @@ -1275,7 +1271,7 @@ void Type::accept(IValVisitor* visitor, void* extra) Type* PtrTypeBase::getValueType() { - return findInnerMostGenericSubstitution(declRef.substitutions)->args[0].As<Type>().Ptr(); + return dynamicCast<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); } // GenericParamIntVal @@ -1304,7 +1300,7 @@ void Type::accept(IValVisitor* visitor, void* extra) // search for a substitution that might apply to us for(auto s = subst.substitutions; s; s = s->outer) { - auto genSubst = s.As<GenericSubstitution>(); + auto genSubst = s.as<GenericSubstitution>(); if(!genSubst) continue; @@ -1323,11 +1319,11 @@ void Type::accept(IValVisitor* visitor, void* extra) (*ioDiff)++; return genSubst->args[index]; } - else if (auto typeParam = m.As<GenericTypeParamDecl>()) + else if (auto typeParam = as<GenericTypeParamDecl>(m)) { index++; } - else if (auto valParam = m.As<GenericValueParamDecl>()) + else if (auto valParam = as<GenericValueParamDecl>(m)) { index++; } @@ -1337,7 +1333,7 @@ void Type::accept(IValVisitor* visitor, void* extra) } } - // Nothing found: don't substittue. + // Nothing found: don't substitute. return this; } @@ -1402,8 +1398,10 @@ void Type::accept(IValVisitor* visitor, void* extra) int diff = 0; if(substOuter != outer) diff++; - auto substWitness = witness->SubstituteImpl(substSet, &diff).As<SubtypeWitness>(); + // NOTE: Must use .as because we must have a smart pointer here to keep in scope. + auto substWitness = witness->SubstituteImpl(substSet, &diff).as<SubtypeWitness>(); + if (!diff) return this; (*ioDiff)++; @@ -1440,7 +1438,7 @@ void Type::accept(IValVisitor* visitor, void* extra) if(substOuter != outer) diff++; - auto substActualType = actualType->SubstituteImpl(substSet, &diff).As<Type>(); + auto substActualType = actualType->SubstituteImpl(substSet, &diff).dynamicCast<Type>(); List<ConstraintArg> substConstraintArgs; for(auto constraintArg : constraintArgs) @@ -1499,7 +1497,7 @@ void Type::accept(IValVisitor* visitor, void* extra) // Otherwise we need to recurse on the type structure // and apply substitutions where it makes sense - return type->Substitute(substitutions).As<Type>(); + return type->Substitute(substitutions).dynamicCast<Type>(); } DeclRefBase DeclRefBase::Substitute(DeclRefBase declRef) const @@ -1529,7 +1527,7 @@ void Type::accept(IValVisitor* visitor, void* extra) Decl* dd = decl; while(dd) { - if(auto interfaceDecl = dd->As<InterfaceDecl>()) + if(auto interfaceDecl = as<InterfaceDecl>(dd)) return interfaceDecl; dd = dd->ParentDecl; @@ -1543,7 +1541,7 @@ void Type::accept(IValVisitor* visitor, void* extra) { for(auto s = substs; s; s = s->outer) { - auto gSubst = s.As<GlobalGenericParamSubstitution>(); + auto gSubst = s.as<GlobalGenericParamSubstitution>(); if(!gSubst) continue; @@ -1576,7 +1574,7 @@ void Type::accept(IValVisitor* visitor, void* extra) // a recursive case that skips the rest of the function. for(auto specSubst = substsToSpecialize; specSubst; specSubst = specSubst->outer) { - auto specGlobalGenericSubst = specSubst.As<GlobalGenericParamSubstitution>(); + auto specGlobalGenericSubst = specSubst.as<GlobalGenericParamSubstitution>(); if(!specGlobalGenericSubst) continue; @@ -1607,7 +1605,7 @@ void Type::accept(IValVisitor* visitor, void* extra) // the end of the list in all cases, so lets advance // until we see them. RefPtr<Substitutions> appGlobalGenericSubsts = substsToApply; - while(appGlobalGenericSubsts && !appGlobalGenericSubsts.As<GlobalGenericParamSubstitution>()) + while(appGlobalGenericSubsts && !appGlobalGenericSubsts.as<GlobalGenericParamSubstitution>()) appGlobalGenericSubsts = appGlobalGenericSubsts->outer; @@ -1627,7 +1625,7 @@ void Type::accept(IValVisitor* visitor, void* extra) RefPtr<Substitutions>* link = &resultSubst; for(auto appSubst = appGlobalGenericSubsts; appSubst; appSubst = appSubst->outer) { - auto appGlobalGenericSubst = appSubst.As<GlobalGenericParamSubstitution>(); + auto appGlobalGenericSubst = appSubst.as<GlobalGenericParamSubstitution>(); if(!appSubst) continue; @@ -1661,7 +1659,7 @@ void Type::accept(IValVisitor* visitor, void* extra) // Construct new substitutions to apply to a declaration, - // based on a provided substituion set to be applied + // based on a provided substitution set to be applied RefPtr<Substitutions> specializeSubstitutions( Decl* declToSpecialize, RefPtr<Substitutions> substsToSpecialize, @@ -1684,11 +1682,11 @@ void Type::accept(IValVisitor* visitor, void* extra) // corresponding to that decl. for(Decl* ancestorDecl = declToSpecialize; ancestorDecl; ancestorDecl = ancestorDecl->ParentDecl) { - if(auto ancestorGenericDecl = ancestorDecl->As<GenericDecl>()) + if(auto ancestorGenericDecl = as<GenericDecl>(ancestorDecl)) { // The declaration is nested inside a generic. // Does it already have a specialization for that generic? - if(auto specGenericSubst = substsToSpecialize.As<GenericSubstitution>()) + if(auto specGenericSubst = substsToSpecialize.as<GenericSubstitution>()) { if(specGenericSubst->genericDecl == ancestorGenericDecl) { @@ -1722,7 +1720,7 @@ void Type::accept(IValVisitor* visitor, void* extra) // for(auto s = substsToApply; s; s = s->outer) { - auto appGenericSubst = s.As<GenericSubstitution>(); + auto appGenericSubst = s.as<GenericSubstitution>(); if(!appGenericSubst) continue; @@ -1750,7 +1748,7 @@ void Type::accept(IValVisitor* visitor, void* extra) return firstSubst; } } - else if(auto ancestorInterfaceDecl = ancestorDecl->As<InterfaceDecl>()) + else if(auto ancestorInterfaceDecl = as<InterfaceDecl>(ancestorDecl)) { // The task is basically the same as for the generic case: // We want to see if there is any existing substitution that @@ -1758,7 +1756,7 @@ void Type::accept(IValVisitor* visitor, void* extra) // The declaration is nested inside a generic. // Does it already have a specialization for that generic? - if(auto specThisTypeSubst = substsToSpecialize.As<ThisTypeSubstitution>()) + if(auto specThisTypeSubst = substsToSpecialize.as<ThisTypeSubstitution>()) { if(specThisTypeSubst->interfaceDecl == ancestorInterfaceDecl) { @@ -1787,7 +1785,7 @@ void Type::accept(IValVisitor* visitor, void* extra) // for(auto s = substsToApply; s; s = s->outer) { - auto appThisTypeSubst = s.As<ThisTypeSubstitution>(); + auto appThisTypeSubst = s.as<ThisTypeSubstitution>(); if(!appThisTypeSubst) continue; @@ -1818,7 +1816,7 @@ void Type::accept(IValVisitor* visitor, void* extra) // in either substitution. // // As an invariant, there should *not* be any generic or this-type - // substitutiosn in `substToSpecialize`, because otherwise they + // substitutions in `substToSpecialize`, because otherwise they // would be specializations that don't actually apply to the given // declaration. // @@ -1865,7 +1863,7 @@ void Type::accept(IValVisitor* visitor, void* extra) // TODO: The old code here used to try to translate a decl-ref // to an associated type in a decl-ref for the concrete type - // in a paarticular implementation. + // in a particular implementation. // // I have only kept that logic in `DeclRefType::SubstituteImpl`, // but it may turn out it is needed here too. @@ -1907,7 +1905,7 @@ void Type::accept(IValVisitor* visitor, void* extra) // and there might be a this-type substitution in place. // A reference to the parent of the interface declaration // should not include that substitution. - if(auto thisTypeSubst = substToApply.As<ThisTypeSubstitution>()) + if(auto thisTypeSubst = substToApply.as<ThisTypeSubstitution>()) { if(thisTypeSubst->interfaceDecl == interfaceDecl) { @@ -1921,11 +1919,11 @@ void Type::accept(IValVisitor* visitor, void* extra) { // The parent of this declaration is a generic, which means // that the decl-ref to the current declaration might include - // substitutiosn that specialize the generic parameters. + // substitutions that specialize the generic parameters. // A decl-ref to the parent generic should *not* include // those substitutions. // - if(auto genericSubst = substToApply.As<GenericSubstitution>()) + if(auto genericSubst = substToApply.as<GenericSubstitution>()) { if(genericSubst->genericDecl == parentGenericDecl) { @@ -1963,7 +1961,7 @@ void Type::accept(IValVisitor* visitor, void* extra) IntegerLiteralValue GetIntVal(RefPtr<IntVal> val) { - if (auto constantVal = val.As<ConstantIntVal>()) + if (auto constantVal = as<ConstantIntVal>(val)) { return constantVal->value; } @@ -1975,7 +1973,7 @@ void Type::accept(IValVisitor* visitor, void* extra) bool ConstantIntVal::EqualsVal(Val* val) { - if (auto intVal = dynamic_cast<ConstantIntVal*>(val)) + if (auto intVal = dynamicCast<ConstantIntVal>(val)) return value == intVal->value; return false; } @@ -2047,12 +2045,12 @@ void Type::accept(IValVisitor* visitor, void* extra) Type* HLSLPatchType::getElementType() { - return findInnerMostGenericSubstitution(declRef.substitutions)->args[0].As<Type>().Ptr(); + return dynamicCast<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); } IntVal* HLSLPatchType::getElementCount() { - return findInnerMostGenericSubstitution(declRef.substitutions)->args[1].As<IntVal>().Ptr(); + return dynamicCast<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->args[1]); } // Constructors for types @@ -2083,7 +2081,7 @@ void Type::accept(IValVisitor* visitor, void* extra) Session* session, DeclRef<TypeDefDecl> const& declRef) { - DeclRef<TypeDefDecl> specializedDeclRef = createDefaultSubstitutionsIfNeeded(session, declRef).As<TypeDefDecl>(); + DeclRef<TypeDefDecl> specializedDeclRef = createDefaultSubstitutionsIfNeeded(session, declRef).as<TypeDefDecl>(); auto namedType = new NamedExpressionType(specializedDeclRef); namedType->setSession(session); @@ -2162,8 +2160,8 @@ void Type::accept(IValVisitor* visitor, void* extra) RefPtr<Val> TypeEqualityWitness::SubstituteImpl(SubstitutionSet subst, int * ioDiff) { RefPtr<TypeEqualityWitness> rs = new TypeEqualityWitness(); - rs->sub = sub->SubstituteImpl(subst, ioDiff).As<Type>(); - rs->sup = sup->SubstituteImpl(subst, ioDiff).As<Type>(); + rs->sub = sub->SubstituteImpl(subst, ioDiff).dynamicCast<Type>(); + rs->sup = sup->SubstituteImpl(subst, ioDiff).dynamicCast<Type>(); return rs; } @@ -2194,7 +2192,7 @@ void Type::accept(IValVisitor* visitor, void* extra) { for(RefPtr<Substitutions> s = substs; s; s = s->outer) { - auto thisTypeSubst = s.As<ThisTypeSubstitution>(); + auto thisTypeSubst = s.dynamicCast<ThisTypeSubstitution>(); if(!thisTypeSubst) continue; @@ -2209,14 +2207,14 @@ void Type::accept(IValVisitor* visitor, void* extra) RefPtr<Val> DeclaredSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int * ioDiff) { - if (auto genConstraintDeclRef = declRef.As<GenericTypeConstraintDecl>()) + if (auto genConstraintDeclRef = declRef.as<GenericTypeConstraintDecl>()) { auto genConstraintDecl = genConstraintDeclRef.getDecl(); // search for a substitution that might apply to us for(auto s = subst.substitutions; s; s = s->outer) { - if(auto genericSubst = s.As<GenericSubstitution>()) + if(auto genericSubst = s.as<GenericSubstitution>()) { // the generic decl associated with the substitution list must be // the generic decl that declared this parameter @@ -2228,7 +2226,7 @@ void Type::accept(IValVisitor* visitor, void* extra) UInt index = 0; for (auto m : genericDecl->Members) { - if (auto constraintParam = m.As<GenericTypeConstraintDecl>()) + if (auto constraintParam = m.dynamicCast<GenericTypeConstraintDecl>()) { if (constraintParam.Ptr() == declRef.getDecl()) { @@ -2247,7 +2245,7 @@ void Type::accept(IValVisitor* visitor, void* extra) return genericSubst->args[index + ordinaryParamCount]; } } - else if(auto globalGenericSubst = s.As<GlobalGenericParamSubstitution>()) + else if(auto globalGenericSubst = s.as<GlobalGenericParamSubstitution>()) { // check if the substitution is really about this global generic type parameter if (globalGenericSubst->paramDecl != genConstraintDecl->ParentDecl) @@ -2267,8 +2265,8 @@ void Type::accept(IValVisitor* visitor, void* extra) // Perform substitution on the constituent elements. int diff = 0; - auto substSub = sub->SubstituteImpl(subst, &diff).As<Type>(); - auto substSup = sup->SubstituteImpl(subst, &diff).As<Type>(); + auto substSub = sub->SubstituteImpl(subst, &diff).dynamicCast<Type>(); + auto substSup = sup->SubstituteImpl(subst, &diff).dynamicCast<Type>(); auto substDeclRef = declRef.SubstituteImpl(subst, &diff); if (!diff) return this; @@ -2285,11 +2283,11 @@ void Type::accept(IValVisitor* visitor, void* extra) // so we'll need to change this location in the code if we ever clean // up the hierarchy. // - if (auto substTypeConstraintDecl = substDeclRef.decl->As<GenericTypeConstraintDecl>()) + if (auto substTypeConstraintDecl = as<GenericTypeConstraintDecl>(substDeclRef.decl)) { - if (auto substAssocTypeDecl = substTypeConstraintDecl->ParentDecl->As<AssocTypeDecl>()) + if (auto substAssocTypeDecl = as<AssocTypeDecl>(substTypeConstraintDecl->ParentDecl)) { - if (auto interfaceDecl = substAssocTypeDecl->ParentDecl->As<InterfaceDecl>()) + if (auto interfaceDecl = as<InterfaceDecl>(substAssocTypeDecl->ParentDecl)) { // At this point we have a constraint decl for an associated type, // and we nee to see if we are dealing with a concrete substitution @@ -2362,9 +2360,9 @@ void Type::accept(IValVisitor* visitor, void* extra) { int diff = 0; - RefPtr<Type> substSub = sub->SubstituteImpl(subst, &diff).As<Type>(); - RefPtr<Type> substSup = sup->SubstituteImpl(subst, &diff).As<Type>(); - RefPtr<SubtypeWitness> substSubToMid = subToMid->SubstituteImpl(subst, &diff).As<SubtypeWitness>(); + RefPtr<Type> substSub = sub->SubstituteImpl(subst, &diff).dynamicCast<Type>(); + RefPtr<Type> substSup = sup->SubstituteImpl(subst, &diff).dynamicCast<Type>(); + RefPtr<SubtypeWitness> substSubToMid = subToMid->SubstituteImpl(subst, &diff).dynamicCast<SubtypeWitness>(); DeclRef<Decl> substMidToSup = midToSup.SubstituteImpl(subst, &diff); // If nothing changed, then we can bail out early. @@ -2463,7 +2461,7 @@ void Type::accept(IValVisitor* visitor, void* extra) bool ExtractExistentialType::EqualsImpl(Type* type) { - if( auto extractExistential = type->As<ExtractExistentialType>() ) + if( auto extractExistential = as<ExtractExistentialType>(type) ) { return declRef.Equals(extractExistential->declRef); } @@ -2498,7 +2496,7 @@ void Type::accept(IValVisitor* visitor, void* extra) bool ExtractExistentialSubtypeWitness::EqualsVal(Val* val) { - if( auto extractWitness = val->dynamicCast<ExtractExistentialSubtypeWitness>() ) + if( auto extractWitness = dynamicCast<ExtractExistentialSubtypeWitness>(val) ) { return declRef.Equals(extractWitness->declRef); } @@ -2524,8 +2522,8 @@ void Type::accept(IValVisitor* visitor, void* extra) int diff = 0; auto substDeclRef = declRef.SubstituteImpl(subst, &diff); - auto substSub = sub->SubstituteImpl(subst, &diff).As<Type>(); - auto substSup = sup->SubstituteImpl(subst, &diff).As<Type>(); + auto substSub = sub->SubstituteImpl(subst, &diff).dynamicCast<Type>(); + auto substSup = sup->SubstituteImpl(subst, &diff).dynamicCast<Type>(); if(!diff) return this; @@ -2561,7 +2559,7 @@ void Type::accept(IValVisitor* visitor, void* extra) bool TaggedUnionType::EqualsImpl(Type* type) { - auto taggedUnion = type->As<TaggedUnionType>(); + auto taggedUnion = as<TaggedUnionType>(type); if(!taggedUnion) return false; @@ -2608,7 +2606,7 @@ void Type::accept(IValVisitor* visitor, void* extra) List<RefPtr<Type>> substCaseTypes; for( auto caseType : caseTypes ) { - substCaseTypes.Add(caseType->SubstituteImpl(subst, &diff).As<Type>()); + substCaseTypes.Add(caseType->SubstituteImpl(subst, &diff).dynamicCast<Type>()); } if(!diff) return this; @@ -2628,7 +2626,7 @@ void Type::accept(IValVisitor* visitor, void* extra) bool TaggedUnionSubtypeWitness::EqualsVal(Val* val) { - auto taggedUnionWitness = val->dynamicCast<TaggedUnionSubtypeWitness>(); + auto taggedUnionWitness = dynamicCast<TaggedUnionSubtypeWitness>(val); if(!taggedUnionWitness) return false; @@ -2674,8 +2672,8 @@ RefPtr<Val> TaggedUnionSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int { int diff = 0; - auto substSub = sub->SubstituteImpl(subst, &diff).As<Type>(); - auto substSup = sup->SubstituteImpl(subst, &diff).As<Type>(); + auto substSub = sub->SubstituteImpl(subst, &diff).dynamicCast<Type>(); + auto substSup = sup->SubstituteImpl(subst, &diff).dynamicCast<Type>(); List<RefPtr<Val>> substCaseWitnesses; for( auto caseWitness : caseWitnesses ) diff --git a/source/slang/syntax.h b/source/slang/syntax.h index 5db762f11..076d62d71 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -1,5 +1,5 @@ -#ifndef RASTER_RENDERER_SYNTAX_H -#define RASTER_RENDERER_SYNTAX_H +#ifndef SLANG_SYNTAX_H +#define SLANG_SYNTAX_H #include "../core/basic.h" #include "ir.h" @@ -291,10 +291,7 @@ namespace Slang struct QualType { RefPtr<Type> type; - bool IsLeftValue; - - template <typename T> - T* As(); + bool IsLeftValue; QualType() : IsLeftValue(false) @@ -307,6 +304,7 @@ namespace Slang Type* Ptr() { return type.Ptr(); } + operator Type*() { return type; } operator RefPtr<Type>() { return type; } RefPtr<Type> operator->() { return type; } }; @@ -432,7 +430,7 @@ namespace Slang Decl* decl = nullptr; Decl* getDecl() const { return decl; } - // Optionally, a chain of substititions to perform + // Optionally, a chain of substitutions to perform SubstitutionSet substitutions; DeclRefBase() @@ -452,7 +450,7 @@ namespace Slang , substitutions(subst) {} - // Apply substitutions to a type or ddeclaration + // Apply substitutions to a type or declaration RefPtr<Type> Substitute(RefPtr<Type> type) const; DeclRefBase Substitute(DeclRefBase declRef) const; @@ -506,10 +504,10 @@ namespace Slang // "dynamic cast" to a more specific declaration reference type template<typename U> - DeclRef<U> As() const + DeclRef<U> as() const { DeclRef<U> result; - result.decl = dynamic_cast<U*>(decl); + result.decl = Slang::as<U>(decl); result.substitutions = substitutions; return result; } @@ -618,7 +616,7 @@ namespace Slang { while (cursor != end) { - if ((*cursor).As<T>()) + if (dynamicCast<T>(*cursor)) return cursor; cursor++; } @@ -728,7 +726,7 @@ namespace Slang while (ptr != end) { DeclRef<Decl> declRef(ptr->Ptr(), substitutions); - if (declRef.As<T>()) + if (declRef.as<T>()) return ptr; ptr++; } @@ -1064,7 +1062,7 @@ namespace Slang RefPtr<Val> getVal() { SLANG_ASSERT(getFlavor() == Flavor::val); - return m_obj.As<Val>(); + return m_obj.dynamicCast<Val>(); } RefPtr<WitnessTable> getWitnessTable(); @@ -1116,13 +1114,6 @@ namespace Slang #include "object-meta-end.h" - - template <typename T> - SLANG_FORCE_INLINE T* QualType::As() - { - return type ? type->As<T>() : nullptr; - } - inline RefPtr<Type> GetSub(DeclRef<GenericTypeConstraintDecl> const& declRef) { return declRef.Substitute(declRef.getDecl()->sub.Ptr()); @@ -1166,13 +1157,15 @@ namespace Slang // - inline BaseType GetVectorBaseType(VectorExpressionType* vecType) { - return vecType->elementType->AsBasicType()->baseType; + inline BaseType GetVectorBaseType(VectorExpressionType* vecType) + { + auto basicExprType = as<BasicExpressionType>(vecType->elementType); + return basicExprType->baseType; } inline int GetVectorSize(VectorExpressionType* vecType) { - auto constantVal = vecType->elementCount.As<ConstantIntVal>(); + auto constantVal = vecType->elementCount.dynamicCast<ConstantIntVal>(); if (constantVal) return (int) constantVal->value; // TODO: what to do in this case? @@ -1205,7 +1198,7 @@ namespace Slang List<DeclRef<T>> rs; for (auto d : getMembersOfType<T>(declRef)) rs.Add(d); - if (auto aggDeclRef = declRef.As<AggTypeDecl>()) + if (auto aggDeclRef = declRef.as<AggTypeDecl>()) { for (auto ext = GetCandidateExtensions(aggDeclRef); ext; ext = ext->nextCandidateExtension) { diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp index 05c61e706..b902db826 100644 --- a/source/slang/type-layout.cpp +++ b/source/slang/type-layout.cpp @@ -845,11 +845,11 @@ static LayoutSize GetElementCount(RefPtr<IntVal> val) if(!val) return LayoutSize::infinite(); - if (auto constantVal = val.As<ConstantIntVal>()) + if (auto constantVal = as<ConstantIntVal>(val)) { return LayoutSize(LayoutSize::RawValue(constantVal->value)); } - else if( auto varRefVal = val.As<GenericParamIntVal>() ) + else if( auto varRefVal = as<GenericParamIntVal>(val) ) { // TODO: We want to treat the case where the number of // elements in an array depends on a generic parameter @@ -905,19 +905,19 @@ static SimpleLayoutInfo getParameterGroupLayoutInfo( RefPtr<ParameterGroupType> type, LayoutRulesImpl* rules) { - if( type->As<ConstantBufferType>() ) + if( as<ConstantBufferType>(type) ) { return rules->GetObjectLayout(ShaderParameterKind::ConstantBuffer); } - else if( type->As<TextureBufferType>() ) + else if( as<TextureBufferType>(type) ) { return rules->GetObjectLayout(ShaderParameterKind::TextureUniformBuffer); } - else if( type->As<GLSLShaderStorageBufferType>() ) + else if( as<GLSLShaderStorageBufferType>(type) ) { return rules->GetObjectLayout(ShaderParameterKind::ShaderStorageBuffer); } - else if (type->As<ParameterBlockType>()) + else if (as<ParameterBlockType>(type)) { // Note: we default to consuming zero register spces here, because // a parameter block might not contain anything (or all it contains @@ -935,11 +935,11 @@ static SimpleLayoutInfo getParameterGroupLayoutInfo( // TODO: the vertex-input and fragment-output cases should // only actually apply when we are at the appropriate stage in // the pipeline... - else if( type->As<GLSLInputParameterGroupType>() ) + else if( as<GLSLInputParameterGroupType>(type) ) { return SimpleLayoutInfo(LayoutResourceKind::VertexInput, 0); } - else if( type->As<GLSLOutputParameterGroupType>() ) + else if( as<GLSLOutputParameterGroupType>(type) ) { return SimpleLayoutInfo(LayoutResourceKind::FragmentOutput, 0); } @@ -1093,7 +1093,7 @@ RefPtr<TypeLayout> applyOffsetToTypeLayout( return oldTypeLayout; RefPtr<TypeLayout> newTypeLayout; - if (auto oldStructTypeLayout = oldTypeLayout.As<StructTypeLayout>()) + if (auto oldStructTypeLayout = oldTypeLayout.as<StructTypeLayout>()) { RefPtr<StructTypeLayout> newStructTypeLayout = new StructTypeLayout(); newStructTypeLayout->type = oldStructTypeLayout->type; @@ -1212,15 +1212,14 @@ createParameterGroupTypeLayout( // in HLSL or not. // Check if we are working with a parameter block... - auto parameterBlockType = parameterGroupType ? parameterGroupType->As<ParameterBlockType>() : nullptr; - - + auto parameterBlockType = as<ParameterBlockType>(parameterGroupType); + // Check if we have a parameter block *and* it should be // allocated into its own register space(s) bool ownRegisterSpace = false; if (parameterBlockType) { - // Should we allocate this block its own regsiter space? + // Should we allocate this block its own register space? if( shouldAllocateRegisterSpaceForParameterBlock(context) ) { ownRegisterSpace = true; @@ -1419,27 +1418,27 @@ LayoutRulesImpl* getParameterBufferElementTypeLayoutRules( RefPtr<ParameterGroupType> parameterGroupType, LayoutRulesImpl* rules) { - if( parameterGroupType->As<ConstantBufferType>() ) + if( as<ConstantBufferType>(parameterGroupType) ) { return rules->getLayoutRulesFamily()->getConstantBufferRules(); } - else if( parameterGroupType->As<TextureBufferType>() ) + else if( as<TextureBufferType>(parameterGroupType) ) { return rules->getLayoutRulesFamily()->getTextureBufferRules(); } - else if( parameterGroupType->As<GLSLInputParameterGroupType>() ) + else if( as<GLSLInputParameterGroupType>(parameterGroupType) ) { return rules->getLayoutRulesFamily()->getVaryingInputRules(); } - else if( parameterGroupType->As<GLSLOutputParameterGroupType>() ) + else if( as<GLSLOutputParameterGroupType>(parameterGroupType) ) { return rules->getLayoutRulesFamily()->getVaryingOutputRules(); } - else if( parameterGroupType->As<GLSLShaderStorageBufferType>() ) + else if( as<GLSLShaderStorageBufferType>(parameterGroupType) ) { return rules->getLayoutRulesFamily()->getShaderStorageBufferRules(); } - else if (parameterGroupType->As<ParameterBlockType>()) + else if (as<ParameterBlockType>(parameterGroupType)) { return rules->getLayoutRulesFamily()->getParameterBlockRules(); } @@ -1668,7 +1667,7 @@ static RefPtr<TypeLayout> maybeAdjustLayoutForArrayElementType( // Let's look at the type layout we have, and see if there is anything // that we need to do with it. // - if( auto originalArrayTypeLayout = originalTypeLayout.As<ArrayTypeLayout>() ) + if( auto originalArrayTypeLayout = originalTypeLayout.as<ArrayTypeLayout>() ) { // The element type is itself an array, so we'll need to adjust // *its* element type accordingly. @@ -1696,7 +1695,7 @@ static RefPtr<TypeLayout> maybeAdjustLayoutForArrayElementType( return adjustedArrayTypeLayout; } - else if(auto originalParameterGroupTypeLayout = originalTypeLayout.As<ParameterGroupTypeLayout>() ) + else if(auto originalParameterGroupTypeLayout = originalTypeLayout.as<ParameterGroupTypeLayout>() ) { auto originalInnerElementTypeLayout = originalParameterGroupTypeLayout->elementVarLayout->typeLayout; auto adjustedInnerElementTypeLayout = maybeAdjustLayoutForArrayElementType( @@ -1715,7 +1714,7 @@ static RefPtr<TypeLayout> maybeAdjustLayoutForArrayElementType( SLANG_UNIMPLEMENTED_X("array of parameter group"); UNREACHABLE_RETURN(originalTypeLayout); } - else if(auto originalStructTypeLayout = originalTypeLayout.As<StructTypeLayout>() ) + else if(auto originalStructTypeLayout = originalTypeLayout.as<StructTypeLayout>() ) { UInt fieldCount = originalStructTypeLayout->fields.Count(); @@ -1782,7 +1781,7 @@ static RefPtr<TypeLayout> maybeAdjustLayoutForArrayElementType( { // If we are making an unbounded array, then a `struct` // field with resource type will turn into its own space, - // and it will start at regsiter zero in that space. + // and it will start at register zero in that space. // resInfo.index = 0; resInfo.space = spaceOffsetForField.getFiniteValue(); @@ -1831,7 +1830,7 @@ SimpleLayoutInfo GetLayoutImpl( { auto rules = context.rules; - if (auto parameterGroupType = type->As<ParameterGroupType>()) + if (auto parameterGroupType = as<ParameterGroupType>(type)) { // If the user is just interested in uniform layout info, // then this is easy: a `ConstantBuffer<T>` is really no @@ -1860,7 +1859,7 @@ SimpleLayoutInfo GetLayoutImpl( return info; } - else if (auto samplerStateType = type->As<SamplerStateType>()) + else if (auto samplerStateType = as<SamplerStateType>(type)) { return GetSimpleLayoutImpl( rules->GetObjectLayout(ShaderParameterKind::SamplerState), @@ -1868,7 +1867,7 @@ SimpleLayoutInfo GetLayoutImpl( rules, outTypeLayout); } - else if (auto textureType = type->As<TextureType>()) + else if (auto textureType = as<TextureType>(type)) { // TODO: the logic here should really be defined by the rules, // and not at this top level... @@ -1890,7 +1889,7 @@ SimpleLayoutInfo GetLayoutImpl( rules, outTypeLayout); } - else if (auto imageType = type->As<GLSLImageType>()) + else if (auto imageType = as<GLSLImageType>(type)) { // TODO: the logic here should really be defined by the rules, // and not at this top level... @@ -1912,7 +1911,7 @@ SimpleLayoutInfo GetLayoutImpl( rules, outTypeLayout); } - else if (auto textureSamplerType = type->As<TextureSamplerType>()) + else if (auto textureSamplerType = as<TextureSamplerType>(type)) { // TODO: the logic here should really be defined by the rules, // and not at this top level... @@ -1937,7 +1936,7 @@ SimpleLayoutInfo GetLayoutImpl( // TODO: need a better way to handle this stuff... #define CASE(TYPE, KIND) \ - else if(auto type_##TYPE = type->As<TYPE>()) do { \ + else if(auto type_##TYPE = as<TYPE>(type)) do { \ auto info = rules->GetObjectLayout(ShaderParameterKind::KIND); \ if (outTypeLayout) \ { \ @@ -1961,7 +1960,7 @@ SimpleLayoutInfo GetLayoutImpl( // TODO: need a better way to handle this stuff... #define CASE(TYPE, KIND) \ - else if(type->As<TYPE>()) do { \ + else if(as<TYPE>(type)) do { \ return GetSimpleLayoutImpl( \ rules->GetObjectLayout(ShaderParameterKind::KIND), \ type, rules, outTypeLayout); \ @@ -1981,7 +1980,7 @@ SimpleLayoutInfo GetLayoutImpl( // // TODO(tfoley): Need to recognize any UAV types here // - else if(auto basicType = type->As<BasicExpressionType>()) + else if(auto basicType = as<BasicExpressionType>(type)) { return GetSimpleLayoutImpl( rules->GetScalarLayout(basicType->baseType), @@ -1989,7 +1988,7 @@ SimpleLayoutInfo GetLayoutImpl( rules, outTypeLayout); } - else if(auto vecType = type->As<VectorExpressionType>()) + else if(auto vecType = as<VectorExpressionType>(type)) { return GetSimpleLayoutImpl( rules->GetVectorLayout( @@ -1999,7 +1998,7 @@ SimpleLayoutInfo GetLayoutImpl( rules, outTypeLayout); } - else if(auto matType = type->As<MatrixExpressionType>()) + else if(auto matType = as<MatrixExpressionType>(type)) { // The `GetMatrixLayout` implementation in the layout rules // currently defaults to assuming column-major layout, @@ -2040,7 +2039,7 @@ SimpleLayoutInfo GetLayoutImpl( return info; } - else if (auto arrayType = type->As<ArrayExpressionType>()) + else if (auto arrayType = as<ArrayExpressionType>(type)) { RefPtr<TypeLayout> elementTypeLayout; auto elementInfo = GetLayoutImpl( @@ -2192,11 +2191,11 @@ SimpleLayoutInfo GetLayoutImpl( } return arrayUniformInfo; } - else if (auto declRefType = type->As<DeclRefType>()) + else if (auto declRefType = as<DeclRefType>(type)) { auto declRef = declRefType->declRef; - if (auto structDeclRef = declRef.As<StructDecl>()) + if (auto structDeclRef = declRef.as<StructDecl>()) { RefPtr<StructTypeLayout> typeLayout; if (outTypeLayout) @@ -2316,7 +2315,7 @@ SimpleLayoutInfo GetLayoutImpl( return info; } - else if (auto globalGenParam = declRef.As<GlobalGenericParamDecl>()) + else if (auto globalGenParam = declRef.as<GlobalGenericParamDecl>()) { SimpleLayoutInfo info; info.alignment = 0; @@ -2336,7 +2335,7 @@ SimpleLayoutInfo GetLayoutImpl( return info; } } - else if (auto errorType = type->As<ErrorType>()) + else if (auto errorType = as<ErrorType>(type)) { // An error type means that we encountered something we don't understand. // @@ -2349,7 +2348,7 @@ SimpleLayoutInfo GetLayoutImpl( rules, outTypeLayout); } - else if( auto taggedUnionType = type->As<TaggedUnionType>() ) + else if( auto taggedUnionType = as<TaggedUnionType>(type) ) { // A tagged union type needs to be laid out as the maximum // size of any constituent type. @@ -2493,9 +2492,9 @@ RefPtr<TypeLayout> TypeLayout::unwrapArray() RefPtr<GlobalGenericParamDecl> GenericParamTypeLayout::getGlobalGenericParamDecl() { - auto declRefType = type->AsDeclRefType(); + auto declRefType = as<DeclRefType>(type); SLANG_ASSERT(declRefType); - auto rsDeclRef = declRefType->declRef.As<GlobalGenericParamDecl>(); + auto rsDeclRef = declRefType->declRef.as<GlobalGenericParamDecl>(); return rsDeclRef.getDecl(); } diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h index da2f0e4f7..418f4684d 100644 --- a/source/slang/type-layout.h +++ b/source/slang/type-layout.h @@ -575,6 +575,18 @@ public: LayoutSize tagOffset; }; + /// Layout for a scoped entity like a program, module, or entry point +class ScopeLayout : public Layout +{ +public: + // The layout for the parameters of this entity. + // + RefPtr<VarLayout> parametersLayout; +}; + +StructTypeLayout* getScopeStructLayout( + ScopeLayout* programLayout); + // Layout information for a single shader entry point // within a program // @@ -584,7 +596,7 @@ public: // // TODO: where to store layout info for the return // type of the function? -class EntryPointLayout : public StructTypeLayout +class EntryPointLayout : public ScopeLayout { public: // The corresponding function declaration @@ -617,9 +629,10 @@ public: }; // Layout information for the global scope of a program -class ProgramLayout : public Layout +class ProgramLayout : public ScopeLayout { public: + /* // We store a layout for the declarations at the global // scope. Note that this will *either* be a single // `StructTypeLayout` with the fields stored directly, @@ -634,6 +647,7 @@ public: // to store them). // RefPtr<VarLayout> globalScopeLayout; + */ // We catalog the requested entry points here, // and any entry-point-specific parameter data @@ -646,6 +660,9 @@ public: TargetRequest* targetRequest = nullptr; }; +StructTypeLayout* getGlobalStructLayout( + ProgramLayout* programLayout); + struct LayoutRulesFamilyImpl; // A delineation of shader parameter types into fine-grained diff --git a/tests/compute/entry-point-uniform-params.slang b/tests/compute/entry-point-uniform-params.slang new file mode 100644 index 000000000..f91f7d146 --- /dev/null +++ b/tests/compute/entry-point-uniform-params.slang @@ -0,0 +1,51 @@ +// entry-point-uniform-params.slang + +// Confirm that `uniform` parameters on +// entry points are allowed, and work as expected. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute + +struct Signs +{ + int a; +} + +struct Stuff +{ + int b; +} + +struct Things +{ + int c; +} + +// A shader parameter at global scope should be assigned +// a register/binding before any related to the entry point. + +//TEST_INPUT:cbuffer(data=[1 0 0 0]):dxbinding(0),glbinding(0) +ConstantBuffer<Signs> signs; + +[numthreads(4, 1, 1)] +void computeMain( +//TEST_INPUT:cbuffer(data=[2 0 0 0 3 0 0 0]):dxbinding(1),glbinding(1) + uniform Stuff stuff, + uniform Things things, + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(2),out + uniform RWStructuredBuffer<int> outputBuffer, + + uint3 dispatchThreadID : SV_DispatchThreadID) +{ + int tid = dispatchThreadID.x; + + int val = 0; + val = val*16 + signs.a; + val = val*16 + stuff.b; + val = val*16 + things.c; + val = val*16 + tid; + + outputBuffer[tid] = val; +}
\ No newline at end of file diff --git a/tests/compute/entry-point-uniform-params.slang.expected.txt b/tests/compute/entry-point-uniform-params.slang.expected.txt new file mode 100644 index 000000000..ef2c43c16 --- /dev/null +++ b/tests/compute/entry-point-uniform-params.slang.expected.txt @@ -0,0 +1,4 @@ +1230 +1231 +1232 +1233 diff --git a/tests/compute/global-type-param.slang b/tests/compute/global-type-param.slang index 2638852eb..f177dcb1d 100644 --- a/tests/compute/global-type-param.slang +++ b/tests/compute/global-type-param.slang @@ -28,10 +28,10 @@ struct Impl : IBase __generic_param TImpl : IBase; -TImpl impl; - [numthreads(1, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +void computeMain( + uniform TImpl impl, + uint3 dispatchThreadID : SV_DispatchThreadID) { uint tid = dispatchThreadID.x; float outVal = impl.compute(); diff --git a/tests/reflection/sample-index-input.hlsl.expected b/tests/reflection/sample-index-input.hlsl.expected index 5bf5f297e..c799f8f25 100644 --- a/tests/reflection/sample-index-input.hlsl.expected +++ b/tests/reflection/sample-index-input.hlsl.expected @@ -29,7 +29,6 @@ standard output = { "scalarType": "float32" } }, - "stage": "fragment", "binding": {"kind": "varyingInput", "index": 0}, "semanticName": "COLOR" }, diff --git a/tests/reflection/sample-rate-input.hlsl.expected b/tests/reflection/sample-rate-input.hlsl.expected index 0c86ebecb..ec6cfca6e 100644 --- a/tests/reflection/sample-rate-input.hlsl.expected +++ b/tests/reflection/sample-rate-input.hlsl.expected @@ -29,7 +29,6 @@ standard output = { "scalarType": "float32" } }, - "stage": "fragment", "binding": {"kind": "varyingInput", "index": 0}, "semanticName": "EXTRA" }, @@ -43,7 +42,6 @@ standard output = { "scalarType": "float32" } }, - "stage": "fragment", "binding": {"kind": "varyingInput", "index": 1}, "semanticName": "COLOR" } diff --git a/tests/reflection/vertex-input-semantics.hlsl.expected b/tests/reflection/vertex-input-semantics.hlsl.expected index 06b7bc95a..2ff8d7847 100644 --- a/tests/reflection/vertex-input-semantics.hlsl.expected +++ b/tests/reflection/vertex-input-semantics.hlsl.expected @@ -44,7 +44,6 @@ standard output = { "scalarType": "int32" } }, - "stage": "vertex", "binding": {"kind": "varyingInput", "index": 0}, "semanticName": "B" }, @@ -64,7 +63,6 @@ standard output = { "scalarType": "float32" } }, - "stage": "vertex", "binding": {"kind": "varyingInput", "index": 0}, "semanticName": "B", "semanticIndex": 1 @@ -79,14 +77,12 @@ standard output = { "scalarType": "float32" } }, - "stage": "vertex", "binding": {"kind": "varyingInput", "index": 1}, "semanticName": "B", "semanticIndex": 2 } ] }, - "stage": "vertex", "binding": {"kind": "varyingInput", "index": 1, "count": 2}, "semanticName": "B", "semanticIndex": 1 @@ -118,7 +114,6 @@ standard output = { "scalarType": "float32" } }, - "stage": "vertex", "binding": {"kind": "varyingInput", "index": 0}, "semanticName": "CX" }, @@ -132,14 +127,12 @@ standard output = { "scalarType": "float32" } }, - "stage": "vertex", "binding": {"kind": "varyingInput", "index": 1}, "semanticName": "CX", "semanticIndex": 1 } ] }, - "stage": "vertex", "binding": {"kind": "varyingInput", "index": 0, "count": 2}, "semanticName": "CX" }, @@ -153,7 +146,6 @@ standard output = { "scalarType": "int32" } }, - "stage": "vertex", "binding": {"kind": "varyingInput", "index": 2}, "semanticName": "CY" } |
