summaryrefslogtreecommitdiff
path: root/source
diff options
context:
space:
mode:
authorjsmall-nvidia <jsmall@nvidia.com>2023-07-05 13:23:14 -0400
committerGitHub <noreply@github.com>2023-07-05 13:23:14 -0400
commit69450a2be7575aa4f984b9ae2824da0e5634c9f0 (patch)
treed554404f441af7fd113737cae8e1bde4897a814e /source
parentf9b73eab7edcedc9dc2c7825fcd4171631d14ac7 (diff)
Initial sizeof/alignof implementation. (#2954)
* Initial sizeof implementation. * Small macro improvement. * Fix some typos. * Refactor NaturalSize. Add more sizeof tests. * Use _makeParseExpr to add sizeof support. * Add size-of.slang diagnostic result. * Fix typo in folding with macro change. * Add a sizeof test of This. * Some more NaturalSize coverage. * Simple alignof support. * Testing for alignof. * Added 8 bit enum to check enums values are correctly sized. * Add alignof to completion. * Lower sizeof/alignof to IR. sizeof/alignof IR pass. Tests for simple generic scenarios. * Make append handle invalid properly. Improve comments. --------- Co-authored-by: Theresa Foley <10618364+tangent-vector@users.noreply.github.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-expr.h21
-rw-r--r--source/slang/slang-ast-natural-layout.cpp249
-rw-r--r--source/slang/slang-ast-natural-layout.h103
-rw-r--r--source/slang/slang-ast-val.cpp105
-rw-r--r--source/slang/slang-check-expr.cpp83
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-diagnostic-defs.h5
-rw-r--r--source/slang/slang-emit.cpp5
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h6
-rw-r--r--source/slang/slang-ir-lower-size-of.cpp106
-rw-r--r--source/slang/slang-ir-lower-size-of.h17
-rw-r--r--source/slang/slang-ir.cpp25
-rw-r--r--source/slang/slang-language-server-completion.cpp3
-rw-r--r--source/slang/slang-lower-to-ir.cpp40
-rw-r--r--source/slang/slang-parser.cpp68
16 files changed, 780 insertions, 61 deletions
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index 36d6546de..c441e1b9b 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -376,6 +376,27 @@ class AsTypeExpr : public Expr
};
+class SizeOfLikeExpr : public Expr
+{
+ SLANG_AST_CLASS(SizeOfLikeExpr);
+
+ // Set during the parse, could be an expression, a variable or a type
+ Expr* value = nullptr;
+
+ // The type the size/alignment needs to operate on. Set during traversal of SemanticsExprVisitor
+ Type* sizedType = nullptr;
+};
+
+class SizeOfExpr : public SizeOfLikeExpr
+{
+ SLANG_AST_CLASS(SizeOfExpr);
+};
+
+class AlignOfExpr : public SizeOfLikeExpr
+{
+ SLANG_AST_CLASS(AlignOfExpr);
+};
+
class MakeOptionalExpr : public Expr
{
SLANG_AST_CLASS(MakeOptionalExpr)
diff --git a/source/slang/slang-ast-natural-layout.cpp b/source/slang/slang-ast-natural-layout.cpp
new file mode 100644
index 000000000..1789c5cea
--- /dev/null
+++ b/source/slang/slang-ast-natural-layout.cpp
@@ -0,0 +1,249 @@
+// slang-ast-natural-layout.cpp
+#include "slang-ast-natural-layout.h"
+
+#include "slang-ast-builder.h"
+
+// For BaseInfo
+#include "slang-compiler.h"
+
+namespace Slang
+{
+
+/* !!!!!!!!!!!!!!!!!!!!!!!!! NaturalSize !!!!!!!!!!!!!!!!!!!!!!!!!!!! */
+
+
+NaturalSize NaturalSize::operator*(Count count) const
+{
+ // If the count is < 0 or the size is invalid, the result is invalid
+ if (isInvalid() || count < 0)
+ {
+ return makeInvalid();
+ }
+
+ if (count <= 0)
+ {
+ // If the count is 0, in effect the result doesn't take up any space
+ return makeEmpty();
+ }
+ else
+ {
+ // We don't want to produce an aligned size, as we allow the last element to not
+ // take up a whole stride (only up to size)
+ return make(size + (getStride() * (count - 1)), alignment);
+ }
+}
+
+/* static */NaturalSize NaturalSize::makeFromBaseType(BaseType baseType)
+{
+ // Special case void
+ if (baseType == BaseType::Void)
+ {
+ return makeEmpty();
+ }
+ else
+ {
+ // In "natural" layout the alignment of a base type is always the same
+ // as the size of the type itself
+ auto info = BaseTypeInfo::getInfo(baseType);
+ return make(info.sizeInBytes, info.sizeInBytes);
+ }
+}
+
+/* static */NaturalSize NaturalSize::calcUnion(NaturalSize a, NaturalSize b)
+{
+ const auto alignment = maxAlignment(a.alignment, b.alignment);
+ Count size = (alignment == kInvalidAlignment) ? 0 : Math::Max(a.size, b.size);
+ return make(size, alignment);
+}
+
+/* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ASTNaturalLayoutContext !!!!!!!!!!!!!!!!!!!!!!!!!!!! */
+
+ASTNaturalLayoutContext::ASTNaturalLayoutContext(ASTBuilder* astBuilder, DiagnosticSink* sink):
+ m_astBuilder(astBuilder),
+ m_sink(sink)
+{
+ // A null type always maps to invalid
+ m_typeToSize.add(nullptr, NaturalSize::makeInvalid());
+}
+
+Count ASTNaturalLayoutContext::_getCount(IntVal* intVal)
+{
+ if (auto constIntVal = as<ConstantIntVal>(intVal))
+ {
+ if (constIntVal->value >= 0)
+ {
+ return Count(constIntVal->value);
+ }
+ }
+
+ if (m_sink)
+ {
+ // Could output an error
+ }
+
+ return -1;
+}
+
+NaturalSize ASTNaturalLayoutContext::calcSize(Type* type)
+{
+ if (auto sizePtr = m_typeToSize.tryGetValue(type))
+ {
+ return *sizePtr;
+ }
+
+ // Calc the size
+ const NaturalSize size = _calcSizeImpl(type);
+
+ // We want to add to the cache, but we need to special case
+ // in case there is an aggregate type that `poisoned` the cache entry, to stop infinite recursion.
+ //
+ // A requirement is that when the agg type completes it must set the cache entry, and return the same result.
+ if (auto foundSize = m_typeToSize.tryGetValueOrAdd(type, size))
+ {
+ // If there is a found size, it must match. If not we update the state as invalid.
+ if (*foundSize != size)
+ {
+ *foundSize = NaturalSize::makeInvalid();
+ return *foundSize;
+ }
+ }
+
+ return size;
+}
+
+NaturalSize ASTNaturalLayoutContext::_calcSizeImpl(Type* type)
+{
+ if (VectorExpressionType* vecType = as<VectorExpressionType>(type))
+ {
+ const Count elementCount = _getCount(vecType->elementCount);
+ return (elementCount > 0) ?
+ calcSize(vecType->elementType) * elementCount :
+ NaturalSize::makeInvalid();
+ }
+ else if (auto matType = as<MatrixExpressionType>(type))
+ {
+ const Count colCount = _getCount(matType->getColumnCount());
+ const Count rowCount = _getCount(matType->getRowCount());
+ return (colCount > 0 && rowCount > 0) ?
+ calcSize(matType->getElementType()) * (colCount * rowCount) :
+ NaturalSize::makeInvalid();
+ }
+ else if (auto basicType = as<BasicExpressionType>(type))
+ {
+ return NaturalSize::makeFromBaseType(basicType->baseType);
+ }
+ else if (as<PtrTypeBase>(type) || as<NullPtrType>(type))
+ {
+ // We assume 64 bits/8 bytes across the board
+ return NaturalSize::makeFromBaseType(BaseType::UInt64);
+ }
+ else if (auto arrayType = as<ArrayExpressionType>(type))
+ {
+ const Count elementCount = _getCount(arrayType->getElementCount());
+ return (elementCount > 0) ?
+ calcSize(arrayType->getElementType()) * elementCount :
+ NaturalSize::makeInvalid();
+ }
+ else if (auto namedType = as<NamedExpressionType>(type))
+ {
+ return calcSize(namedType->innerType);
+ }
+ else if (const auto tupleType = as<TupleType>(type))
+ {
+ // Initialize empty
+ NaturalSize size = NaturalSize::makeEmpty();
+
+ // Accumulate over all the member types
+ for (auto cur : tupleType->memberTypes)
+ {
+ const auto curSize = calcSize(cur);
+ if (!curSize)
+ {
+ return NaturalSize::makeInvalid();
+ }
+ size.append(curSize);
+ }
+
+ return size;
+ }
+ else if (const auto taggedUnion = as<TaggedUnionType>(type))
+ {
+ NaturalSize size = NaturalSize::makeInvalid();
+
+ for( auto caseType : taggedUnion->caseTypes )
+ {
+ const NaturalSize caseSize = calcSize(caseType);
+ if (!caseSize)
+ {
+ return NaturalSize::makeInvalid();
+ }
+ size = NaturalSize::calcUnion(size, caseSize);
+ }
+
+ // After we've computed the size required to hold all the
+ // case types, we will allocate space for the tag field.
+
+ // Currently we assume uint32_t on all targets
+ size.append(NaturalSize::makeFromBaseType(BaseType::UInt));
+
+ return size;
+ }
+ else if( auto declRefType = as<DeclRefType>(type) )
+ {
+ if (const auto enumDeclRef = declRefType->declRef.as<EnumDecl>())
+ {
+ Type* tagType = getTagType(m_astBuilder, enumDeclRef);
+ return calcSize(tagType);
+ }
+ else if(const auto structDeclRef = declRefType->declRef.as<StructDecl>())
+ {
+ // Poison the cache whilst we construct
+ m_typeToSize.add(type, NaturalSize::makeInvalid());
+
+ // Initialize empty
+ NaturalSize size = NaturalSize::makeEmpty();
+
+ for (auto inherited : structDeclRef.getDecl()->getMembersOfType<InheritanceDecl>())
+ {
+ // Look for a struct type that it inherits from
+ if (auto inheritedDeclRef = as<DeclRefType>(inherited->base.type))
+ {
+ if (auto parentDecl = inheritedDeclRef->declRef.as<StructDecl>())
+ {
+ // We can only inherit from one thing
+ size = calcSize(inherited->base.type);
+ if (!size)
+ {
+ return size;
+ }
+ break;
+ }
+ }
+ }
+
+ // Accumulate over all of the fields
+ for (auto field : structDeclRef.getDecl()->getFields())
+ {
+ const auto fieldSize = calcSize(field->getType());
+ if (!fieldSize)
+ {
+ return NaturalSize::makeInvalid();
+ }
+ size.append(fieldSize);
+ }
+
+ // Set the cached result to the size.
+ m_typeToSize.set(type, size);
+
+ return size;
+ }
+ else if (const auto typeDef = declRefType->declRef.as<TypeDefDecl>())
+ {
+ return calcSize(typeDef.getDecl()->type);
+ }
+ }
+
+ return NaturalSize::makeInvalid();
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ast-natural-layout.h b/source/slang/slang-ast-natural-layout.h
new file mode 100644
index 000000000..4a165973d
--- /dev/null
+++ b/source/slang/slang-ast-natural-layout.h
@@ -0,0 +1,103 @@
+#ifndef SLANG_AST_NATURAL_LAYOUT_H
+#define SLANG_AST_NATURAL_LAYOUT_H
+
+#include "slang-ast-base.h"
+
+namespace Slang
+{
+
+struct NaturalSize
+{
+ typedef NaturalSize ThisType;
+
+ // We are going to use 0 as invalid for alignment. This has a few nice propeties
+ //
+ // * Will naturally produce 0 size when used with `calcAligned` operation
+ // * Is fast to test
+ // * Is easy to make a fast 'max' such that a max with invalid always returns `invalid`
+ //
+ // We also desire that when invalid the `size` member is 0.
+ // This is so that equality testing doesn't require anything special.
+ SLANG_FORCE_INLINE static Count calcAligned(Count size, Count alignment) { return (size + alignment - 1) & ~(alignment - 1); }
+ // Use to get the max of two alignments. Uses some maths such that `invalid` is always max
+ SLANG_FORCE_INLINE static Count maxAlignment(Count a, Count b) { return (UCount(a) - 1) > (UCount(b) - 1) ? a : b; }
+
+ /// Given two sizes, returns a result that can hold the union.
+ static NaturalSize calcUnion(NaturalSize a, NaturalSize b);
+
+ /// Value chosen such that normal combining operations produce an invalid result
+ /// as typically a max.
+ static const Count kInvalidAlignment = 0;
+
+ /// Get the stride, which is equivalent to the size aligned
+ SLANG_FORCE_INLINE Count getStride() const { return calcAligned(size, alignment); }
+
+ /// Append rhs to this.
+ /// If rhs is invalid or this is the result will also be invalid
+ void append(const ThisType& rhs)
+ {
+ const auto newAlignment = maxAlignment(alignment, rhs.alignment);
+
+ // If the new alignment is valid we calculate the size, else it's 0
+ size = (newAlignment != kInvalidAlignment) ?
+ (calcAligned(size, rhs.alignment) + rhs.size) :
+ 0;
+
+ // Set the new alignment
+ alignment = newAlignment;
+ }
+
+ SLANG_FORCE_INLINE bool isInvalid() const { return alignment == kInvalidAlignment; }
+ SLANG_FORCE_INLINE bool isValid() const { return !isInvalid(); }
+
+ bool operator==(const ThisType& rhs) const { return size == rhs.size && alignment == rhs.alignment; }
+ bool operator!=(const ThisType& rhs) const { return !(*this == rhs); }
+
+ /// Converts to bool to make testing convenient
+ operator bool() const { return isValid(); }
+
+ /// An empty size. It consumes 0 bytes and has the lowest alignment (1)
+ static ThisType makeEmpty() { return ThisType{ 0, 1 }; }
+ /// Make an invalid size.
+ static ThisType makeInvalid() { return ThisType{ 0, kInvalidAlignment }; }
+ /// Make a size with an amount of bytes and the alignment
+ static ThisType make(Count size, Count alignment) { return ThisType{size, alignment}; }
+
+ /// Given a base type returns it's size
+ static ThisType makeFromBaseType(BaseType baseType);
+
+ /// Multiply by a count.
+ /// Will return invalid if count < 0 or this is already invalid
+ ThisType operator*(Count count) const;
+
+ Count size;
+ Count alignment;
+};
+
+struct ASTNaturalLayoutContext
+{
+ /// Given a type returns it's natural size.
+ /// Returns invalid size if types size could not be calculated
+ NaturalSize calcSize(Type* type);
+
+ /// Ctor
+ ASTNaturalLayoutContext(ASTBuilder* astBuilder, DiagnosticSink* sink = nullptr);
+
+protected:
+
+ /// Gets a count (positivie integer including 0).
+ /// <0 indicates error
+ Count _getCount(IntVal* intVal);
+
+ /// The main implementation, assumes outer `calcSize` will perform caching
+ NaturalSize _calcSizeImpl(Type* type);
+
+ Dictionary<Type*, NaturalSize> m_typeToSize;
+
+ ASTBuilder* m_astBuilder;
+ DiagnosticSink* m_sink;
+};
+
+} // namespace Slang
+
+#endif
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index d8886f05b..21f876048 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -1400,13 +1400,6 @@ HashCode FuncCallIntVal::_getHashCodeOverride()
return result;
}
-static bool nameIs(Name* name, const char* val)
-{
- if (name && name->text.getUnownedSlice() == val)
- return true;
- return false;
-}
-
Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink)
{
// Are all args const now?
@@ -1428,29 +1421,24 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR
{
// Evaluate the function.
auto opName = newFuncDecl.getName();
+ SLANG_ASSERT(opName);
+
+ const auto opNameSlice = opName->text.getUnownedSlice();
+
IntegerLiteralValue resultValue = 0;
- if (nameIs(opName, "=="))
- {
- resultValue = constArgs[0]->value / constArgs[1]->value;
- }
+
+ // Define convenience macros.
+ // The last macro used in the list *must* be
+ // TERMINATING_CASE, as this handles the closing else, and matches if nothing else does.
+
#define BINARY_OPERATOR_CASE(op) \
- else if (nameIs(opName, #op)) \
+ if (opNameSlice == toSlice(#op)) \
{ \
resultValue = constArgs[0]->value op constArgs[1]->value; \
- }
- BINARY_OPERATOR_CASE(>=)
- BINARY_OPERATOR_CASE(<=)
- BINARY_OPERATOR_CASE(>)
- BINARY_OPERATOR_CASE(<)
- BINARY_OPERATOR_CASE(!=)
- BINARY_OPERATOR_CASE(<<)
- BINARY_OPERATOR_CASE(>>)
- BINARY_OPERATOR_CASE(&)
- BINARY_OPERATOR_CASE(|)
- BINARY_OPERATOR_CASE(^)
-#undef BINARY_OPERATOR_CASE
+ } else
+
#define DIV_OPERATOR_CASE(op) \
- else if (nameIs(opName, #op)) \
+ if (opNameSlice == toSlice(#op)) \
{ \
if (constArgs[1]->value == 0) \
{ \
@@ -1459,35 +1447,56 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclR
return nullptr; \
} \
resultValue = constArgs[0]->value op constArgs[1]->value; \
- }
- DIV_OPERATOR_CASE(/)
- DIV_OPERATOR_CASE(%)
-#undef DIV_OPERATOR_CASE
+ } else
+
#define LOGICAL_OPERATOR_CASE(op) \
- else if (nameIs(opName, #op)) \
+ if (opNameSlice == toSlice(#op)) \
{ \
resultValue = (((constArgs[0]->value!=0) op (constArgs[1]->value!=0)) ? 1 : 0); \
+ } else
+
+
+#define SPECIAL_OPERATOR_CASE(op, IF_MATCH) \
+ if (opNameSlice == toSlice(op)) \
+ { \
+ IF_MATCH \
+ } else
+
+#define TERMINATING_CASE(MATCH) \
+ { \
+ MATCH \
}
+
+ // Handle the cases using the macros
+ BINARY_OPERATOR_CASE(>=)
+ BINARY_OPERATOR_CASE(<=)
+ BINARY_OPERATOR_CASE(>)
+ BINARY_OPERATOR_CASE(<)
+ BINARY_OPERATOR_CASE(!=)
+ BINARY_OPERATOR_CASE(==)
+ BINARY_OPERATOR_CASE(<<)
+ BINARY_OPERATOR_CASE(>>)
+ BINARY_OPERATOR_CASE(&)
+ BINARY_OPERATOR_CASE(|)
+ BINARY_OPERATOR_CASE(^)
+ DIV_OPERATOR_CASE(/)
+ DIV_OPERATOR_CASE(%)
LOGICAL_OPERATOR_CASE(&&)
- LOGICAL_OPERATOR_CASE(|| )
-#undef LOGICAL_OPERATOR_CASE
- else if (nameIs(opName, "!"))
- {
- resultValue = ((constArgs[0]->value != 0) ? 1 : 0);
- }
- else if (nameIs(opName, "~"))
- {
- resultValue = ~constArgs[0]->value;
- }
- else if (nameIs(opName, "?:"))
- {
- resultValue = constArgs[0]->value != 0 ? constArgs[1]->value : constArgs[2]->value;
- }
- else
- {
- SLANG_UNREACHABLE("constant folding of FuncCallIntVal");
- }
+ LOGICAL_OPERATOR_CASE(||)
+ // Special cases need their "operator" names quoted.
+ SPECIAL_OPERATOR_CASE("!", resultValue = ((constArgs[0]->value != 0) ? 1 : 0);)
+ SPECIAL_OPERATOR_CASE("~", resultValue = ~constArgs[0]->value;)
+ SPECIAL_OPERATOR_CASE("?:", resultValue = constArgs[0]->value != 0 ? constArgs[1]->value : constArgs[2]->value;)
+ TERMINATING_CASE(SLANG_UNREACHABLE("constant folding of FuncCallIntVal");)
+
return astBuilder->getIntVal(resultType, resultValue);
+
+ // The macros for the cases are no longer needed so undef them all.
+#undef BINARY_OPERATOR_CASE
+#undef DIV_OPERATOR_CASE
+#undef LOGICAL_OPERATOR_CASE
+#undef SPECIAL_OPERATOR_CASE
+#undef TERMINATING_CASE
}
return nullptr;
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 00ece3628..9dc359a5e 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -11,6 +11,8 @@
//
// * `slang-check-conversion.cpp` is responsible for the logic of handling type conversion/coercion
+#include "slang-ast-natural-layout.h"
+
#include "slang-lookup.h"
#include "slang-ast-print.h"
@@ -1279,8 +1281,6 @@ namespace Slang
return nullptr;
}
-
-
// Let's not constant-fold operations with more than a certain number of arguments, for simplicity
static const int kMaxArgs = 8;
auto argCount = getArgCount(invokeExpr);
@@ -1533,6 +1533,7 @@ namespace Slang
SubstExpr<Expr> expr,
ConstantFoldingCircularityInfo* circularityInfo)
{
+
// Unwrap any "identity" expressions
while (auto parenExpr = expr.as<ParenExpr>())
{
@@ -1629,7 +1630,23 @@ namespace Slang
if (val)
return val;
}
+ else if (auto sizeOfLikeExpr = as<SizeOfLikeExpr>(expr.getExpr()))
+ {
+ ASTNaturalLayoutContext context(getASTBuilder(), nullptr);
+ const auto size = context.calcSize(sizeOfLikeExpr->sizedType);
+ if (!size)
+ {
+ return nullptr;
+ }
+ auto value = as<AlignOfExpr>(sizeOfLikeExpr) ?
+ size.alignment :
+ size.size;
+
+ // We can return as an IntVal
+ return getASTBuilder()->getIntVal(expr.getExpr()->type, value);
+ }
+
return nullptr;
}
@@ -2145,6 +2162,7 @@ namespace Slang
return rs;
}
+
Expr* SemanticsExprVisitor::visitSelectExpr(SelectExpr* expr)
{
auto result = visitInvokeExpr(expr);
@@ -2695,6 +2713,67 @@ namespace Slang
return expr;
}
+ static bool _isSizeOfType(Type* type)
+ {
+ if (!type)
+ {
+ return false;
+ }
+
+ if (as<ArithmeticExpressionType>(type) ||
+ as<ArrayExpressionType>(type) ||
+ as<PtrTypeBase>(type) ||
+ as<TupleType>(type) ||
+ as<GenericDeclRefType>(type))
+ {
+ return true;
+ }
+
+ if (as<DeclRefType>(type))
+ {
+ return true;
+ }
+
+ return false;
+ }
+
+ Expr* SemanticsExprVisitor::visitSizeOfLikeExpr(SizeOfLikeExpr* sizeOfLikeExpr)
+ {
+ auto valueExpr = dispatch(sizeOfLikeExpr->value);
+
+ Type* type = nullptr;
+
+ if (as<TypeType>(valueExpr->type))
+ {
+ TypeExp typeExp;
+ typeExp.exp = valueExpr;
+
+ auto properTypeExpr = CoerceToProperType(typeExp);
+
+ type = properTypeExpr.type;
+ }
+ else
+ {
+ // Is this a proper type?
+ TypeExp typeExp(valueExpr->type);
+ TypeExp properType = tryCoerceToProperType(typeExp);
+
+ type = properType.type;
+ }
+
+ if (!_isSizeOfType(type))
+ {
+ getSink()->diagnose(sizeOfLikeExpr, Diagnostics::sizeOfArgumentIsInvalid);
+
+ sizeOfLikeExpr->type = m_astBuilder->getErrorType();
+ return sizeOfLikeExpr;
+ }
+
+ sizeOfLikeExpr->sizedType = type;
+
+ return sizeOfLikeExpr;
+ }
+
Expr* SemanticsExprVisitor::visitTypeCastExpr(TypeCastExpr * expr)
{
// Check the term we are applying first
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index ac4d51549..26c96a72e 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1933,6 +1933,8 @@ namespace Slang
: SemanticsVisitor(outer)
{}
+ Expr* visitSizeOfLikeExpr(SizeOfLikeExpr* expr);
+
Expr* visitIncompleteExpr(IncompleteExpr* expr);
Expr* visitBoolLiteralExpr(BoolLiteralExpr* expr);
Expr* visitNullPtrLiteralExpr(NullPtrLiteralExpr* expr);
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index d38cde4e6..0fd962614 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -318,6 +318,8 @@ DIAGNOSTIC(30096, Error, differentialTypeShouldServeAsItsOwnDifferentialType, "t
DIAGNOSTIC(30097, Error, functionNotMarkedAsDifferentiable, "function '$0' is not marked as $1-differentiable.")
DIAGNOSTIC(30098, Error, nonStaticMemberFunctionNotAllowedAsDiffOperand, "non-static function reference '$0' is not allowed here.")
+DIAGNOSTIC(30099, Error, sizeOfArgumentIsInvalid, "argument to sizeof is invalid")
+
DIAGNOSTIC(-1, Note, noteSeeUseOfDifferentialType, "see use of '$0' as Differential of '$1'.")
// Attributes
@@ -625,6 +627,9 @@ DIAGNOSTIC(41023, Error, getStringHashMustBeOnStringLiteral, "getStringHash can
DIAGNOSTIC(41901, Error, unsupportedUseOfLValueForAutoDiff, "unsupported use of L-value for auto differentiation.")
DIAGNOSTIC(41902, Error, cannotDifferentiateDynamicallyIndexedData, "cannot auto-differentiate mixed read/write access to dynamically indexed data in '$0'.")
+DIAGNOSTIC(41903, Error, unableToSizeOf, "sizeof could not be performed for type '$0'.")
+DIAGNOSTIC(41904, Error, unableToAlignOf, "alignof could not be performed for type '$0'.")
+
DIAGNOSTIC(42001, Error, invalidUseOfTorchTensorTypeInDeviceFunc, "invalid use of TorchTensor type in device/kernel functions. use `TensorView` instead.")
//
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 827c69e50..4c050ffcc 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -36,6 +36,7 @@
#include "slang-ir-lower-optional-type.h"
#include "slang-ir-lower-bit-cast.h"
#include "slang-ir-lower-l-value-cast.h"
+#include "slang-ir-lower-size-of.h"
#include "slang-ir-lower-reinterpret.h"
#include "slang-ir-loop-unroll.h"
#include "slang-ir-metadata.h"
@@ -867,6 +868,10 @@ Result linkAndOptimizeIR(
legalizeUniformBufferLoad(irModule);
}
+ // Lower sizeof/alignof
+
+ lowerSizeOfLike(targetRequest, irModule, sink);
+
// Lower all the LValue implict casts (used for out/inout/ref scenarios)
lowerLValueCast(targetRequest, irModule);
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index de3735a55..636264a6d 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -934,6 +934,9 @@ INST(CastPtrToInt, CastPtrToInt, 1, 0)
INST(CastIntToPtr, CastIntToPtr, 1, 0)
INST(CastToVoid, castToVoid, 1, 0)
+INST(SizeOf, sizeOf, 1, 0)
+INST(AlignOf, alignOf, 1, 0)
+
INST(IsType, IsType, 3, 0)
INST(ForwardDifferentiate, ForwardDifferentiate, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index f2c00f406..18a0677b0 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3622,6 +3622,12 @@ public:
IRType* type,
IRInst* val);
+ IRInst* emitSizeOf(
+ IRInst* sizedType);
+
+ IRInst* emitAlignOf(
+ IRInst* sizedType);
+
IRInst* emitCastPtrToBool(IRInst* val);
IRGlobalConstant* emitGlobalConstant(
diff --git a/source/slang/slang-ir-lower-size-of.cpp b/source/slang/slang-ir-lower-size-of.cpp
new file mode 100644
index 000000000..a8b599031
--- /dev/null
+++ b/source/slang/slang-ir-lower-size-of.cpp
@@ -0,0 +1,106 @@
+#include "slang-ir-lower-size-of.h"
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+
+#include "slang-ir-layout.h"
+
+namespace Slang
+{
+
+struct SizeOfLikeLoweringContext
+{
+ void _addToWorkList(IRInst* inst)
+ {
+ if (!findOuterGeneric(inst) && !m_workList.contains(inst))
+ {
+ m_workList.add(inst);
+ }
+ }
+
+ void _processInst(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_AlignOf:
+ case kIROp_SizeOf:
+ _processSizeOfLike(inst);
+ break;
+ default:
+ break;
+ }
+ }
+
+ void processModule()
+ {
+ _addToWorkList(m_module->getModuleInst());
+
+ while (m_workList.getCount() != 0)
+ {
+ IRInst* inst = m_workList.getLast();
+ m_workList.removeLast();
+
+ _processInst(inst);
+
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
+ {
+ _addToWorkList(child);
+ }
+ }
+ }
+
+ void _processSizeOfLike(IRInst* sizeOfLikeInst)
+ {
+ auto typeOperand = as<IRType>(sizeOfLikeInst->getOperand(0));
+
+ IRSizeAndAlignment sizeAndAlignment;
+
+ if (SLANG_FAILED(getNaturalSizeAndAlignment(m_targetReq, typeOperand, &sizeAndAlignment)))
+ {
+ // Output a diagnostic failure
+ if(sizeOfLikeInst->getOp() == kIROp_AlignOf)
+ {
+ m_sink->diagnose(sizeOfLikeInst, Diagnostics::unableToAlignOf, typeOperand);
+ }
+ else
+ {
+ m_sink->diagnose(sizeOfLikeInst, Diagnostics::unableToSizeOf, typeOperand);
+ }
+
+ return;
+ }
+
+ IRBuilder builder(m_module);
+
+ const auto value = (sizeOfLikeInst->getOp() == kIROp_AlignOf) ?
+ sizeAndAlignment.alignment :
+ sizeAndAlignment.size;
+
+ auto valueInst = builder.getIntValue(sizeOfLikeInst->getDataType(), value);
+
+ // Replace all uses of sizeOfLikeInst with the value
+ sizeOfLikeInst->replaceUsesWith(valueInst);
+ // We don't need the instruction any more
+ sizeOfLikeInst->removeAndDeallocate();
+ }
+
+ SizeOfLikeLoweringContext(TargetRequest* targetReq, IRModule* module, DiagnosticSink* sink):
+ m_module(module),
+ m_targetReq(targetReq),
+ m_sink(sink)
+ {
+ }
+
+ TargetRequest* m_targetReq;
+ DiagnosticSink* m_sink;
+ IRModule* m_module;
+ OrderedHashSet<IRInst*> m_workList;
+};
+
+void lowerSizeOfLike(TargetRequest* targetReq, IRModule* module, DiagnosticSink* sink)
+{
+ SizeOfLikeLoweringContext context(targetReq, module, sink);
+ context.processModule();
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-lower-size-of.h b/source/slang/slang-ir-lower-size-of.h
new file mode 100644
index 000000000..4205aa2f0
--- /dev/null
+++ b/source/slang/slang-ir-lower-size-of.h
@@ -0,0 +1,17 @@
+#ifndef SLANG_IR_LOWER_SIZE_OF_H
+#define SLANG_IR_LOWER_SIZE_OF_H
+
+// This defines an IR pass that lowers sizeof/alignof.
+
+namespace Slang
+{
+
+struct IRModule;
+class TargetRequest;
+class DiagnosticSink;
+
+void lowerSizeOfLike(TargetRequest* target, IRModule* module, DiagnosticSink* sink);
+
+} // namespace Slang
+
+#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 789349b4c..74679de96 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5268,6 +5268,31 @@ namespace Slang
return inst;
}
+
+ IRInst* IRBuilder::emitSizeOf(
+ IRInst* sizedType)
+ {
+ auto inst = createInst<IRInst>(
+ this,
+ kIROp_SizeOf,
+ getUIntType(),
+ sizedType);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitAlignOf(
+ IRInst* sizedType)
+ {
+ auto inst = createInst<IRInst>(
+ this,
+ kIROp_AlignOf,
+ getUIntType(),
+ sizedType);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitBitCast(
IRType* type,
IRInst* val)
diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp
index d95f7d4e5..6bccca8d3 100644
--- a/source/slang/slang-language-server-completion.cpp
+++ b/source/slang/slang-language-server-completion.cpp
@@ -34,7 +34,8 @@ static const char* kStmtKeywords[] = {
"extension", "associatedtype", "this", "namespace", "This", "using",
"__generic", "__exported", "import", "enum", "break", "continue",
"discard", "defer", "cbuffer", "tbuffer", "func", "is",
- "as", "nullptr", "none", "true", "false", "functype"};
+ "as", "nullptr", "none", "true", "false", "functype",
+ "sizeof", "alignof"};
static const char* hlslSemanticNames[] = {
"register",
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 85b17dafb..cb2c9129a 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -32,6 +32,9 @@
#include "slang-type-layout.h"
#include "slang-visitor.h"
+// Natural layout
+#include "slang-ast-natural-layout.h"
+
namespace Slang
{
@@ -3377,6 +3380,43 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
return LoweredValInfo::simple(arrayType->getElementCount());
}
+ LoweredValInfo visitSizeOfLikeExpr(SizeOfLikeExpr* sizeOfLikeExpr)
+ {
+ // Lets try and lower to a constant
+ ASTNaturalLayoutContext naturalLayoutContext(getASTBuilder(), nullptr);
+
+ const auto size = naturalLayoutContext.calcSize(sizeOfLikeExpr->sizedType);
+
+ auto builder = getBuilder();
+
+ if (!size)
+ {
+ auto sizedType = lowerType(context, sizeOfLikeExpr->sizedType);
+
+ // We can create an inst
+
+ IRInst* inst = nullptr;
+
+ if (as<AlignOfExpr>(sizeOfLikeExpr))
+ {
+ inst = builder->emitAlignOf(sizedType);
+ }
+ else
+ {
+ inst = builder->emitSizeOf(sizedType);
+ }
+
+ return LoweredValInfo::simple(inst);
+ }
+
+ const auto value =
+ as<SizeOfExpr>(sizeOfLikeExpr) ?
+ size.size :
+ size.alignment;
+
+ return LoweredValInfo::simple(getBuilder()->getIntValue(builder->getUIntType(), value));
+ }
+
LoweredValInfo visitOverloadedExpr(OverloadedExpr* /*expr*/)
{
SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST");
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index dcf835234..f790cd4d2 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -4940,9 +4940,6 @@ namespace Slang
return Associativity::Left;
}
-
-
-
Precedence GetOpLevel(Parser* parser, const Token& token)
{
switch(token.type)
@@ -4998,12 +4995,15 @@ namespace Slang
case TokenType::OpMod:
return Precedence::Multiplicative;
default:
- if (token.getContent() == "is" || token.getContent() == "as")
+ {
+ const auto content = token.getContent();
+ if (content == "is" || content == "as")
{
return Precedence::RelationalComparison;
}
return Precedence::Invalid;
}
+ }
}
static Expr* parseOperator(Parser* parser)
@@ -5060,7 +5060,9 @@ namespace Slang
// Special case the "is" and "as" operators.
if (opToken.type == TokenType::Identifier)
{
- if (opToken.getContent() == "is")
+ const auto content = opToken.getContent();
+
+ if (content == "is")
{
auto isExpr = parser->astBuilder->create<IsTypeExpr>();
isExpr->value = expr;
@@ -5070,7 +5072,7 @@ namespace Slang
expr = isExpr;
continue;
}
- else if (opToken.getContent() == "as")
+ else if (content == "as")
{
auto asExpr = parser->astBuilder->create<AsTypeExpr>();
asExpr->value = expr;
@@ -5246,6 +5248,40 @@ namespace Slang
return parser->astBuilder->create<NoneLiteralExpr>();
}
+ static NodeBase* parseSizeOfExpr(Parser* parser, void* /*userData*/)
+ {
+ // We could have a type or a variable or an expression
+ SizeOfExpr* sizeOfExpr = parser->astBuilder->create<SizeOfExpr>();
+
+ parser->ReadMatchingToken(TokenType::LParent);
+
+ // The return type is always a UInt
+ sizeOfExpr->type = parser->astBuilder->getUIntType();
+
+ sizeOfExpr->value = parser->ParseExpression();
+
+ parser->ReadMatchingToken(TokenType::RParent);
+
+ return sizeOfExpr;
+ }
+
+ static NodeBase* parseAlignOfExpr(Parser* parser, void* /*userData*/)
+ {
+ // We could have a type or a variable or an expression
+ AlignOfExpr* alignOfExpr = parser->astBuilder->create<AlignOfExpr>();
+
+ parser->ReadMatchingToken(TokenType::LParent);
+
+ // The return type is always a UInt
+ alignOfExpr->type = parser->astBuilder->getUIntType();
+
+ alignOfExpr->value = parser->ParseExpression();
+
+ parser->ReadMatchingToken(TokenType::RParent);
+
+ return alignOfExpr;
+ }
+
static NodeBase* parseTryExpr(Parser* parser, void* /*userData*/)
{
auto tryExpr = parser->astBuilder->create<TryExpr>();
@@ -6040,7 +6076,7 @@ namespace Slang
}
break;
- // Call oepration `f(x)`
+ // Call operation `f(x)`
case TokenType::LParent:
{
InvokeExpr* invokeExpr = parser->astBuilder->create<InvokeExpr>();
@@ -6144,8 +6180,11 @@ namespace Slang
auto tokenType = peekTokenType(parser);
switch( tokenType )
{
- default:
- if (parser->LookAheadToken("new"))
+ case TokenType::Identifier:
+ {
+ auto identifierToken = peekToken(parser);
+ const auto identifierTokenContent = identifierToken.getContent();
+ if (identifierTokenContent == toSlice("new"))
{
NewExpr* newExpr = parser->astBuilder->create<NewExpr>();
parser->FillPosition(newExpr);
@@ -6168,7 +6207,14 @@ namespace Slang
}
return newExpr;
}
+
+
return parsePostfixExpr(parser);
+ }
+ default:
+ {
+ return parsePostfixExpr(parser);
+ }
case TokenType::OpNot:
case TokenType::OpInc:
case TokenType::OpDec:
@@ -6873,7 +6919,9 @@ namespace Slang
_makeParseExpr("__bwd_diff", parseBackwardDifferentiate),
_makeParseExpr("fwd_diff", parseForwardDifferentiate),
_makeParseExpr("bwd_diff", parseBackwardDifferentiate),
- _makeParseExpr("__dispatch_kernel", parseDispatchKernel)
+ _makeParseExpr("__dispatch_kernel", parseDispatchKernel),
+ _makeParseExpr("sizeof", parseSizeOfExpr),
+ _makeParseExpr("alignof", parseAlignOfExpr),
};
ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos()