diff options
| author | jsmall-nvidia <jsmall@nvidia.com> | 2023-07-05 13:23:14 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-07-05 13:23:14 -0400 |
| commit | 69450a2be7575aa4f984b9ae2824da0e5634c9f0 (patch) | |
| tree | d554404f441af7fd113737cae8e1bde4897a814e /source | |
| parent | f9b73eab7edcedc9dc2c7825fcd4171631d14ac7 (diff) | |
Initial sizeof/alignof implementation. (#2954)
* Initial sizeof implementation.
* Small macro improvement.
* Fix some typos.
* Refactor NaturalSize.
Add more sizeof tests.
* Use _makeParseExpr to add sizeof support.
* Add size-of.slang diagnostic result.
* Fix typo in folding with macro change.
* Add a sizeof test of This.
* Some more NaturalSize coverage.
* Simple alignof support.
* Testing for alignof.
* Added 8 bit enum to check enums values are correctly sized.
* Add alignof to completion.
* Lower sizeof/alignof to IR.
sizeof/alignof IR pass.
Tests for simple generic scenarios.
* Make append handle invalid properly.
Improve comments.
---------
Co-authored-by: Theresa Foley <10618364+tangent-vector@users.noreply.github.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-expr.h | 21 | ||||
| -rw-r--r-- | source/slang/slang-ast-natural-layout.cpp | 249 | ||||
| -rw-r--r-- | source/slang/slang-ast-natural-layout.h | 103 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 105 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 83 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-size-of.cpp | 106 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-size-of.h | 17 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 25 | ||||
| -rw-r--r-- | source/slang/slang-language-server-completion.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 40 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 68 |
16 files changed, 780 insertions, 61 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 36d6546de..c441e1b9b 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -376,6 +376,27 @@ class AsTypeExpr : public Expr }; +class SizeOfLikeExpr : public Expr +{ + SLANG_AST_CLASS(SizeOfLikeExpr); + + // Set during the parse, could be an expression, a variable or a type + Expr* value = nullptr; + + // The type the size/alignment needs to operate on. Set during traversal of SemanticsExprVisitor + Type* sizedType = nullptr; +}; + +class SizeOfExpr : public SizeOfLikeExpr +{ + SLANG_AST_CLASS(SizeOfExpr); +}; + +class AlignOfExpr : public SizeOfLikeExpr +{ + SLANG_AST_CLASS(AlignOfExpr); +}; + class MakeOptionalExpr : public Expr { SLANG_AST_CLASS(MakeOptionalExpr) diff --git a/source/slang/slang-ast-natural-layout.cpp b/source/slang/slang-ast-natural-layout.cpp new file mode 100644 index 000000000..1789c5cea --- /dev/null +++ b/source/slang/slang-ast-natural-layout.cpp @@ -0,0 +1,249 @@ +// slang-ast-natural-layout.cpp +#include "slang-ast-natural-layout.h" + +#include "slang-ast-builder.h" + +// For BaseInfo +#include "slang-compiler.h" + +namespace Slang +{ + +/* !!!!!!!!!!!!!!!!!!!!!!!!! NaturalSize !!!!!!!!!!!!!!!!!!!!!!!!!!!! */ + + +NaturalSize NaturalSize::operator*(Count count) const +{ + // If the count is < 0 or the size is invalid, the result is invalid + if (isInvalid() || count < 0) + { + return makeInvalid(); + } + + if (count <= 0) + { + // If the count is 0, in effect the result doesn't take up any space + return makeEmpty(); + } + else + { + // We don't want to produce an aligned size, as we allow the last element to not + // take up a whole stride (only up to size) + return make(size + (getStride() * (count - 1)), alignment); + } +} + +/* static */NaturalSize NaturalSize::makeFromBaseType(BaseType baseType) +{ + // Special case void + if (baseType == BaseType::Void) + { + return makeEmpty(); + } + else + { + // In "natural" layout the alignment of a base type is always the same + // as the size of the type itself + auto info = BaseTypeInfo::getInfo(baseType); + return make(info.sizeInBytes, info.sizeInBytes); + } +} + +/* static */NaturalSize NaturalSize::calcUnion(NaturalSize a, NaturalSize b) +{ + const auto alignment = maxAlignment(a.alignment, b.alignment); + Count size = (alignment == kInvalidAlignment) ? 0 : Math::Max(a.size, b.size); + return make(size, alignment); +} + +/* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTNaturalLayoutContext !!!!!!!!!!!!!!!!!!!!!!!!!!!! */ + +ASTNaturalLayoutContext::ASTNaturalLayoutContext(ASTBuilder* astBuilder, DiagnosticSink* sink): + m_astBuilder(astBuilder), + m_sink(sink) +{ + // A null type always maps to invalid + m_typeToSize.add(nullptr, NaturalSize::makeInvalid()); +} + +Count ASTNaturalLayoutContext::_getCount(IntVal* intVal) +{ + if (auto constIntVal = as<ConstantIntVal>(intVal)) + { + if (constIntVal->value >= 0) + { + return Count(constIntVal->value); + } + } + + if (m_sink) + { + // Could output an error + } + + return -1; +} + +NaturalSize ASTNaturalLayoutContext::calcSize(Type* type) +{ + if (auto sizePtr = m_typeToSize.tryGetValue(type)) + { + return *sizePtr; + } + + // Calc the size + const NaturalSize size = _calcSizeImpl(type); + + // We want to add to the cache, but we need to special case + // in case there is an aggregate type that `poisoned` the cache entry, to stop infinite recursion. + // + // A requirement is that when the agg type completes it must set the cache entry, and return the same result. + if (auto foundSize = m_typeToSize.tryGetValueOrAdd(type, size)) + { + // If there is a found size, it must match. If not we update the state as invalid. + if (*foundSize != size) + { + *foundSize = NaturalSize::makeInvalid(); + return *foundSize; + } + } + + return size; +} + +NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type) +{ + if (VectorExpressionType* vecType = as<VectorExpressionType>(type)) + { + const Count elementCount = _getCount(vecType->elementCount); + return (elementCount > 0) ? + calcSize(vecType->elementType) * elementCount : + NaturalSize::makeInvalid(); + } + else if (auto matType = as<MatrixExpressionType>(type)) + { + const Count colCount = _getCount(matType->getColumnCount()); + const Count rowCount = _getCount(matType->getRowCount()); + return (colCount > 0 && rowCount > 0) ? + calcSize(matType->getElementType()) * (colCount * rowCount) : + NaturalSize::makeInvalid(); + } + else if (auto basicType = as<BasicExpressionType>(type)) + { + return NaturalSize::makeFromBaseType(basicType->baseType); + } + else if (as<PtrTypeBase>(type) || as<NullPtrType>(type)) + { + // We assume 64 bits/8 bytes across the board + return NaturalSize::makeFromBaseType(BaseType::UInt64); + } + else if (auto arrayType = as<ArrayExpressionType>(type)) + { + const Count elementCount = _getCount(arrayType->getElementCount()); + return (elementCount > 0) ? + calcSize(arrayType->getElementType()) * elementCount : + NaturalSize::makeInvalid(); + } + else if (auto namedType = as<NamedExpressionType>(type)) + { + return calcSize(namedType->innerType); + } + else if (const auto tupleType = as<TupleType>(type)) + { + // Initialize empty + NaturalSize size = NaturalSize::makeEmpty(); + + // Accumulate over all the member types + for (auto cur : tupleType->memberTypes) + { + const auto curSize = calcSize(cur); + if (!curSize) + { + return NaturalSize::makeInvalid(); + } + size.append(curSize); + } + + return size; + } + else if (const auto taggedUnion = as<TaggedUnionType>(type)) + { + NaturalSize size = NaturalSize::makeInvalid(); + + for( auto caseType : taggedUnion->caseTypes ) + { + const NaturalSize caseSize = calcSize(caseType); + if (!caseSize) + { + return NaturalSize::makeInvalid(); + } + size = NaturalSize::calcUnion(size, caseSize); + } + + // After we've computed the size required to hold all the + // case types, we will allocate space for the tag field. + + // Currently we assume uint32_t on all targets + size.append(NaturalSize::makeFromBaseType(BaseType::UInt)); + + return size; + } + else if( auto declRefType = as<DeclRefType>(type) ) + { + if (const auto enumDeclRef = declRefType->declRef.as<EnumDecl>()) + { + Type* tagType = getTagType(m_astBuilder, enumDeclRef); + return calcSize(tagType); + } + else if(const auto structDeclRef = declRefType->declRef.as<StructDecl>()) + { + // Poison the cache whilst we construct + m_typeToSize.add(type, NaturalSize::makeInvalid()); + + // Initialize empty + NaturalSize size = NaturalSize::makeEmpty(); + + for (auto inherited : structDeclRef.getDecl()->getMembersOfType<InheritanceDecl>()) + { + // Look for a struct type that it inherits from + if (auto inheritedDeclRef = as<DeclRefType>(inherited->base.type)) + { + if (auto parentDecl = inheritedDeclRef->declRef.as<StructDecl>()) + { + // We can only inherit from one thing + size = calcSize(inherited->base.type); + if (!size) + { + return size; + } + break; + } + } + } + + // Accumulate over all of the fields + for (auto field : structDeclRef.getDecl()->getFields()) + { + const auto fieldSize = calcSize(field->getType()); + if (!fieldSize) + { + return NaturalSize::makeInvalid(); + } + size.append(fieldSize); + } + + // Set the cached result to the size. + m_typeToSize.set(type, size); + + return size; + } + else if (const auto typeDef = declRefType->declRef.as<TypeDefDecl>()) + { + return calcSize(typeDef.getDecl()->type); + } + } + + return NaturalSize::makeInvalid(); +} + +} // namespace Slang diff --git a/source/slang/slang-ast-natural-layout.h b/source/slang/slang-ast-natural-layout.h new file mode 100644 index 000000000..4a165973d --- /dev/null +++ b/source/slang/slang-ast-natural-layout.h @@ -0,0 +1,103 @@ +#ifndef SLANG_AST_NATURAL_LAYOUT_H +#define SLANG_AST_NATURAL_LAYOUT_H + +#include "slang-ast-base.h" + +namespace Slang +{ + +struct NaturalSize +{ + typedef NaturalSize ThisType; + + // We are going to use 0 as invalid for alignment. This has a few nice propeties + // + // * Will naturally produce 0 size when used with `calcAligned` operation + // * Is fast to test + // * Is easy to make a fast 'max' such that a max with invalid always returns `invalid` + // + // We also desire that when invalid the `size` member is 0. + // This is so that equality testing doesn't require anything special. + SLANG_FORCE_INLINE static Count calcAligned(Count size, Count alignment) { return (size + alignment - 1) & ~(alignment - 1); } + // Use to get the max of two alignments. Uses some maths such that `invalid` is always max + SLANG_FORCE_INLINE static Count maxAlignment(Count a, Count b) { return (UCount(a) - 1) > (UCount(b) - 1) ? a : b; } + + /// Given two sizes, returns a result that can hold the union. + static NaturalSize calcUnion(NaturalSize a, NaturalSize b); + + /// Value chosen such that normal combining operations produce an invalid result + /// as typically a max. + static const Count kInvalidAlignment = 0; + + /// Get the stride, which is equivalent to the size aligned + SLANG_FORCE_INLINE Count getStride() const { return calcAligned(size, alignment); } + + /// Append rhs to this. + /// If rhs is invalid or this is the result will also be invalid + void append(const ThisType& rhs) + { + const auto newAlignment = maxAlignment(alignment, rhs.alignment); + + // If the new alignment is valid we calculate the size, else it's 0 + size = (newAlignment != kInvalidAlignment) ? + (calcAligned(size, rhs.alignment) + rhs.size) : + 0; + + // Set the new alignment + alignment = newAlignment; + } + + SLANG_FORCE_INLINE bool isInvalid() const { return alignment == kInvalidAlignment; } + SLANG_FORCE_INLINE bool isValid() const { return !isInvalid(); } + + bool operator==(const ThisType& rhs) const { return size == rhs.size && alignment == rhs.alignment; } + bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } + + /// Converts to bool to make testing convenient + operator bool() const { return isValid(); } + + /// An empty size. It consumes 0 bytes and has the lowest alignment (1) + static ThisType makeEmpty() { return ThisType{ 0, 1 }; } + /// Make an invalid size. + static ThisType makeInvalid() { return ThisType{ 0, kInvalidAlignment }; } + /// Make a size with an amount of bytes and the alignment + static ThisType make(Count size, Count alignment) { return ThisType{size, alignment}; } + + /// Given a base type returns it's size + static ThisType makeFromBaseType(BaseType baseType); + + /// Multiply by a count. + /// Will return invalid if count < 0 or this is already invalid + ThisType operator*(Count count) const; + + Count size; + Count alignment; +}; + +struct ASTNaturalLayoutContext +{ + /// Given a type returns it's natural size. + /// Returns invalid size if types size could not be calculated + NaturalSize calcSize(Type* type); + + /// Ctor + ASTNaturalLayoutContext(ASTBuilder* astBuilder, DiagnosticSink* sink = nullptr); + +protected: + + /// Gets a count (positivie integer including 0). + /// <0 indicates error + Count _getCount(IntVal* intVal); + + /// The main implementation, assumes outer `calcSize` will perform caching + NaturalSize _calcSizeImpl(Type* type); + + Dictionary<Type*, NaturalSize> m_typeToSize; + + ASTBuilder* m_astBuilder; + DiagnosticSink* m_sink; +}; + +} // namespace Slang + +#endif diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index d8886f05b..21f876048 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -1400,13 +1400,6 @@ HashCode FuncCallIntVal::_getHashCodeOverride() return result; } -static bool nameIs(Name* name, const char* val) -{ - if (name && name->text.getUnownedSlice() == val) - return true; - return false; -} - Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink) { // Are all args const now? @@ -1428,29 +1421,24 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR { // Evaluate the function. auto opName = newFuncDecl.getName(); + SLANG_ASSERT(opName); + + const auto opNameSlice = opName->text.getUnownedSlice(); + IntegerLiteralValue resultValue = 0; - if (nameIs(opName, "==")) - { - resultValue = constArgs[0]->value / constArgs[1]->value; - } + + // Define convenience macros. + // The last macro used in the list *must* be + // TERMINATING_CASE, as this handles the closing else, and matches if nothing else does. + #define BINARY_OPERATOR_CASE(op) \ - else if (nameIs(opName, #op)) \ + if (opNameSlice == toSlice(#op)) \ { \ resultValue = constArgs[0]->value op constArgs[1]->value; \ - } - BINARY_OPERATOR_CASE(>=) - BINARY_OPERATOR_CASE(<=) - BINARY_OPERATOR_CASE(>) - BINARY_OPERATOR_CASE(<) - BINARY_OPERATOR_CASE(!=) - BINARY_OPERATOR_CASE(<<) - BINARY_OPERATOR_CASE(>>) - BINARY_OPERATOR_CASE(&) - BINARY_OPERATOR_CASE(|) - BINARY_OPERATOR_CASE(^) -#undef BINARY_OPERATOR_CASE + } else + #define DIV_OPERATOR_CASE(op) \ - else if (nameIs(opName, #op)) \ + if (opNameSlice == toSlice(#op)) \ { \ if (constArgs[1]->value == 0) \ { \ @@ -1459,35 +1447,56 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR return nullptr; \ } \ resultValue = constArgs[0]->value op constArgs[1]->value; \ - } - DIV_OPERATOR_CASE(/) - DIV_OPERATOR_CASE(%) -#undef DIV_OPERATOR_CASE + } else + #define LOGICAL_OPERATOR_CASE(op) \ - else if (nameIs(opName, #op)) \ + if (opNameSlice == toSlice(#op)) \ { \ resultValue = (((constArgs[0]->value!=0) op (constArgs[1]->value!=0)) ? 1 : 0); \ + } else + + +#define SPECIAL_OPERATOR_CASE(op, IF_MATCH) \ + if (opNameSlice == toSlice(op)) \ + { \ + IF_MATCH \ + } else + +#define TERMINATING_CASE(MATCH) \ + { \ + MATCH \ } + + // Handle the cases using the macros + BINARY_OPERATOR_CASE(>=) + BINARY_OPERATOR_CASE(<=) + BINARY_OPERATOR_CASE(>) + BINARY_OPERATOR_CASE(<) + BINARY_OPERATOR_CASE(!=) + BINARY_OPERATOR_CASE(==) + BINARY_OPERATOR_CASE(<<) + BINARY_OPERATOR_CASE(>>) + BINARY_OPERATOR_CASE(&) + BINARY_OPERATOR_CASE(|) + BINARY_OPERATOR_CASE(^) + DIV_OPERATOR_CASE(/) + DIV_OPERATOR_CASE(%) LOGICAL_OPERATOR_CASE(&&) - LOGICAL_OPERATOR_CASE(|| ) -#undef LOGICAL_OPERATOR_CASE - else if (nameIs(opName, "!")) - { - resultValue = ((constArgs[0]->value != 0) ? 1 : 0); - } - else if (nameIs(opName, "~")) - { - resultValue = ~constArgs[0]->value; - } - else if (nameIs(opName, "?:")) - { - resultValue = constArgs[0]->value != 0 ? constArgs[1]->value : constArgs[2]->value; - } - else - { - SLANG_UNREACHABLE("constant folding of FuncCallIntVal"); - } + LOGICAL_OPERATOR_CASE(||) + // Special cases need their "operator" names quoted. + SPECIAL_OPERATOR_CASE("!", resultValue = ((constArgs[0]->value != 0) ? 1 : 0);) + SPECIAL_OPERATOR_CASE("~", resultValue = ~constArgs[0]->value;) + SPECIAL_OPERATOR_CASE("?:", resultValue = constArgs[0]->value != 0 ? constArgs[1]->value : constArgs[2]->value;) + TERMINATING_CASE(SLANG_UNREACHABLE("constant folding of FuncCallIntVal");) + return astBuilder->getIntVal(resultType, resultValue); + + // The macros for the cases are no longer needed so undef them all. +#undef BINARY_OPERATOR_CASE +#undef DIV_OPERATOR_CASE +#undef LOGICAL_OPERATOR_CASE +#undef SPECIAL_OPERATOR_CASE +#undef TERMINATING_CASE } return nullptr; } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 00ece3628..9dc359a5e 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -11,6 +11,8 @@ // // * `slang-check-conversion.cpp` is responsible for the logic of handling type conversion/coercion +#include "slang-ast-natural-layout.h" + #include "slang-lookup.h" #include "slang-ast-print.h" @@ -1279,8 +1281,6 @@ namespace Slang return nullptr; } - - // Let's not constant-fold operations with more than a certain number of arguments, for simplicity static const int kMaxArgs = 8; auto argCount = getArgCount(invokeExpr); @@ -1533,6 +1533,7 @@ namespace Slang SubstExpr<Expr> expr, ConstantFoldingCircularityInfo* circularityInfo) { + // Unwrap any "identity" expressions while (auto parenExpr = expr.as<ParenExpr>()) { @@ -1629,7 +1630,23 @@ namespace Slang if (val) return val; } + else if (auto sizeOfLikeExpr = as<SizeOfLikeExpr>(expr.getExpr())) + { + ASTNaturalLayoutContext context(getASTBuilder(), nullptr); + const auto size = context.calcSize(sizeOfLikeExpr->sizedType); + if (!size) + { + return nullptr; + } + auto value = as<AlignOfExpr>(sizeOfLikeExpr) ? + size.alignment : + size.size; + + // We can return as an IntVal + return getASTBuilder()->getIntVal(expr.getExpr()->type, value); + } + return nullptr; } @@ -2145,6 +2162,7 @@ namespace Slang return rs; } + Expr* SemanticsExprVisitor::visitSelectExpr(SelectExpr* expr) { auto result = visitInvokeExpr(expr); @@ -2695,6 +2713,67 @@ namespace Slang return expr; } + static bool _isSizeOfType(Type* type) + { + if (!type) + { + return false; + } + + if (as<ArithmeticExpressionType>(type) || + as<ArrayExpressionType>(type) || + as<PtrTypeBase>(type) || + as<TupleType>(type) || + as<GenericDeclRefType>(type)) + { + return true; + } + + if (as<DeclRefType>(type)) + { + return true; + } + + return false; + } + + Expr* SemanticsExprVisitor::visitSizeOfLikeExpr(SizeOfLikeExpr* sizeOfLikeExpr) + { + auto valueExpr = dispatch(sizeOfLikeExpr->value); + + Type* type = nullptr; + + if (as<TypeType>(valueExpr->type)) + { + TypeExp typeExp; + typeExp.exp = valueExpr; + + auto properTypeExpr = CoerceToProperType(typeExp); + + type = properTypeExpr.type; + } + else + { + // Is this a proper type? + TypeExp typeExp(valueExpr->type); + TypeExp properType = tryCoerceToProperType(typeExp); + + type = properType.type; + } + + if (!_isSizeOfType(type)) + { + getSink()->diagnose(sizeOfLikeExpr, Diagnostics::sizeOfArgumentIsInvalid); + + sizeOfLikeExpr->type = m_astBuilder->getErrorType(); + return sizeOfLikeExpr; + } + + sizeOfLikeExpr->sizedType = type; + + return sizeOfLikeExpr; + } + Expr* SemanticsExprVisitor::visitTypeCastExpr(TypeCastExpr * expr) { // Check the term we are applying first diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index ac4d51549..26c96a72e 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1933,6 +1933,8 @@ namespace Slang : SemanticsVisitor(outer) {} + Expr* visitSizeOfLikeExpr(SizeOfLikeExpr* expr); + Expr* visitIncompleteExpr(IncompleteExpr* expr); Expr* visitBoolLiteralExpr(BoolLiteralExpr* expr); Expr* visitNullPtrLiteralExpr(NullPtrLiteralExpr* expr); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index d38cde4e6..0fd962614 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -318,6 +318,8 @@ DIAGNOSTIC(30096, Error, differentialTypeShouldServeAsItsOwnDifferentialType, "t DIAGNOSTIC(30097, Error, functionNotMarkedAsDifferentiable, "function '$0' is not marked as $1-differentiable.") DIAGNOSTIC(30098, Error, nonStaticMemberFunctionNotAllowedAsDiffOperand, "non-static function reference '$0' is not allowed here.") +DIAGNOSTIC(30099, Error, sizeOfArgumentIsInvalid, "argument to sizeof is invalid") + DIAGNOSTIC(-1, Note, noteSeeUseOfDifferentialType, "see use of '$0' as Differential of '$1'.") // Attributes @@ -625,6 +627,9 @@ DIAGNOSTIC(41023, Error, getStringHashMustBeOnStringLiteral, "getStringHash can DIAGNOSTIC(41901, Error, unsupportedUseOfLValueForAutoDiff, "unsupported use of L-value for auto differentiation.") DIAGNOSTIC(41902, Error, cannotDifferentiateDynamicallyIndexedData, "cannot auto-differentiate mixed read/write access to dynamically indexed data in '$0'.") +DIAGNOSTIC(41903, Error, unableToSizeOf, "sizeof could not be performed for type '$0'.") +DIAGNOSTIC(41904, Error, unableToAlignOf, "alignof could not be performed for type '$0'.") + DIAGNOSTIC(42001, Error, invalidUseOfTorchTensorTypeInDeviceFunc, "invalid use of TorchTensor type in device/kernel functions. use `TensorView` instead.") // diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 827c69e50..4c050ffcc 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -36,6 +36,7 @@ #include "slang-ir-lower-optional-type.h" #include "slang-ir-lower-bit-cast.h" #include "slang-ir-lower-l-value-cast.h" +#include "slang-ir-lower-size-of.h" #include "slang-ir-lower-reinterpret.h" #include "slang-ir-loop-unroll.h" #include "slang-ir-metadata.h" @@ -867,6 +868,10 @@ Result linkAndOptimizeIR( legalizeUniformBufferLoad(irModule); } + // Lower sizeof/alignof + + lowerSizeOfLike(targetRequest, irModule, sink); + // Lower all the LValue implict casts (used for out/inout/ref scenarios) lowerLValueCast(targetRequest, irModule); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index de3735a55..636264a6d 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -934,6 +934,9 @@ INST(CastPtrToInt, CastPtrToInt, 1, 0) INST(CastIntToPtr, CastIntToPtr, 1, 0) INST(CastToVoid, castToVoid, 1, 0) +INST(SizeOf, sizeOf, 1, 0) +INST(AlignOf, alignOf, 1, 0) + INST(IsType, IsType, 3, 0) INST(ForwardDifferentiate, ForwardDifferentiate, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index f2c00f406..18a0677b0 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3622,6 +3622,12 @@ public: IRType* type, IRInst* val); + IRInst* emitSizeOf( + IRInst* sizedType); + + IRInst* emitAlignOf( + IRInst* sizedType); + IRInst* emitCastPtrToBool(IRInst* val); IRGlobalConstant* emitGlobalConstant( diff --git a/source/slang/slang-ir-lower-size-of.cpp b/source/slang/slang-ir-lower-size-of.cpp new file mode 100644 index 000000000..a8b599031 --- /dev/null +++ b/source/slang/slang-ir-lower-size-of.cpp @@ -0,0 +1,106 @@ +#include "slang-ir-lower-size-of.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +#include "slang-ir-layout.h" + +namespace Slang +{ + +struct SizeOfLikeLoweringContext +{ + void _addToWorkList(IRInst* inst) + { + if (!findOuterGeneric(inst) && !m_workList.contains(inst)) + { + m_workList.add(inst); + } + } + + void _processInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_AlignOf: + case kIROp_SizeOf: + _processSizeOfLike(inst); + break; + default: + break; + } + } + + void processModule() + { + _addToWorkList(m_module->getModuleInst()); + + while (m_workList.getCount() != 0) + { + IRInst* inst = m_workList.getLast(); + m_workList.removeLast(); + + _processInst(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + _addToWorkList(child); + } + } + } + + void _processSizeOfLike(IRInst* sizeOfLikeInst) + { + auto typeOperand = as<IRType>(sizeOfLikeInst->getOperand(0)); + + IRSizeAndAlignment sizeAndAlignment; + + if (SLANG_FAILED(getNaturalSizeAndAlignment(m_targetReq, typeOperand, &sizeAndAlignment))) + { + // Output a diagnostic failure + if(sizeOfLikeInst->getOp() == kIROp_AlignOf) + { + m_sink->diagnose(sizeOfLikeInst, Diagnostics::unableToAlignOf, typeOperand); + } + else + { + m_sink->diagnose(sizeOfLikeInst, Diagnostics::unableToSizeOf, typeOperand); + } + + return; + } + + IRBuilder builder(m_module); + + const auto value = (sizeOfLikeInst->getOp() == kIROp_AlignOf) ? + sizeAndAlignment.alignment : + sizeAndAlignment.size; + + auto valueInst = builder.getIntValue(sizeOfLikeInst->getDataType(), value); + + // Replace all uses of sizeOfLikeInst with the value + sizeOfLikeInst->replaceUsesWith(valueInst); + // We don't need the instruction any more + sizeOfLikeInst->removeAndDeallocate(); + } + + SizeOfLikeLoweringContext(TargetRequest* targetReq, IRModule* module, DiagnosticSink* sink): + m_module(module), + m_targetReq(targetReq), + m_sink(sink) + { + } + + TargetRequest* m_targetReq; + DiagnosticSink* m_sink; + IRModule* m_module; + OrderedHashSet<IRInst*> m_workList; +}; + +void lowerSizeOfLike(TargetRequest* targetReq, IRModule* module, DiagnosticSink* sink) +{ + SizeOfLikeLoweringContext context(targetReq, module, sink); + context.processModule(); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-lower-size-of.h b/source/slang/slang-ir-lower-size-of.h new file mode 100644 index 000000000..4205aa2f0 --- /dev/null +++ b/source/slang/slang-ir-lower-size-of.h @@ -0,0 +1,17 @@ +#ifndef SLANG_IR_LOWER_SIZE_OF_H +#define SLANG_IR_LOWER_SIZE_OF_H + +// This defines an IR pass that lowers sizeof/alignof. + +namespace Slang +{ + +struct IRModule; +class TargetRequest; +class DiagnosticSink; + +void lowerSizeOfLike(TargetRequest* target, IRModule* module, DiagnosticSink* sink); + +} // namespace Slang + +#endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 789349b4c..74679de96 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5268,6 +5268,31 @@ namespace Slang return inst; } + + IRInst* IRBuilder::emitSizeOf( + IRInst* sizedType) + { + auto inst = createInst<IRInst>( + this, + kIROp_SizeOf, + getUIntType(), + sizedType); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitAlignOf( + IRInst* sizedType) + { + auto inst = createInst<IRInst>( + this, + kIROp_AlignOf, + getUIntType(), + sizedType); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitBitCast( IRType* type, IRInst* val) diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp index d95f7d4e5..6bccca8d3 100644 --- a/source/slang/slang-language-server-completion.cpp +++ b/source/slang/slang-language-server-completion.cpp @@ -34,7 +34,8 @@ static const char* kStmtKeywords[] = { "extension", "associatedtype", "this", "namespace", "This", "using", "__generic", "__exported", "import", "enum", "break", "continue", "discard", "defer", "cbuffer", "tbuffer", "func", "is", - "as", "nullptr", "none", "true", "false", "functype"}; + "as", "nullptr", "none", "true", "false", "functype", + "sizeof", "alignof"}; static const char* hlslSemanticNames[] = { "register", diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 85b17dafb..cb2c9129a 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -32,6 +32,9 @@ #include "slang-type-layout.h" #include "slang-visitor.h" +// Natural layout +#include "slang-ast-natural-layout.h" + namespace Slang { @@ -3377,6 +3380,43 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return LoweredValInfo::simple(arrayType->getElementCount()); } + LoweredValInfo visitSizeOfLikeExpr(SizeOfLikeExpr* sizeOfLikeExpr) + { + // Lets try and lower to a constant + ASTNaturalLayoutContext naturalLayoutContext(getASTBuilder(), nullptr); + + const auto size = naturalLayoutContext.calcSize(sizeOfLikeExpr->sizedType); + + auto builder = getBuilder(); + + if (!size) + { + auto sizedType = lowerType(context, sizeOfLikeExpr->sizedType); + + // We can create an inst + + IRInst* inst = nullptr; + + if (as<AlignOfExpr>(sizeOfLikeExpr)) + { + inst = builder->emitAlignOf(sizedType); + } + else + { + inst = builder->emitSizeOf(sizedType); + } + + return LoweredValInfo::simple(inst); + } + + const auto value = + as<SizeOfExpr>(sizeOfLikeExpr) ? + size.size : + size.alignment; + + return LoweredValInfo::simple(getBuilder()->getIntValue(builder->getUIntType(), value)); + } + LoweredValInfo visitOverloadedExpr(OverloadedExpr* /*expr*/) { SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST"); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index dcf835234..f790cd4d2 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -4940,9 +4940,6 @@ namespace Slang return Associativity::Left; } - - - Precedence GetOpLevel(Parser* parser, const Token& token) { switch(token.type) @@ -4998,12 +4995,15 @@ namespace Slang case TokenType::OpMod: return Precedence::Multiplicative; default: - if (token.getContent() == "is" || token.getContent() == "as") + { + const auto content = token.getContent(); + if (content == "is" || content == "as") { return Precedence::RelationalComparison; } return Precedence::Invalid; } + } } static Expr* parseOperator(Parser* parser) @@ -5060,7 +5060,9 @@ namespace Slang // Special case the "is" and "as" operators. if (opToken.type == TokenType::Identifier) { - if (opToken.getContent() == "is") + const auto content = opToken.getContent(); + + if (content == "is") { auto isExpr = parser->astBuilder->create<IsTypeExpr>(); isExpr->value = expr; @@ -5070,7 +5072,7 @@ namespace Slang expr = isExpr; continue; } - else if (opToken.getContent() == "as") + else if (content == "as") { auto asExpr = parser->astBuilder->create<AsTypeExpr>(); asExpr->value = expr; @@ -5246,6 +5248,40 @@ namespace Slang return parser->astBuilder->create<NoneLiteralExpr>(); } + static NodeBase* parseSizeOfExpr(Parser* parser, void* /*userData*/) + { + // We could have a type or a variable or an expression + SizeOfExpr* sizeOfExpr = parser->astBuilder->create<SizeOfExpr>(); + + parser->ReadMatchingToken(TokenType::LParent); + + // The return type is always a UInt + sizeOfExpr->type = parser->astBuilder->getUIntType(); + + sizeOfExpr->value = parser->ParseExpression(); + + parser->ReadMatchingToken(TokenType::RParent); + + return sizeOfExpr; + } + + static NodeBase* parseAlignOfExpr(Parser* parser, void* /*userData*/) + { + // We could have a type or a variable or an expression + AlignOfExpr* alignOfExpr = parser->astBuilder->create<AlignOfExpr>(); + + parser->ReadMatchingToken(TokenType::LParent); + + // The return type is always a UInt + alignOfExpr->type = parser->astBuilder->getUIntType(); + + alignOfExpr->value = parser->ParseExpression(); + + parser->ReadMatchingToken(TokenType::RParent); + + return alignOfExpr; + } + static NodeBase* parseTryExpr(Parser* parser, void* /*userData*/) { auto tryExpr = parser->astBuilder->create<TryExpr>(); @@ -6040,7 +6076,7 @@ namespace Slang } break; - // Call oepration `f(x)` + // Call operation `f(x)` case TokenType::LParent: { InvokeExpr* invokeExpr = parser->astBuilder->create<InvokeExpr>(); @@ -6144,8 +6180,11 @@ namespace Slang auto tokenType = peekTokenType(parser); switch( tokenType ) { - default: - if (parser->LookAheadToken("new")) + case TokenType::Identifier: + { + auto identifierToken = peekToken(parser); + const auto identifierTokenContent = identifierToken.getContent(); + if (identifierTokenContent == toSlice("new")) { NewExpr* newExpr = parser->astBuilder->create<NewExpr>(); parser->FillPosition(newExpr); @@ -6168,7 +6207,14 @@ namespace Slang } return newExpr; } + + return parsePostfixExpr(parser); + } + default: + { + return parsePostfixExpr(parser); + } case TokenType::OpNot: case TokenType::OpInc: case TokenType::OpDec: @@ -6873,7 +6919,9 @@ namespace Slang _makeParseExpr("__bwd_diff", parseBackwardDifferentiate), _makeParseExpr("fwd_diff", parseForwardDifferentiate), _makeParseExpr("bwd_diff", parseBackwardDifferentiate), - _makeParseExpr("__dispatch_kernel", parseDispatchKernel) + _makeParseExpr("__dispatch_kernel", parseDispatchKernel), + _makeParseExpr("sizeof", parseSizeOfExpr), + _makeParseExpr("alignof", parseAlignOfExpr), }; ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos() |
