summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2024-01-31 03:28:04 +0800
committerGitHub <noreply@github.com>2024-01-30 11:28:04 -0800
commit2d0912bfe2de7799b32e80722fa5c8dc279a339b (patch)
tree152bfe7c054f035090c84fabd7d9d12e9f5fc362
parent470c5a28f5b84353f077c2d871db65cddd5f923a (diff)
Correctly apply glsl local size layout to entry points during lowering (#3528)
* Correctly apply glsl local size layout to entry points during lowering * Test for glsl layout correctness
-rw-r--r--source/slang/slang-ast-modifier.h8
-rw-r--r--source/slang/slang-check-modifier.cpp33
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-ir.cpp13
-rw-r--r--source/slang/slang-lower-to-ir.cpp40
-rw-r--r--source/slang/slang-parser.cpp2
-rw-r--r--tests/glsl/compute-shader-layout.slang22
7 files changed, 109 insertions, 11 deletions
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 2cf4c93cf..97487f131 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -748,6 +748,14 @@ class DisableArrayFlatteningAttribute : public Attribute
class GLSLLayoutLocalSizeAttribute : public Attribute
{
SLANG_AST_CLASS(GLSLLayoutLocalSizeAttribute)
+
+ // The number of threads to use along each axis
+ //
+ // TODO: These should be accessors that use the
+ // ordinary `args` list, rather than side data.
+ int32_t x;
+ int32_t y;
+ int32_t z;
};
// TODO: for attributes that take arguments, the syntax node
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 413e1c157..96ec0acc8 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -1272,6 +1272,39 @@ namespace Slang
}
}
+ if (auto attr = as<GLSLLayoutLocalSizeAttribute>(m))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 3);
+
+ int32_t values[3];
+
+ for (int i = 0; i < 3; ++i)
+ {
+ int32_t value = 1;
+
+ auto arg = attr->args[i];
+ if (arg)
+ {
+ auto intValue = checkConstantIntVal(arg);
+ if (!intValue)
+ {
+ return nullptr;
+ }
+ if (intValue->getValue() < 1)
+ {
+ getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, intValue->getValue());
+ return nullptr;
+ }
+ value = int32_t(intValue->getValue());
+ }
+ values[i] = value;
+ }
+
+ attr->x = values[0];
+ attr->y = values[1];
+ attr->z = values[2];
+ }
+
// Default behavior is to leave things as they are,
// and assume that modifiers are mostly already checked.
//
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 2d24efb43..c8309e3df 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3483,6 +3483,8 @@ public:
IRInst* addFloatingModeOverrideDecoration(IRInst* dest, FloatingPointMode mode);
+ IRInst* addNumThreadsDecoration(IRInst* inst, Int x, Int y, Int z);
+
IRInst* emitSpecializeInst(
IRType* type,
IRInst* genericVal,
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 1dbb21a0d..94de28089 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5032,6 +5032,19 @@ namespace Slang
getIntValue(getIntType(), (IRIntegerValue)mode));
}
+ IRInst* IRBuilder::addNumThreadsDecoration(IRInst* inst, Int x, Int y, Int z)
+ {
+ IRType* intType = getIntType();
+
+ IRInst* operands[3] = {
+ getIntValue(intType, x),
+ getIntValue(intType, y),
+ getIntValue(intType, z)
+ };
+
+ return addDecoration(inst, kIROp_NumThreadsDecoration, operands, 3);
+ }
+
IRInst* IRBuilder::emitSwizzle(
IRType* type,
IRInst* base,
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 00db77511..e0bf72f7d 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -6938,7 +6938,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
IGNORED_CASE(IncludeDecl)
IGNORED_CASE(ImplementingDecl)
IGNORED_CASE(UsingDecl)
- IGNORED_CASE(EmptyDecl)
IGNORED_CASE(SyntaxDecl)
IGNORED_CASE(AttributeDecl)
IGNORED_CASE(NamespaceDecl)
@@ -6947,6 +6946,29 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
#undef IGNORED_CASE
+ LoweredValInfo visitEmptyDecl(EmptyDecl* decl)
+ {
+ for(const auto modifier : decl->modifiers)
+ {
+ if(const auto layoutLocalSizeAttr = as<GLSLLayoutLocalSizeAttribute>(modifier))
+ {
+ for(const auto d : context->irBuilder->getModule()->getModuleInst()->getGlobalInsts())
+ {
+ if(d->findDecoration<IREntryPointDecoration>())
+ {
+ getBuilder()->addNumThreadsDecoration(
+ d,
+ layoutLocalSizeAttr->x,
+ layoutLocalSizeAttr->y,
+ layoutLocalSizeAttr->z
+ );
+ }
+ }
+ }
+ }
+ return LoweredValInfo();
+ }
+
void ensureInsertAtGlobalScope(IRBuilder* builder)
{
auto inst = builder->getInsertLoc().getInst();
@@ -9325,16 +9347,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
else if (auto numThreadsAttr = as<NumThreadsAttribute>(modifier))
{
- auto builder = getBuilder();
- IRType* intType = builder->getIntType();
-
- IRInst* operands[3] = {
- builder->getIntValue(intType, numThreadsAttr->x),
- builder->getIntValue(intType, numThreadsAttr->y),
- builder->getIntValue(intType, numThreadsAttr->z)
- };
-
- builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3);
+ getBuilder()->addNumThreadsDecoration(
+ irFunc,
+ numThreadsAttr->x,
+ numThreadsAttr->y,
+ numThreadsAttr->z
+ );
}
else if (as<ReadNoneAttribute>(modifier))
{
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index f0c9e175f..f5312e645 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -7630,6 +7630,8 @@ namespace Slang
{
numThreadsAttrib = parser->astBuilder->create<GLSLLayoutLocalSizeAttribute>();
numThreadsAttrib->args.setCount(3);
+ for (auto& i : numThreadsAttrib->args)
+ i = nullptr;
// Just mark the loc and name from the first in the list
numThreadsAttrib->keywordName = getName(parser, "numthreads");
diff --git a/tests/glsl/compute-shader-layout.slang b/tests/glsl/compute-shader-layout.slang
new file mode 100644
index 000000000..b81a87aed
--- /dev/null
+++ b/tests/glsl/compute-shader-layout.slang
@@ -0,0 +1,22 @@
+//TEST:SIMPLE(filecheck=CHECKGLSLANG): -target spirv -stage compute -entry main -allow-glsl
+//TEST:SIMPLE(filecheck=CHECKDIRECT): -target spirv -stage compute -entry main -allow-glsl -emit-spirv-directly
+#version 430
+precision highp float;
+precision highp int;
+
+layout(binding = 0) buffer MyBlockName
+{
+ vec4 data[];
+} output_data;
+
+// CHECKGLSLANG-DAG: [[x:%[^ ]+]] = OpConstant {{%[^ ]+}} 44
+// CHECKGLSLANG-DAG: [[y:%[^ ]+]] = OpConstant {{%[^ ]+}} 45
+// CHECKGLSLANG-DAG: [[z:%[^ ]+]] = OpConstant {{%[^ ]+}} 46
+// CHECKGLSLANG: %gl_WorkGroupSize = OpConstantComposite {{%[^ ]+}} [[x]] [[y]] [[z]]
+
+// CHECKDIRECT: OpExecutionMode %main LocalSize 44 45 46
+layout(local_size_x = 44, local_size_y = 45, local_size_z = 46) in;
+void main()
+{
+ output_data.data[gl_GlobalInvocationID.x] = vec4(gl_GlobalInvocationID, 1);
+}