summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorJulius Ikkala <julius.ikkala@gmail.com>2025-01-14 20:32:29 +0200
committerGitHub <noreply@github.com>2025-01-14 10:32:29 -0800
commitcbdc7e1219e472fd74f7f559d7e417f233e7df39 (patch)
treee051b90e317a875e264c2c8d951668bf0b7d3ad0 /source/slang
parent971996b397711016d47fe961890d7001338c6f23 (diff)
Implement specialization constant support in numthreads / local_size (#5963)
* Allow using specialization constants in numthreads attribute * Add support for GLSL local_size_x_id syntax * Fix overeager specialization constant parsing * Add diagnostics for specialization constant numthreads * Remove unused variable * Fix local_size_x_id not finding existing specialization constant * Allow materializeGetWorkGroupSize to reference specialization constants * Use SpvOpExecutionModeId for modes that require it * Cleanup specialization constant numthreads code * Add tests for specialization constant work group sizes * Fix implicit Slang::Int -> int32_t cast * Fix querying thread group size in reflection API --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ast-modifier.h20
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-check-modifier.cpp108
-rw-r--r--source/slang/slang-diagnostic-defs.h6
-rw-r--r--source/slang/slang-emit-c-like.cpp40
-rw-r--r--source/slang/slang-emit-c-like.h13
-rw-r--r--source/slang/slang-emit-glsl.cpp16
-rw-r--r--source/slang/slang-emit-spirv.cpp55
-rw-r--r--source/slang/slang-ir-collect-global-uniforms.cpp10
-rw-r--r--source/slang/slang-ir-insts.h11
-rw-r--r--source/slang/slang-ir-legalize-varying-params.cpp16
-rw-r--r--source/slang/slang-ir-metal-legalize.cpp26
-rw-r--r--source/slang/slang-ir-translate-glsl-global-var.cpp17
-rw-r--r--source/slang/slang-ir-util.cpp13
-rw-r--r--source/slang/slang-ir-util.h2
-rw-r--r--source/slang/slang-lower-to-ir.cpp46
-rw-r--r--source/slang/slang-parser.cpp11
-rw-r--r--source/slang/slang-reflection-api.cpp20
18 files changed, 346 insertions, 86 deletions
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index f5dd86df1..ee29750a6 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -973,9 +973,14 @@ class GLSLLayoutLocalSizeAttribute : public Attribute
//
// TODO: These should be accessors that use the
// ordinary `args` list, rather than side data.
- IntVal* x;
- IntVal* y;
- IntVal* z;
+ IntVal* extents[3];
+
+ bool axisIsSpecConstId[3];
+
+ // References to specialization constants, for defining the number of
+ // threads with them. If set, the corresponding axis is set to nullptr
+ // above.
+ DeclRef<VarDeclBase> specConstExtents[3];
};
class GLSLLayoutDerivativeGroupQuadAttribute : public Attribute
@@ -1038,9 +1043,12 @@ class NumThreadsAttribute : public Attribute
//
// TODO: These should be accessors that use the
// ordinary `args` list, rather than side data.
- IntVal* x;
- IntVal* y;
- IntVal* z;
+ IntVal* extents[3];
+
+ // References to specialization constants, for defining the number of
+ // threads with them. If set, the corresponding axis is set to nullptr
+ // above.
+ DeclRef<VarDeclBase> specConstExtents[3];
};
class WaveSizeAttribute : public Attribute
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index b3e30dbc2..3ef1e8f3b 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1656,6 +1656,8 @@ public:
void visitModifier(Modifier*);
+ DeclRef<VarDeclBase> tryGetIntSpecializationConstant(Expr* expr);
+
AttributeDecl* lookUpAttributeDecl(Name* attributeName, Scope* scope);
bool hasIntArgs(Attribute* attr, int numArgs);
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 3723c98f8..6e451b5cf 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -114,6 +114,36 @@ void SemanticsVisitor::visitModifier(Modifier*)
// Do nothing with modifiers for now
}
+DeclRef<VarDeclBase> SemanticsVisitor::tryGetIntSpecializationConstant(Expr* expr)
+{
+ // First type-check the expression as normal
+ expr = CheckExpr(expr);
+
+ if (IsErrorExpr(expr))
+ return DeclRef<VarDeclBase>();
+
+ if (!isScalarIntegerType(expr->type))
+ return DeclRef<VarDeclBase>();
+
+ auto specConstVar = as<VarExpr>(expr);
+ if (!specConstVar || !specConstVar->declRef)
+ return DeclRef<VarDeclBase>();
+
+ auto decl = specConstVar->declRef.getDecl();
+ if (!decl)
+ return DeclRef<VarDeclBase>();
+
+ for (auto modifier : decl->modifiers)
+ {
+ if (as<SpecializationConstantAttribute>(modifier) || as<VkConstantIdAttribute>(modifier))
+ {
+ return specConstVar->declRef.as<VarDeclBase>();
+ }
+ }
+
+ return DeclRef<VarDeclBase>();
+}
+
static bool _isDeclAllowedAsAttribute(DeclRef<Decl> declRef)
{
if (as<AttributeDecl>(declRef.getDecl()))
@@ -350,8 +380,6 @@ Modifier* SemanticsVisitor::validateAttribute(
{
SLANG_ASSERT(attr->args.getCount() == 3);
- IntVal* values[3];
-
for (int i = 0; i < 3; ++i)
{
IntVal* value = nullptr;
@@ -359,6 +387,14 @@ Modifier* SemanticsVisitor::validateAttribute(
auto arg = attr->args[i];
if (arg)
{
+ auto specConstDecl = tryGetIntSpecializationConstant(arg);
+ if (specConstDecl)
+ {
+ numThreadsAttr->extents[i] = nullptr;
+ numThreadsAttr->specConstExtents[i] = specConstDecl;
+ continue;
+ }
+
auto intValue = checkLinkTimeConstantIntVal(arg);
if (!intValue)
{
@@ -390,12 +426,8 @@ Modifier* SemanticsVisitor::validateAttribute(
{
value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
}
- values[i] = value;
+ numThreadsAttr->extents[i] = value;
}
-
- numThreadsAttr->x = values[0];
- numThreadsAttr->y = values[1];
- numThreadsAttr->z = values[2];
}
else if (auto waveSizeAttr = as<WaveSizeAttribute>(attr))
{
@@ -1831,15 +1863,24 @@ Modifier* SemanticsVisitor::checkModifier(
{
SLANG_ASSERT(attr->args.getCount() == 3);
- IntVal* values[3];
+ // GLSLLayoutLocalSizeAttribute is always attached to an EmptyDecl.
+ auto decl = as<EmptyDecl>(syntaxNode);
+ SLANG_ASSERT(decl);
for (int i = 0; i < 3; ++i)
{
- IntVal* value = nullptr;
+ attr->extents[i] = nullptr;
auto arg = attr->args[i];
if (arg)
{
+ auto specConstDecl = tryGetIntSpecializationConstant(arg);
+ if (specConstDecl)
+ {
+ attr->specConstExtents[i] = specConstDecl;
+ continue;
+ }
+
auto intValue = checkConstantIntVal(arg);
if (!intValue)
{
@@ -1847,7 +1888,45 @@ Modifier* SemanticsVisitor::checkModifier(
}
if (auto cintVal = as<ConstantIntVal>(intValue))
{
- if (cintVal->getValue() < 1)
+ if (attr->axisIsSpecConstId[i])
+ {
+ // This integer should actually be a reference to a
+ // specialization constant with this ID.
+ Int specConstId = cintVal->getValue();
+
+ for (auto member : decl->parentDecl->members)
+ {
+ auto constantId = member->findModifier<VkConstantIdAttribute>();
+ if (constantId)
+ {
+ SLANG_ASSERT(constantId->args.getCount() == 1);
+ auto id = checkConstantIntVal(constantId->args[0]);
+ if (id->getValue() == specConstId)
+ {
+ attr->specConstExtents[i] =
+ DeclRef<VarDeclBase>(member->getDefaultDeclRef());
+ break;
+ }
+ }
+ }
+
+ // If not found, we need to create a new specialization
+ // constant with this ID.
+ if (!attr->specConstExtents[i])
+ {
+ auto specConstVarDecl = getASTBuilder()->create<VarDecl>();
+ auto constantIdModifier =
+ getASTBuilder()->create<VkConstantIdAttribute>();
+ constantIdModifier->location = (int32_t)specConstId;
+ specConstVarDecl->type.type = getASTBuilder()->getIntType();
+ addModifier(specConstVarDecl, constantIdModifier);
+ decl->parentDecl->addMember(specConstVarDecl);
+ attr->specConstExtents[i] =
+ DeclRef<VarDeclBase>(specConstVarDecl->getDefaultDeclRef());
+ }
+ continue;
+ }
+ else if (cintVal->getValue() < 1)
{
getSink()->diagnose(
attr,
@@ -1856,18 +1935,13 @@ Modifier* SemanticsVisitor::checkModifier(
return nullptr;
}
}
- value = intValue;
+ attr->extents[i] = intValue;
}
else
{
- value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
+ attr->extents[i] = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
}
- values[i] = value;
}
-
- attr->x = values[0];
- attr->y = values[1];
- attr->z = values[2];
}
// Default behavior is to leave things as they are,
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 1d09189cc..d86cd8be2 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -2460,6 +2460,12 @@ DIAGNOSTIC(
unsupportedTargetIntrinsic,
"intrinsic operation '$0' is not supported for the current target.")
DIAGNOSTIC(
+ 55205,
+ Error,
+ unsupportedSpecializationConstantForNumThreads,
+ "Specialization constants are not supported in the 'numthreads' attribute for the current "
+ "target.")
+DIAGNOSTIC(
56001,
Error,
unableToAutoMapCUDATypeToHostType,
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 7b51495e2..d3a9359ff 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -295,14 +295,48 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type)
}
-/* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize(
+IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize(
IRFunc* func,
Int outNumThreads[kThreadGroupAxisCount])
{
+ Int specializationConstantIds[kThreadGroupAxisCount];
+ IRNumThreadsDecoration* decor =
+ getComputeThreadGroupSize(func, outNumThreads, specializationConstantIds);
+
+ for (auto id : specializationConstantIds)
+ {
+ if (id >= 0)
+ {
+ getSink()->diagnose(decor, Diagnostics::unsupportedSpecializationConstantForNumThreads);
+ break;
+ }
+ }
+ return decor;
+}
+
+/* static */ IRNumThreadsDecoration* CLikeSourceEmitter::getComputeThreadGroupSize(
+ IRFunc* func,
+ Int outNumThreads[kThreadGroupAxisCount],
+ Int outSpecializationConstantIds[kThreadGroupAxisCount])
+{
IRNumThreadsDecoration* decor = func->findDecoration<IRNumThreadsDecoration>();
- for (int i = 0; i < 3; ++i)
+ for (int i = 0; i < kThreadGroupAxisCount; ++i)
{
- outNumThreads[i] = decor ? Int(getIntVal(decor->getOperand(i))) : 1;
+ if (!decor)
+ {
+ outNumThreads[i] = 1;
+ outSpecializationConstantIds[i] = -1;
+ }
+ else if (auto specConst = as<IRGlobalParam>(decor->getOperand(i)))
+ {
+ outNumThreads[i] = 1;
+ outSpecializationConstantIds[i] = getSpecializationConstantId(specConst);
+ }
+ else
+ {
+ outNumThreads[i] = Int(getIntVal(decor->getOperand(i)));
+ outSpecializationConstantIds[i] = -1;
+ }
}
return decor;
}
diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h
index e5080f731..1354b7cbd 100644
--- a/source/slang/slang-emit-c-like.h
+++ b/source/slang/slang-emit-c-like.h
@@ -500,11 +500,20 @@ public:
/// different. Returns an empty slice if not a built in type
static UnownedStringSlice getDefaultBuiltinTypeName(IROp op);
- /// Finds the IRNumThreadsDecoration and gets the size from that or sets all dimensions to 1
- static IRNumThreadsDecoration* getComputeThreadGroupSize(
+ /// Finds the IRNumThreadsDecoration and gets the size from that or sets all
+ /// dimensions to 1
+ IRNumThreadsDecoration* getComputeThreadGroupSize(
IRFunc* func,
Int outNumThreads[kThreadGroupAxisCount]);
+ /// Finds the IRNumThreadsDecoration and gets the size from that or sets all
+ /// dimensions to 1. If specialization constants are used for an axis, their
+ /// IDs is reported in non-negative entries of outSpecializationConstantIds.
+ static IRNumThreadsDecoration* getComputeThreadGroupSize(
+ IRFunc* func,
+ Int outNumThreads[kThreadGroupAxisCount],
+ Int outSpecializationConstantIds[kThreadGroupAxisCount]);
+
/// Finds the IRWaveSizeDecoration and gets the size from that.
static IRWaveSizeDecoration* getComputeWaveSize(IRFunc* func, Int* outWaveSize);
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index 23fff37ac..0dab07cfc 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -1335,7 +1335,8 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl(
auto emitLocalSizeLayout = [&]()
{
Int sizeAlongAxis[kThreadGroupAxisCount];
- getComputeThreadGroupSize(irFunc, sizeAlongAxis);
+ Int specializationConstantIds[kThreadGroupAxisCount];
+ getComputeThreadGroupSize(irFunc, sizeAlongAxis, specializationConstantIds);
m_writer->emit("layout(");
char const* axes[] = {"x", "y", "z"};
@@ -1345,8 +1346,17 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl(
m_writer->emit(", ");
m_writer->emit("local_size_");
m_writer->emit(axes[ii]);
- m_writer->emit(" = ");
- m_writer->emit(sizeAlongAxis[ii]);
+
+ if (specializationConstantIds[ii] >= 0)
+ {
+ m_writer->emit("_id = ");
+ m_writer->emit(specializationConstantIds[ii]);
+ }
+ else
+ {
+ m_writer->emit(" = ");
+ m_writer->emit(sizeAlongAxis[ii]);
+ }
}
m_writer->emit(") in;\n");
};
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 068e1563c..2cf84a854 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -4353,23 +4353,36 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
// [3.6. Execution Mode]: LocalSize
case kIROp_NumThreadsDecoration:
{
- // TODO: The `LocalSize` execution mode option requires
- // literal values for the X,Y,Z thread-group sizes.
- // There is a `LocalSizeId` variant that takes `<id>`s
- // for those sizes, and we should consider using that
- // and requiring the appropriate capabilities
- // if any of the operands to the decoration are not
- // literals (in a future where we support non-literals
- // in those positions in the Slang IR).
- //
auto numThreads = cast<IRNumThreadsDecoration>(decoration);
- requireSPIRVExecutionMode(
- decoration,
- dstID,
- SpvExecutionModeLocalSize,
- SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())),
- SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())),
- SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue())));
+ if (numThreads->getXSpecConst() || numThreads->getYSpecConst() ||
+ numThreads->getZSpecConst())
+ {
+ // If any of the dimensions needs an ID, we need to emit
+ // all dimensions as an ID due to how LocalSizeId works.
+ int32_t ids[3];
+ for (int i = 0; i < 3; ++i)
+ ids[i] = ensureInst(numThreads->getOperand(i))->id;
+
+ // LocalSizeId is supported from SPIR-V 1.2 onwards without
+ // any extra capabilities.
+ requireSPIRVExecutionMode(
+ decoration,
+ dstID,
+ SpvExecutionModeLocalSizeId,
+ SpvLiteralInteger::from32(int32_t(ids[0])),
+ SpvLiteralInteger::from32(int32_t(ids[1])),
+ SpvLiteralInteger::from32(int32_t(ids[2])));
+ }
+ else
+ {
+ requireSPIRVExecutionMode(
+ decoration,
+ dstID,
+ SpvExecutionModeLocalSize,
+ SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())),
+ SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())),
+ SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue())));
+ }
}
break;
case kIROp_MaxVertexCountDecoration:
@@ -7977,10 +7990,18 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
if (m_executionModes[entryPoint].add(executionMode))
{
+ SpvOp execModeOp = SpvOpExecutionMode;
+ if (executionMode == SpvExecutionModeLocalSizeId ||
+ executionMode == SpvExecutionModeLocalSizeHintId ||
+ executionMode == SpvExecutionModeSubgroupsPerWorkgroupId)
+ {
+ execModeOp = SpvOpExecutionModeId;
+ }
+
emitInst(
getSection(SpvLogicalSectionID::ExecutionModes),
parentInst,
- SpvOpExecutionMode,
+ execModeOp,
entryPoint,
executionMode,
ops...);
diff --git a/source/slang/slang-ir-collect-global-uniforms.cpp b/source/slang/slang-ir-collect-global-uniforms.cpp
index 1c833a294..372ef298e 100644
--- a/source/slang/slang-ir-collect-global-uniforms.cpp
+++ b/source/slang/slang-ir-collect-global-uniforms.cpp
@@ -279,6 +279,16 @@ struct CollectGlobalUniformParametersContext
continue;
}
+ // NumThreadsDecoration may sometimes be the user for a global
+ // parameter. This occurs when the parameter was supposed to be
+ // a specialization constant, but isn't due to that not being
+ // supported for the target. These can be skipped here and
+ // diagnosed later.
+ if (as<IRNumThreadsDecoration>(user))
+ {
+ continue;
+ }
+
// For each use site for the global parameter, we will
// insert new code right before the instruction that uses
// the parameter.
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index a58c2e900..f46586aa2 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -570,6 +570,7 @@ struct IRInstanceDecoration : IRDecoration
IRIntLit* getCount() { return cast<IRIntLit>(getOperand(0)); }
};
+struct IRGlobalParam;
struct IRNumThreadsDecoration : IRDecoration
{
enum
@@ -578,11 +579,13 @@ struct IRNumThreadsDecoration : IRDecoration
};
IR_LEAF_ISA(NumThreadsDecoration)
- IRIntLit* getX() { return cast<IRIntLit>(getOperand(0)); }
- IRIntLit* getY() { return cast<IRIntLit>(getOperand(1)); }
- IRIntLit* getZ() { return cast<IRIntLit>(getOperand(2)); }
+ IRIntLit* getX() { return as<IRIntLit>(getOperand(0)); }
+ IRIntLit* getY() { return as<IRIntLit>(getOperand(1)); }
+ IRIntLit* getZ() { return as<IRIntLit>(getOperand(2)); }
- IRIntLit* getExtentAlongAxis(int axis) { return cast<IRIntLit>(getOperand(axis)); }
+ IRGlobalParam* getXSpecConst() { return as<IRGlobalParam>(getOperand(0)); }
+ IRGlobalParam* getYSpecConst() { return as<IRGlobalParam>(getOperand(1)); }
+ IRGlobalParam* getZSpecConst() { return as<IRGlobalParam>(getOperand(2)); }
};
struct IRWaveSizeDecoration : IRDecoration
diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp
index 33f3944fd..5a18b533a 100644
--- a/source/slang/slang-ir-legalize-varying-params.cpp
+++ b/source/slang/slang-ir-legalize-varying-params.cpp
@@ -188,7 +188,7 @@ IRInst* emitCalcGroupExtents(IRBuilder& builder, IRFunc* entryPoint, IRVectorTyp
for (int axis = 0; axis < kAxisCount; axis++)
{
- auto litValue = as<IRIntLit>(numThreadsDecor->getExtentAlongAxis(axis));
+ auto litValue = as<IRIntLit>(numThreadsDecor->getOperand(axis));
if (!litValue)
return nullptr;
@@ -1432,6 +1432,20 @@ struct CPUEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegalize
//
groupExtents = emitCalcGroupExtents(builder, m_entryPointFunc, uint3Type);
+ if (!groupExtents)
+ {
+ m_sink->diagnose(
+ m_entryPointFunc,
+ Diagnostics::unsupportedSpecializationConstantForNumThreads);
+
+ // Fill in placeholder values.
+ static const int kAxisCount = 3;
+ IRInst* groupExtentAlongAxis[kAxisCount] = {};
+ for (int axis = 0; axis < kAxisCount; axis++)
+ groupExtentAlongAxis[axis] = builder.getIntValue(uint3Type->getElementType(), 1);
+ groupExtents = builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis);
+ }
+
dispatchThreadID =
emitCalcDispatchThreadID(builder, uint3Type, groupID, groupThreadID, groupExtents);
diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp
index 5bfa62e4a..3b47bd59e 100644
--- a/source/slang/slang-ir-metal-legalize.cpp
+++ b/source/slang/slang-ir-metal-legalize.cpp
@@ -1828,12 +1828,26 @@ struct LegalizeMetalEntryPointContext
IRBuilder svBuilder(builder.getModule());
svBuilder.setInsertBefore(entryPoint.entryPointFunc->getFirstOrdinaryInst());
- auto computeExtent = emitCalcGroupExtents(
- svBuilder,
- entryPoint.entryPointFunc,
- builder.getVectorType(
- builder.getUIntType(),
- builder.getIntValue(builder.getIntType(), 3)));
+ auto uint3Type = builder.getVectorType(
+ builder.getUIntType(),
+ builder.getIntValue(builder.getIntType(), 3));
+ auto computeExtent =
+ emitCalcGroupExtents(svBuilder, entryPoint.entryPointFunc, uint3Type);
+ if (!computeExtent)
+ {
+ m_sink->diagnose(
+ entryPoint.entryPointFunc,
+ Diagnostics::unsupportedSpecializationConstantForNumThreads);
+
+ // Fill in placeholder values.
+ static const int kAxisCount = 3;
+ IRInst* groupExtentAlongAxis[kAxisCount] = {};
+ for (int axis = 0; axis < kAxisCount; axis++)
+ groupExtentAlongAxis[axis] =
+ builder.getIntValue(uint3Type->getElementType(), 1);
+ computeExtent =
+ builder.emitMakeVector(uint3Type, kAxisCount, groupExtentAlongAxis);
+ }
auto groupIndexCalc = emitCalcGroupIndex(
svBuilder,
entryPointToGroupThreadId[entryPoint.entryPointFunc],
diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp
index a44e16a7c..077cdb98d 100644
--- a/source/slang/slang-ir-translate-glsl-global-var.cpp
+++ b/source/slang/slang-ir-translate-glsl-global-var.cpp
@@ -282,10 +282,11 @@ struct GlobalVarTranslationContext
if (!numthreadsDecor)
return;
builder.setInsertBefore(use->getUser());
- IRInst* values[] = {
- numthreadsDecor->getExtentAlongAxis(0),
- numthreadsDecor->getExtentAlongAxis(1),
- numthreadsDecor->getExtentAlongAxis(2)};
+ IRInst* values[3] = {
+ numthreadsDecor->getOperand(0),
+ numthreadsDecor->getOperand(1),
+ numthreadsDecor->getOperand(2)};
+
auto workgroupSize = builder.emitMakeVector(
builder.getVectorType(builder.getIntType(), 3),
3,
@@ -328,10 +329,10 @@ struct GlobalVarTranslationContext
if (!firstBlock)
continue;
builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
- IRInst* args[] = {
- numthreadsDecor->getExtentAlongAxis(0),
- numthreadsDecor->getExtentAlongAxis(1),
- numthreadsDecor->getExtentAlongAxis(2)};
+ IRInst* args[3] = {
+ numthreadsDecor->getOperand(0),
+ numthreadsDecor->getOperand(1),
+ numthreadsDecor->getOperand(2)};
auto workgroupSize =
builder.emitMakeVector(workgroupSizeInst->getFullType(), 3, args);
builder.emitStore(globalVar, workgroupSize);
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index c753600a7..d05e1db7d 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -1973,4 +1973,17 @@ IRType* getIRVectorBaseType(IRType* type)
return as<IRVectorType>(type)->getElementType();
}
+Int getSpecializationConstantId(IRGlobalParam* param)
+{
+ auto layout = findVarLayout(param);
+ if (!layout)
+ return 0;
+
+ auto offset = layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant);
+ if (!offset)
+ return 0;
+
+ return offset->getOffset();
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index e23aeb618..666ac71c0 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -373,6 +373,8 @@ inline bool isSPIRV(CodeGenTarget codeGenTarget)
int getIRVectorElementSize(IRType* type);
IRType* getIRVectorBaseType(IRType* type);
+Int getSpecializationConstantId(IRGlobalParam* param);
+
} // namespace Slang
#endif
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index e82fc03fd..086345719 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -7625,12 +7625,29 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
verifyComputeDerivativeGroupModifier = true;
getAllEntryPointsNoOverride(entryPoints);
+
+ LoweredValInfo extents[3];
+
+ for (int i = 0; i < 3; ++i)
+ {
+ extents[i] = layoutLocalSizeAttr->specConstExtents[i]
+ ? emitDeclRef(
+ context,
+ layoutLocalSizeAttr->specConstExtents[i],
+ lowerType(
+ context,
+ getType(
+ context->astBuilder,
+ layoutLocalSizeAttr->specConstExtents[i])))
+ : lowerVal(context, layoutLocalSizeAttr->extents[i]);
+ }
+
for (auto d : entryPoints)
as<IRNumThreadsDecoration>(getBuilder()->addNumThreadsDecoration(
d,
- getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->x)),
- getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->y)),
- getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->z))));
+ getSimpleVal(context, extents[0]),
+ getSimpleVal(context, extents[1]),
+ getSimpleVal(context, extents[2])));
}
else if (as<GLSLLayoutDerivativeGroupQuadAttribute>(modifier))
{
@@ -10336,11 +10353,28 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
else if (auto numThreadsAttr = as<NumThreadsAttribute>(modifier))
{
+ LoweredValInfo extents[3];
+
+ for (int i = 0; i < 3; ++i)
+ {
+ extents[i] = numThreadsAttr->specConstExtents[i]
+ ? emitDeclRef(
+ context,
+ numThreadsAttr->specConstExtents[i],
+ lowerType(
+ context,
+ getType(
+ context->astBuilder,
+ numThreadsAttr->specConstExtents[i])))
+ : lowerVal(context, numThreadsAttr->extents[i]);
+ }
+
numThreadsDecor = as<IRNumThreadsDecoration>(getBuilder()->addNumThreadsDecoration(
irFunc,
- getSimpleVal(context, lowerVal(context, numThreadsAttr->x)),
- getSimpleVal(context, lowerVal(context, numThreadsAttr->y)),
- getSimpleVal(context, lowerVal(context, numThreadsAttr->z))));
+ getSimpleVal(context, extents[0]),
+ getSimpleVal(context, extents[1]),
+ getSimpleVal(context, extents[2])));
+ numThreadsDecor->sourceLoc = numThreadsAttr->loc;
}
else if (auto waveSizeAttr = as<WaveSizeAttribute>(modifier))
{
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index c275a868b..6ae41a2eb 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -8437,7 +8437,9 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/)
int localSizeIndex = -1;
if (nameText.startsWith(localSizePrefix) &&
- nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1)
+ (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1 ||
+ (nameText.endsWith("_id") &&
+ (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 4))))
{
char lastChar = nameText[SLANG_COUNT_OF(localSizePrefix) - 1];
localSizeIndex = (lastChar >= 'x' && lastChar <= 'z') ? (lastChar - 'x') : -1;
@@ -8451,6 +8453,8 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/)
numThreadsAttrib->args.setCount(3);
for (auto& i : numThreadsAttrib->args)
i = nullptr;
+ for (auto& b : numThreadsAttrib->axisIsSpecConstId)
+ b = false;
// Just mark the loc and name from the first in the list
numThreadsAttrib->keywordName = getName(parser, "numthreads");
@@ -8467,6 +8471,11 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/)
}
numThreadsAttrib->args[localSizeIndex] = expr;
+
+ // We can't resolve the specialization constant declaration
+ // here, because it may not even exist. IDs pointing to unnamed
+ // specialization constants are allowed in GLSL.
+ numThreadsAttrib->axisIsSpecConstId[localSizeIndex] = nameText.endsWith("_id");
}
}
else if (nameText == "derivative_group_quadsNV")
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index d235c8270..d1adfedc0 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -4033,18 +4033,14 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize(
auto numThreadsAttribute = entryPointFunc.getDecl()->findModifier<NumThreadsAttribute>();
if (numThreadsAttribute)
{
- if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->x))
- sizeAlongAxis[0] = (SlangUInt)cint->getValue();
- else if (numThreadsAttribute->x)
- sizeAlongAxis[0] = 0;
- if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->y))
- sizeAlongAxis[1] = (SlangUInt)cint->getValue();
- else if (numThreadsAttribute->y)
- sizeAlongAxis[1] = 0;
- if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->z))
- sizeAlongAxis[2] = (SlangUInt)cint->getValue();
- else if (numThreadsAttribute->z)
- sizeAlongAxis[2] = 0;
+ for (int i = 0; i < 3; ++i)
+ {
+ if (auto cint =
+ entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->extents[i]))
+ sizeAlongAxis[i] = (SlangUInt)cint->getValue();
+ else if (numThreadsAttribute->extents[i])
+ sizeAlongAxis[i] = 0;
+ }
}
//