From 2d0912bfe2de7799b32e80722fa5c8dc279a339b Mon Sep 17 00:00:00 2001 From: Ellie Hermaszewska Date: Wed, 31 Jan 2024 03:28:04 +0800 Subject: Correctly apply glsl local size layout to entry points during lowering (#3528) * Correctly apply glsl local size layout to entry points during lowering * Test for glsl layout correctness --- source/slang/slang-ast-modifier.h | 8 +++++++ source/slang/slang-check-modifier.cpp | 33 ++++++++++++++++++++++++++++ source/slang/slang-ir-insts.h | 2 ++ source/slang/slang-ir.cpp | 13 +++++++++++ source/slang/slang-lower-to-ir.cpp | 40 ++++++++++++++++++++++++---------- source/slang/slang-parser.cpp | 2 ++ tests/glsl/compute-shader-layout.slang | 22 +++++++++++++++++++ 7 files changed, 109 insertions(+), 11 deletions(-) create mode 100644 tests/glsl/compute-shader-layout.slang diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 2cf4c93cf..97487f131 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -748,6 +748,14 @@ class DisableArrayFlatteningAttribute : public Attribute class GLSLLayoutLocalSizeAttribute : public Attribute { SLANG_AST_CLASS(GLSLLayoutLocalSizeAttribute) + + // The number of threads to use along each axis + // + // TODO: These should be accessors that use the + // ordinary `args` list, rather than side data. + int32_t x; + int32_t y; + int32_t z; }; // TODO: for attributes that take arguments, the syntax node diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 413e1c157..96ec0acc8 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -1272,6 +1272,39 @@ namespace Slang } } + if (auto attr = as(m)) + { + SLANG_ASSERT(attr->args.getCount() == 3); + + int32_t values[3]; + + for (int i = 0; i < 3; ++i) + { + int32_t value = 1; + + auto arg = attr->args[i]; + if (arg) + { + auto intValue = checkConstantIntVal(arg); + if (!intValue) + { + return nullptr; + } + if (intValue->getValue() < 1) + { + getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, intValue->getValue()); + return nullptr; + } + value = int32_t(intValue->getValue()); + } + values[i] = value; + } + + attr->x = values[0]; + attr->y = values[1]; + attr->z = values[2]; + } + // Default behavior is to leave things as they are, // and assume that modifiers are mostly already checked. // diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 2d24efb43..c8309e3df 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3483,6 +3483,8 @@ public: IRInst* addFloatingModeOverrideDecoration(IRInst* dest, FloatingPointMode mode); + IRInst* addNumThreadsDecoration(IRInst* inst, Int x, Int y, Int z); + IRInst* emitSpecializeInst( IRType* type, IRInst* genericVal, diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 1dbb21a0d..94de28089 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5032,6 +5032,19 @@ namespace Slang getIntValue(getIntType(), (IRIntegerValue)mode)); } + IRInst* IRBuilder::addNumThreadsDecoration(IRInst* inst, Int x, Int y, Int z) + { + IRType* intType = getIntType(); + + IRInst* operands[3] = { + getIntValue(intType, x), + getIntValue(intType, y), + getIntValue(intType, z) + }; + + return addDecoration(inst, kIROp_NumThreadsDecoration, operands, 3); + } + IRInst* IRBuilder::emitSwizzle( IRType* type, IRInst* base, diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 00db77511..e0bf72f7d 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -6938,7 +6938,6 @@ struct DeclLoweringVisitor : DeclVisitor IGNORED_CASE(IncludeDecl) IGNORED_CASE(ImplementingDecl) IGNORED_CASE(UsingDecl) - IGNORED_CASE(EmptyDecl) IGNORED_CASE(SyntaxDecl) IGNORED_CASE(AttributeDecl) IGNORED_CASE(NamespaceDecl) @@ -6947,6 +6946,29 @@ struct DeclLoweringVisitor : DeclVisitor #undef IGNORED_CASE + LoweredValInfo visitEmptyDecl(EmptyDecl* decl) + { + for(const auto modifier : decl->modifiers) + { + if(const auto layoutLocalSizeAttr = as(modifier)) + { + for(const auto d : context->irBuilder->getModule()->getModuleInst()->getGlobalInsts()) + { + if(d->findDecoration()) + { + getBuilder()->addNumThreadsDecoration( + d, + layoutLocalSizeAttr->x, + layoutLocalSizeAttr->y, + layoutLocalSizeAttr->z + ); + } + } + } + } + return LoweredValInfo(); + } + void ensureInsertAtGlobalScope(IRBuilder* builder) { auto inst = builder->getInsertLoc().getInst(); @@ -9325,16 +9347,12 @@ struct DeclLoweringVisitor : DeclVisitor } else if (auto numThreadsAttr = as(modifier)) { - auto builder = getBuilder(); - IRType* intType = builder->getIntType(); - - IRInst* operands[3] = { - builder->getIntValue(intType, numThreadsAttr->x), - builder->getIntValue(intType, numThreadsAttr->y), - builder->getIntValue(intType, numThreadsAttr->z) - }; - - builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3); + getBuilder()->addNumThreadsDecoration( + irFunc, + numThreadsAttr->x, + numThreadsAttr->y, + numThreadsAttr->z + ); } else if (as(modifier)) { diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index f0c9e175f..f5312e645 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -7630,6 +7630,8 @@ namespace Slang { numThreadsAttrib = parser->astBuilder->create(); numThreadsAttrib->args.setCount(3); + for (auto& i : numThreadsAttrib->args) + i = nullptr; // Just mark the loc and name from the first in the list numThreadsAttrib->keywordName = getName(parser, "numthreads"); diff --git a/tests/glsl/compute-shader-layout.slang b/tests/glsl/compute-shader-layout.slang new file mode 100644 index 000000000..b81a87aed --- /dev/null +++ b/tests/glsl/compute-shader-layout.slang @@ -0,0 +1,22 @@ +//TEST:SIMPLE(filecheck=CHECKGLSLANG): -target spirv -stage compute -entry main -allow-glsl +//TEST:SIMPLE(filecheck=CHECKDIRECT): -target spirv -stage compute -entry main -allow-glsl -emit-spirv-directly +#version 430 +precision highp float; +precision highp int; + +layout(binding = 0) buffer MyBlockName +{ + vec4 data[]; +} output_data; + +// CHECKGLSLANG-DAG: [[x:%[^ ]+]] = OpConstant {{%[^ ]+}} 44 +// CHECKGLSLANG-DAG: [[y:%[^ ]+]] = OpConstant {{%[^ ]+}} 45 +// CHECKGLSLANG-DAG: [[z:%[^ ]+]] = OpConstant {{%[^ ]+}} 46 +// CHECKGLSLANG: %gl_WorkGroupSize = OpConstantComposite {{%[^ ]+}} [[x]] [[y]] [[z]] + +// CHECKDIRECT: OpExecutionMode %main LocalSize 44 45 46 +layout(local_size_x = 44, local_size_y = 45, local_size_z = 46) in; +void main() +{ + output_data.data[gl_GlobalInvocationID.x] = vec4(gl_GlobalInvocationID, 1); +} -- cgit v1.2.3