summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-10-09 17:15:20 -0700
committerGitHub <noreply@github.com>2024-10-09 17:15:20 -0700
commit75481ea3b0654eeb727cabc718258984e7753e02 (patch)
tree4d2e56517d1413ca45c4038049035fd971ac903d
parentb8aab84e2c4c3e6d91d75ffcebfcc2f6e84da01c (diff)
Support constant folding for static array access. (#5248)
* Support constant folding for static array access. * Fix test.
-rw-r--r--source/slang/slang-check-expr.cpp52
-rw-r--r--source/slang/slang-check-impl.h5
-rw-r--r--tests/language-feature/constants/static-array-indexing.slang18
3 files changed, 74 insertions, 1 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 842ffb527..ec4182059 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2021,10 +2021,60 @@ namespace Slang
// We can return as an IntVal
return getASTBuilder()->getIntVal(expr.getExpr()->type, value);
}
-
+ else if (auto indexExpr = expr.as<IndexExpr>())
+ {
+ return tryFoldIndexExpr(indexExpr.getExpr(), kind, circularityInfo);
+ }
return nullptr;
}
+ IntVal* SemanticsVisitor::tryFoldIndexExpr(
+ SubstExpr<IndexExpr> expr,
+ ConstantFoldingKind kind,
+ ConstantFoldingCircularityInfo* circularityInfo)
+ {
+ // Ad-hoc constant folding for index expressions.
+ // TOOD: we should generalize this by extending `Val` to support compile-time constants that are
+ // not just integers, but also arrays and structs etc, so that we can independently fold
+ // the base expression and the index expression, and then form a ElementExtractVal() from an
+ // index expr.
+ // For now we just specialize case for array expression that is an initialization list.
+ // And this won't work if the array is a link-time constant.
+ //
+ auto declRefExpr = as<DeclRefExpr>(expr.getExpr()->baseExpression);
+ if (!declRefExpr)
+ return nullptr;
+ auto varDecl = as<VarDecl>(declRefExpr->declRef.getDecl());
+ if (!varDecl)
+ return nullptr;
+ auto type = varDecl->getType();
+ if (!type)
+ return nullptr;
+ auto arrayType = as<ArrayExpressionType>(type);
+ if (!arrayType)
+ return nullptr;
+ if (!varDecl->hasModifier<ConstModifier>())
+ return nullptr;
+ if (isGlobalDecl(varDecl) && !varDecl->hasModifier<HLSLStaticModifier>())
+ return nullptr;
+ if (!varDecl->initExpr)
+ return nullptr;
+ auto arrayContentExpr = as<InitializerListExpr>(varDecl->initExpr);
+ if (!arrayContentExpr)
+ return nullptr;
+ if (expr.getExpr()->indexExprs.getCount() != 1)
+ return nullptr;
+ auto indexVal = as<ConstantIntVal>(tryFoldIntegerConstantExpression(
+ expr.getExpr()->indexExprs[0], kind, circularityInfo));
+ if (!indexVal)
+ return nullptr;
+ auto index = indexVal->getValue();
+ if (index < 0 || index >= arrayContentExpr->args.getCount())
+ return nullptr;
+ auto elementExpr = arrayContentExpr->args[Index(index)];
+ return tryFoldIntegerConstantExpression(elementExpr, kind, circularityInfo);
+ }
+
IntVal* SemanticsVisitor::tryFoldIntegerConstantExpression(
SubstExpr<Expr> expr,
ConstantFoldingKind kind,
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 2f3eae0c5..5d70c36d5 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2089,6 +2089,11 @@ namespace Slang
ConstantFoldingKind kind,
ConstantFoldingCircularityInfo* circularityInfo);
+ IntVal* tryFoldIndexExpr(
+ SubstExpr<IndexExpr> expr,
+ ConstantFoldingKind kind,
+ ConstantFoldingCircularityInfo* circularityInfo);
+
// Enforce that an expression resolves to an integer constant, and get its value
enum class IntegerConstantExpressionCoercionType
{
diff --git a/tests/language-feature/constants/static-array-indexing.slang b/tests/language-feature/constants/static-array-indexing.slang
new file mode 100644
index 000000000..0a7963b34
--- /dev/null
+++ b/tests/language-feature/constants/static-array-indexing.slang
@@ -0,0 +1,18 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):
+
+int check<int v>()
+{
+ return v;
+}
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
+{
+ int tid = dispatchThreadID.x;
+ const int a[] = { 1, 2, 3, 4 };
+ // CHECK: 4
+ outputBuffer[tid] = check<a[3]>();
+}