diff options
| author | Julius Ikkala <julius.ikkala@gmail.com> | 2025-05-22 22:10:42 +0300 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-22 22:10:42 +0300 |
| commit | ce238dd878038bf857968931773cc9b10f3b225d (patch) | |
| tree | 2e29a5191fff5eb85a5a7895fd68b7b285bcb198 | |
| parent | 27c6e9b01f7386263bde90e16812be46327015c2 (diff) | |
Make sizeof(T) & alignof(T) of generic types work as compile-time constants (#7213)
* Make sizeof(generic) work as compile-time constant
* format code
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 101 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.h | 52 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 34 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-mangle.cpp | 15 | ||||
| -rw-r--r-- | tests/hlsl-intrinsic/size-of/size-of-compile-time.slang | 41 |
6 files changed, 237 insertions, 22 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 1cdca0440..92e170515 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -3,6 +3,7 @@ #include "slang-ast-builder.h" #include "slang-ast-dispatch.h" +#include "slang-ast-natural-layout.h" #include "slang-check-impl.h" #include "slang-diagnostics.h" #include "slang-mangle.h" @@ -1742,6 +1743,106 @@ Val* FuncCallIntVal::_substituteImplOverride( return this; } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SizeOfIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +void SizeOfIntVal::_toTextOverride(StringBuilder& out) +{ + out << "sizeof("; + getTypeArg()->toText(out); + out << ")"; +} + +Val* SizeOfIntVal::tryFoldOrNull(ASTBuilder* astBuilder, Type* intType, Type* newType) +{ + ASTNaturalLayoutContext context(astBuilder, nullptr); + const auto size = context.calcSize(newType); + + if (!size) + return nullptr; + + return astBuilder->getIntVal(intType, size.size); +} + +Val* SizeOfIntVal::tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType) +{ + if (auto result = tryFoldOrNull(astBuilder, intType, newType)) + return result; + auto result = astBuilder->getOrCreate<SizeOfIntVal>(intType, newType); + return result; +} + +Val* SizeOfIntVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) +{ + int diff = 0; + auto newType = as<Type>(getTypeArg()->substituteImpl(astBuilder, subst, &diff)); + if (!diff) + return this; + + (*ioDiff)++; + return tryFold(astBuilder, getType(), newType); +} + +Val* SizeOfIntVal::_resolveImplOverride() +{ + auto resolvedTypeArg = getTypeArg()->resolve(); + if (resolvedTypeArg == getTypeArg()) + return this; + return tryFold(getCurrentASTBuilder(), getType(), as<Type>(resolvedTypeArg)); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! AlignOfIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +void AlignOfIntVal::_toTextOverride(StringBuilder& out) +{ + out << "alignof("; + getTypeArg()->toText(out); + out << ")"; +} + +Val* AlignOfIntVal::tryFoldOrNull(ASTBuilder* astBuilder, Type* intType, Type* newType) +{ + ASTNaturalLayoutContext context(astBuilder, nullptr); + const auto size = context.calcSize(newType); + + if (!size) + return nullptr; + + return astBuilder->getIntVal(intType, size.alignment); +} + +Val* AlignOfIntVal::tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType) +{ + if (auto result = tryFoldOrNull(astBuilder, intType, newType)) + return result; + auto result = astBuilder->getOrCreate<AlignOfIntVal>(intType, newType); + return result; +} + +Val* AlignOfIntVal::_substituteImplOverride( + ASTBuilder* astBuilder, + SubstitutionSet subst, + int* ioDiff) +{ + int diff = 0; + auto newType = as<Type>(getTypeArg()->substituteImpl(astBuilder, subst, &diff)); + if (!diff) + return this; + + (*ioDiff)++; + return tryFold(astBuilder, getType(), newType); +} + +Val* AlignOfIntVal::_resolveImplOverride() +{ + auto resolvedTypeArg = getTypeArg()->resolve(); + if (resolvedTypeArg == getTypeArg()) + return this; + return tryFold(getCurrentASTBuilder(), getType(), as<Type>(resolvedTypeArg)); +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! CountOfIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void CountOfIntVal::_toTextOverride(StringBuilder& out) diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 2b4c7ed22..a8b969a94 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -255,21 +255,65 @@ class FuncCallIntVal : public IntVal Val* _linkTimeResolveOverride(Dictionary<String, IntVal*>& map); }; -FIDDLE() -class CountOfIntVal : public IntVal +FIDDLE(abstract) +class SizeOfLikeIntVal : public IntVal { FIDDLE(...) - CountOfIntVal(Type* inType, Type* typeArg) { setOperands(inType, typeArg); } + SizeOfLikeIntVal(Type* inType, Type* typeArg) { setOperands(inType, typeArg); } Val* getTypeArg() { return getOperand(1); } + bool _isLinkTimeValOverride() { return false; } +}; + +FIDDLE() +class SizeOfIntVal : public SizeOfLikeIntVal +{ + FIDDLE(...) + SizeOfIntVal(Type* inType, Type* typeArg) + : SizeOfLikeIntVal(inType, typeArg) + { + } + void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); Val* _resolveImplOverride(); - bool _isLinkTimeValOverride() { return false; } static Val* tryFoldOrNull(ASTBuilder* astBuilder, Type* intType, Type* newType); + static Val* tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType); +}; + +FIDDLE() +class AlignOfIntVal : public SizeOfLikeIntVal +{ + FIDDLE(...) + AlignOfIntVal(Type* inType, Type* typeArg) + : SizeOfLikeIntVal(inType, typeArg) + { + } + void _toTextOverride(StringBuilder& out); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + Val* _resolveImplOverride(); + + static Val* tryFoldOrNull(ASTBuilder* astBuilder, Type* intType, Type* newType); + static Val* tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType); +}; + +FIDDLE() +class CountOfIntVal : public SizeOfLikeIntVal +{ + FIDDLE(...) + CountOfIntVal(Type* inType, Type* typeArg) + : SizeOfLikeIntVal(inType, typeArg) + { + } + + void _toTextOverride(StringBuilder& out); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + Val* _resolveImplOverride(); + + static Val* tryFoldOrNull(ASTBuilder* astBuilder, Type* intType, Type* newType); static Val* tryFold(ASTBuilder* astBuilder, Type* intType, Type* newType); }; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index ad36a7e4a..db507c060 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2060,13 +2060,25 @@ IntVal* SemanticsVisitor::tryConstantFoldExpr( } } - if (auto countOfExpr = expr.as<CountOfExpr>()) + if (auto sizeOfLikeExpr = expr.as<SizeOfLikeExpr>()) { - auto type = - as<Type>(countOfExpr.getExpr()->sizedType->substitute(m_astBuilder, expr.getSubsts())); - if (type) + auto type = as<Type>( + sizeOfLikeExpr.getExpr()->sizedType->substitute(m_astBuilder, expr.getSubsts())); + + if (auto sizeOfExpr = expr.as<SizeOfExpr>()) + { + return as<IntVal>(SizeOfIntVal::tryFold(m_astBuilder, expr.getExpr()->type.type, type)); + } + else if (auto alignOfExpr = expr.as<AlignOfExpr>()) + { + return as<IntVal>( + AlignOfIntVal::tryFold(m_astBuilder, expr.getExpr()->type.type, type)); + } + else if (auto countOfExpr = expr.as<CountOfExpr>()) + { return as<IntVal>( CountOfIntVal::tryFold(m_astBuilder, expr.getExpr()->type.type, type)); + } } // it is possible that we are referring to a generic value param @@ -2159,20 +2171,6 @@ IntVal* SemanticsVisitor::tryConstantFoldExpr( if (val) return val; } - else if (auto sizeOfLikeExpr = as<SizeOfLikeExpr>(expr.getExpr())) - { - ASTNaturalLayoutContext context(getASTBuilder(), nullptr); - const auto size = context.calcSize(sizeOfLikeExpr->sizedType); - if (!size) - { - return nullptr; - } - - auto value = as<AlignOfExpr>(sizeOfLikeExpr) ? size.alignment : size.size; - - // We can return as an IntVal - return getASTBuilder()->getIntVal(expr.getExpr()->type, value); - } else if (auto indexExpr = expr.as<IndexExpr>()) { return tryFoldIndexExpr(indexExpr.getExpr(), kind, circularityInfo); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index f8946f5dc..9920075fe 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1636,6 +1636,22 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower return LoweredValInfo::simple(resultVal); } + LoweredValInfo visitSizeOfIntVal(SizeOfIntVal* val) + { + auto irBuilder = getBuilder(); + auto typeArg = lowerType(context, as<Type>(val->getTypeArg())); + auto count = irBuilder->emitSizeOf(typeArg); + return LoweredValInfo::simple(count); + } + + LoweredValInfo visitAlignOfIntVal(AlignOfIntVal* val) + { + auto irBuilder = getBuilder(); + auto typeArg = lowerType(context, as<Type>(val->getTypeArg())); + auto count = irBuilder->emitAlignOf(typeArg); + return LoweredValInfo::simple(count); + } + LoweredValInfo visitCountOfIntVal(CountOfIntVal* val) { auto irBuilder = getBuilder(); diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index 056c7accb..ea620ebb2 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -357,6 +357,21 @@ void emitVal(ManglingContext* context, Val* val) emitVal(context, lookupIntVal->getWitness()); emitName(context, lookupIntVal->getKey()->getName()); } + else if (auto sizeOfIntVal = dynamicCast<SizeOfIntVal>(val)) + { + emitRaw(context, "KSO"); + emitVal(context, sizeOfIntVal->getTypeArg()); + } + else if (auto alignOfIntVal = dynamicCast<AlignOfIntVal>(val)) + { + emitRaw(context, "KAO"); + emitVal(context, alignOfIntVal->getTypeArg()); + } + else if (auto countOfIntVal = dynamicCast<CountOfIntVal>(val)) + { + emitRaw(context, "KCO"); + emitVal(context, countOfIntVal->getTypeArg()); + } else if (const auto polynomialIntVal = dynamicCast<PolynomialIntVal>(val)) { emitRaw(context, "KX"); diff --git a/tests/hlsl-intrinsic/size-of/size-of-compile-time.slang b/tests/hlsl-intrinsic/size-of/size-of-compile-time.slang new file mode 100644 index 000000000..64e2f640b --- /dev/null +++ b/tests/hlsl-intrinsic/size-of/size-of-compile-time.slang @@ -0,0 +1,41 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-slang -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-slang -compute -dx12 -shaderobj +//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -compute -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cuda -compute -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer + +RWStructuredBuffer<int> outputBuffer; + +struct Thing<T> +{ + uint8_t data[sizeof(T)]; +}; + +struct AlignThing<T> +{ + uint8_t data[alignof(T)]; +}; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + const int idx = asint(dispatchThreadID.x); + + int size = 0; + + switch (idx) + { + case 0: size = sizeof(Thing<float>); break; + case 1: size = sizeof(Thing<float3>); break; + case 2: size = sizeof(AlignThing<float>); break; + case 3: size = sizeof(AlignThing<float3>); break; + } + + // CHECK: 4 + // CHECK-NEXT: C + // CHECK: 4 + // CHECK-NEXT: 4 + outputBuffer[idx] = size; +} |
