summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJulius Ikkala <julius.ikkala@gmail.com>2025-05-22 22:10:42 +0300
committerGitHub <noreply@github.com>2025-05-22 22:10:42 +0300
commitce238dd878038bf857968931773cc9b10f3b225d (patch)
tree2e29a5191fff5eb85a5a7895fd68b7b285bcb198
parent27c6e9b01f7386263bde90e16812be46327015c2 (diff)
Make sizeof(T) & alignof(T) of generic types work as compile-time constants (#7213)
* Make sizeof(generic) work as compile-time constant * format code --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
-rw-r--r--source/slang/slang-ast-val.cpp101
-rw-r--r--source/slang/slang-ast-val.h52
-rw-r--r--source/slang/slang-check-expr.cpp34
-rw-r--r--source/slang/slang-lower-to-ir.cpp16
-rw-r--r--source/slang/slang-mangle.cpp15
-rw-r--r--tests/hlsl-intrinsic/size-of/size-of-compile-time.slang41
6 files changed, 237 insertions, 22 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index 1cdca0440..92e170515 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -3,6 +3,7 @@
#include "slang-ast-builder.h"
#include "slang-ast-dispatch.h"
+#include "slang-ast-natural-layout.h"
#include "slang-check-impl.h"
#include "slang-diagnostics.h"
#include "slang-mangle.h"
@@ -1742,6 +1743,106 @@ Val* FuncCallIntVal::_substituteImplOverride(
return this;
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SizeOfIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+void SizeOfIntVal::_toTextOverride(StringBuilder& out)
+{
+ out << "sizeof(";
+ getTypeArg()->toText(out);
+ out << ")";
+}
+
+Val* SizeOfIntVal::tryFoldOrNull(ASTBuilder* astBuilder, Type* intType, Type* newType)
+{
+ ASTNaturalLayoutContext context(astBuilder, nullptr);
+ const auto size = context.calcSize(newType);
+
+ if (!size)
+ return nullptr;
+
+ return astBuilder->getIntVal(intType, size.size);
+}
+
+Val* SizeOfIntVal::tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType)
+{
+ if (auto result = tryFoldOrNull(astBuilder, intType, newType))
+ return result;
+ auto result = astBuilder->getOrCreate<SizeOfIntVal>(intType, newType);
+ return result;
+}
+
+Val* SizeOfIntVal::_substituteImplOverride(
+ ASTBuilder* astBuilder,
+ SubstitutionSet subst,
+ int* ioDiff)
+{
+ int diff = 0;
+ auto newType = as<Type>(getTypeArg()->substituteImpl(astBuilder, subst, &diff));
+ if (!diff)
+ return this;
+
+ (*ioDiff)++;
+ return tryFold(astBuilder, getType(), newType);
+}
+
+Val* SizeOfIntVal::_resolveImplOverride()
+{
+ auto resolvedTypeArg = getTypeArg()->resolve();
+ if (resolvedTypeArg == getTypeArg())
+ return this;
+ return tryFold(getCurrentASTBuilder(), getType(), as<Type>(resolvedTypeArg));
+}
+
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! AlignOfIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+void AlignOfIntVal::_toTextOverride(StringBuilder& out)
+{
+ out << "alignof(";
+ getTypeArg()->toText(out);
+ out << ")";
+}
+
+Val* AlignOfIntVal::tryFoldOrNull(ASTBuilder* astBuilder, Type* intType, Type* newType)
+{
+ ASTNaturalLayoutContext context(astBuilder, nullptr);
+ const auto size = context.calcSize(newType);
+
+ if (!size)
+ return nullptr;
+
+ return astBuilder->getIntVal(intType, size.alignment);
+}
+
+Val* AlignOfIntVal::tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType)
+{
+ if (auto result = tryFoldOrNull(astBuilder, intType, newType))
+ return result;
+ auto result = astBuilder->getOrCreate<AlignOfIntVal>(intType, newType);
+ return result;
+}
+
+Val* AlignOfIntVal::_substituteImplOverride(
+ ASTBuilder* astBuilder,
+ SubstitutionSet subst,
+ int* ioDiff)
+{
+ int diff = 0;
+ auto newType = as<Type>(getTypeArg()->substituteImpl(astBuilder, subst, &diff));
+ if (!diff)
+ return this;
+
+ (*ioDiff)++;
+ return tryFold(astBuilder, getType(), newType);
+}
+
+Val* AlignOfIntVal::_resolveImplOverride()
+{
+ auto resolvedTypeArg = getTypeArg()->resolve();
+ if (resolvedTypeArg == getTypeArg())
+ return this;
+ return tryFold(getCurrentASTBuilder(), getType(), as<Type>(resolvedTypeArg));
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! CountOfIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void CountOfIntVal::_toTextOverride(StringBuilder& out)
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index 2b4c7ed22..a8b969a94 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -255,21 +255,65 @@ class FuncCallIntVal : public IntVal
Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map);
};
-FIDDLE()
-class CountOfIntVal : public IntVal
+FIDDLE(abstract)
+class SizeOfLikeIntVal : public IntVal
{
FIDDLE(...)
- CountOfIntVal(Type* inType, Type* typeArg) { setOperands(inType, typeArg); }
+ SizeOfLikeIntVal(Type* inType, Type* typeArg) { setOperands(inType, typeArg); }
Val* getTypeArg() { return getOperand(1); }
+ bool _isLinkTimeValOverride() { return false; }
+};
+
+FIDDLE()
+class SizeOfIntVal : public SizeOfLikeIntVal
+{
+ FIDDLE(...)
+ SizeOfIntVal(Type* inType, Type* typeArg)
+ : SizeOfLikeIntVal(inType, typeArg)
+ {
+ }
+
void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
Val* _resolveImplOverride();
- bool _isLinkTimeValOverride() { return false; }
static Val* tryFoldOrNull(ASTBuilder* astBuilder, Type* intType, Type* newType);
+ static Val* tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType);
+};
+
+FIDDLE()
+class AlignOfIntVal : public SizeOfLikeIntVal
+{
+ FIDDLE(...)
+ AlignOfIntVal(Type* inType, Type* typeArg)
+ : SizeOfLikeIntVal(inType, typeArg)
+ {
+ }
+ void _toTextOverride(StringBuilder& out);
+ Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+ Val* _resolveImplOverride();
+
+ static Val* tryFoldOrNull(ASTBuilder* astBuilder, Type* intType, Type* newType);
+ static Val* tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType);
+};
+
+FIDDLE()
+class CountOfIntVal : public SizeOfLikeIntVal
+{
+ FIDDLE(...)
+ CountOfIntVal(Type* inType, Type* typeArg)
+ : SizeOfLikeIntVal(inType, typeArg)
+ {
+ }
+
+ void _toTextOverride(StringBuilder& out);
+ Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+ Val* _resolveImplOverride();
+
+ static Val* tryFoldOrNull(ASTBuilder* astBuilder, Type* intType, Type* newType);
static Val* tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType);
};
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index ad36a7e4a..db507c060 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2060,13 +2060,25 @@ IntVal* SemanticsVisitor::tryConstantFoldExpr(
}
}
- if (auto countOfExpr = expr.as<CountOfExpr>())
+ if (auto sizeOfLikeExpr = expr.as<SizeOfLikeExpr>())
{
- auto type =
- as<Type>(countOfExpr.getExpr()->sizedType->substitute(m_astBuilder, expr.getSubsts()));
- if (type)
+ auto type = as<Type>(
+ sizeOfLikeExpr.getExpr()->sizedType->substitute(m_astBuilder, expr.getSubsts()));
+
+ if (auto sizeOfExpr = expr.as<SizeOfExpr>())
+ {
+ return as<IntVal>(SizeOfIntVal::tryFold(m_astBuilder, expr.getExpr()->type.type, type));
+ }
+ else if (auto alignOfExpr = expr.as<AlignOfExpr>())
+ {
+ return as<IntVal>(
+ AlignOfIntVal::tryFold(m_astBuilder, expr.getExpr()->type.type, type));
+ }
+ else if (auto countOfExpr = expr.as<CountOfExpr>())
+ {
return as<IntVal>(
CountOfIntVal::tryFold(m_astBuilder, expr.getExpr()->type.type, type));
+ }
}
// it is possible that we are referring to a generic value param
@@ -2159,20 +2171,6 @@ IntVal* SemanticsVisitor::tryConstantFoldExpr(
if (val)
return val;
}
- else if (auto sizeOfLikeExpr = as<SizeOfLikeExpr>(expr.getExpr()))
- {
- ASTNaturalLayoutContext context(getASTBuilder(), nullptr);
- const auto size = context.calcSize(sizeOfLikeExpr->sizedType);
- if (!size)
- {
- return nullptr;
- }
-
- auto value = as<AlignOfExpr>(sizeOfLikeExpr) ? size.alignment : size.size;
-
- // We can return as an IntVal
- return getASTBuilder()->getIntVal(expr.getExpr()->type, value);
- }
else if (auto indexExpr = expr.as<IndexExpr>())
{
return tryFoldIndexExpr(indexExpr.getExpr(), kind, circularityInfo);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index f8946f5dc..9920075fe 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1636,6 +1636,22 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(resultVal);
}
+ LoweredValInfo visitSizeOfIntVal(SizeOfIntVal* val)
+ {
+ auto irBuilder = getBuilder();
+ auto typeArg = lowerType(context, as<Type>(val->getTypeArg()));
+ auto count = irBuilder->emitSizeOf(typeArg);
+ return LoweredValInfo::simple(count);
+ }
+
+ LoweredValInfo visitAlignOfIntVal(AlignOfIntVal* val)
+ {
+ auto irBuilder = getBuilder();
+ auto typeArg = lowerType(context, as<Type>(val->getTypeArg()));
+ auto count = irBuilder->emitAlignOf(typeArg);
+ return LoweredValInfo::simple(count);
+ }
+
LoweredValInfo visitCountOfIntVal(CountOfIntVal* val)
{
auto irBuilder = getBuilder();
diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp
index 056c7accb..ea620ebb2 100644
--- a/source/slang/slang-mangle.cpp
+++ b/source/slang/slang-mangle.cpp
@@ -357,6 +357,21 @@ void emitVal(ManglingContext* context, Val* val)
emitVal(context, lookupIntVal->getWitness());
emitName(context, lookupIntVal->getKey()->getName());
}
+ else if (auto sizeOfIntVal = dynamicCast<SizeOfIntVal>(val))
+ {
+ emitRaw(context, "KSO");
+ emitVal(context, sizeOfIntVal->getTypeArg());
+ }
+ else if (auto alignOfIntVal = dynamicCast<AlignOfIntVal>(val))
+ {
+ emitRaw(context, "KAO");
+ emitVal(context, alignOfIntVal->getTypeArg());
+ }
+ else if (auto countOfIntVal = dynamicCast<CountOfIntVal>(val))
+ {
+ emitRaw(context, "KCO");
+ emitVal(context, countOfIntVal->getTypeArg());
+ }
else if (const auto polynomialIntVal = dynamicCast<PolynomialIntVal>(val))
{
emitRaw(context, "KX");
diff --git a/tests/hlsl-intrinsic/size-of/size-of-compile-time.slang b/tests/hlsl-intrinsic/size-of/size-of-compile-time.slang
new file mode 100644
index 000000000..64e2f640b
--- /dev/null
+++ b/tests/hlsl-intrinsic/size-of/size-of-compile-time.slang
@@ -0,0 +1,41 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu -compute -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-slang -compute -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-slang -compute -dx12 -shaderobj
+//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -compute -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cuda -compute -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer
+
+RWStructuredBuffer<int> outputBuffer;
+
+struct Thing<T>
+{
+ uint8_t data[sizeof(T)];
+};
+
+struct AlignThing<T>
+{
+ uint8_t data[alignof(T)];
+};
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ const int idx = asint(dispatchThreadID.x);
+
+ int size = 0;
+
+ switch (idx)
+ {
+ case 0: size = sizeof(Thing<float>); break;
+ case 1: size = sizeof(Thing<float3>); break;
+ case 2: size = sizeof(AlignThing<float>); break;
+ case 3: size = sizeof(AlignThing<float3>); break;
+ }
+
+ // CHECK: 4
+ // CHECK-NEXT: C
+ // CHECK: 4
+ // CHECK-NEXT: 4
+ outputBuffer[idx] = size;
+}