summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/glsl.meta.slang23
-rw-r--r--source/slang/hlsl.meta.slang5
-rw-r--r--source/slang/slang-check-modifier.cpp7
-rw-r--r--source/slang/slang-ir-inst-defs.h4
-rw-r--r--source/slang/slang-ir-translate-glsl-global-var.cpp83
-rw-r--r--tests/spirv/subgroup-size-2.slang33
-rw-r--r--tests/spirv/subgroup-size.slang16
7 files changed, 156 insertions, 15 deletions
diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang
index e915442c7..54f62308f 100644
--- a/source/slang/glsl.meta.slang
+++ b/source/slang/glsl.meta.slang
@@ -124,20 +124,13 @@ public property uint3 gl_NumWorkGroups {
}
}
-[require(glsl)]
-[require(spirv)]
-public property uint3 gl_WorkGroupSize {
-
- get {
- __target_switch
- {
- case glsl:
- __intrinsic_asm "(gl_WorkGroupSize)";
- case spirv:
- return spirv_asm {
- result:$$uint3 = OpLoad builtin(WorkgroupSize:uint3);
- };
- }
+[require(compute)]
+public property uint3 gl_WorkGroupSize
+{
+ [__unsafeForceInlineEarly]
+ get
+ {
+ return WorkgroupSize();
}
}
@@ -7902,4 +7895,4 @@ __spirv_version(1.3)
public bool allInvocationsEqual(bool value)
{
return WaveActiveAllEqual(value);
-} \ No newline at end of file
+}
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index fa2a7ccb1..ef5c3ae5d 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -3483,6 +3483,11 @@ void AllMemoryBarrierWithGroupSync()
}
}
+// Returns the workgroup size of the calling entry point.
+[require(compute)]
+__intrinsic_op($(kIROp_GetWorkGroupSize))
+int3 WorkgroupSize();
+
// Test if any components is non-zero (HLSL SM 1.0)
__generic<T : __BuiltinType>
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index af47eaedb..6d39f977c 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -348,8 +348,15 @@ namespace Slang
getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, constIntVal->getValue());
return false;
}
+ if (intValue->getType() != m_astBuilder->getIntType())
+ {
+ intValue = m_astBuilder->getIntVal(m_astBuilder->getIntType(), constIntVal->getValue());
+ }
}
+ // Make sure we always canonicalize the type to int.
value = intValue;
+ if (value->getType() != m_astBuilder->getIntType())
+ value = m_astBuilder->getTypeCastIntVal(m_astBuilder->getIntType(), value);
}
else
{
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 25f331708..5acb22674 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -359,6 +359,10 @@ INST(UnpackAnyValue, unpackAnyValue, 1, 0)
INST(WitnessTableEntry, witness_table_entry, 2, 0)
INST(InterfaceRequirementEntry, interface_req_entry, 2, GLOBAL)
+// An inst to represent the workgroup size of the calling entry point.
+// We will materialize this inst during `translateGLSLGlobalVar`.
+INST(GetWorkGroupSize, kIROp_GetWorkGroupSize, 0, HOISTABLE)
+
INST(Param, param, 0, 0)
INST(StructField, field, 2, 0)
INST(Var, var, 0, 0)
diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp
index 57bd71418..575a66457 100644
--- a/source/slang/slang-ir-translate-glsl-global-var.cpp
+++ b/source/slang/slang-ir-translate-glsl-global-var.cpp
@@ -18,10 +18,15 @@ namespace Slang
buildEntryPointReferenceGraph(referencingEntryPoints, module);
List<IRInst*> entryPoints;
+ // Traverse the module to find all entry points.
+ // If we see a `GetWorkGroupSize` instruction, we will materialize it.
+ //
for (auto inst : module->getGlobalInsts())
{
if (inst->getOp() == kIROp_Func && inst->findDecoration<IREntryPointDecoration>())
entryPoints.add(inst);
+ else if (inst->getOp() == kIROp_GetWorkGroupSize)
+ materializeGetWorkGroupSize(module, referencingEntryPoints, inst);
}
IRBuilder builder(module);
@@ -223,6 +228,84 @@ namespace Slang
entryPointFunc->setFullType(newFuncType);
}
}
+
+ // If we see a `GetWorkGroupSize` instruction, we should materialize it by replacing its uses with a constant
+ // that represent the workgroup size of the calling entrypoint.
+ // This is trivial if the `GetWorkGroupSize` instruction is used from a function called by one entry point.
+ // If it is used in a place reachable from multiple entry points, we will introduce a global variable to represent
+ // the workgroup size, and replace the uses with a load from the global variable.
+ // We will assign the value of the global variable at the start of each entry point.
+ //
+ void materializeGetWorkGroupSize(IRModule* module, Dictionary<IRInst*, HashSet<IRFunc*>>& referenceGraph, IRInst* workgroupSizeInst)
+ {
+ IRBuilder builder(workgroupSizeInst);
+ traverseUses(workgroupSizeInst, [&](IRUse* use)
+ {
+ if (auto parentFunc = getParentFunc(use->getUser()))
+ {
+ auto referenceSet = referenceGraph.tryGetValue(parentFunc);
+ if (!referenceSet)
+ return;
+ if (referenceSet->getCount() == 1)
+ {
+ // If the function that uses the workgroup size is only used by one entry point,
+ // we can materialize the workgroup size by substituting the use with a constant.
+ auto entryPoint = *referenceSet->begin();
+ auto numthreadsDecor = entryPoint->findDecoration<IRNumThreadsDecoration>();
+ if (!numthreadsDecor)
+ return;
+ builder.setInsertBefore(use->getUser());
+ IRInst* values[] = {
+ numthreadsDecor->getExtentAlongAxis(0),
+ numthreadsDecor->getExtentAlongAxis(1),
+ numthreadsDecor->getExtentAlongAxis(2) };
+ auto workgroupSize = builder.emitMakeVector(builder.getVectorType(builder.getIntType(), 3),
+ 3, values);
+ builder.replaceOperand(use, workgroupSize);
+ }
+ }
+ });
+
+ // If workgroupSizeInst still has uses, it means it is used by multiple entry points.
+ // We need to introduce a global variable and assign value to it in each entry point.
+
+ if (!workgroupSizeInst->hasUses())
+ return;
+ builder.setInsertBefore(workgroupSizeInst);
+ auto globalVar = builder.createGlobalVar(workgroupSizeInst->getFullType());
+
+ // Replace all remaining uses of the workgroupSize inst of a load from globalVar.
+ traverseUses(workgroupSizeInst, [&](IRUse* use)
+ {
+ builder.setInsertBefore(use->getUser());
+ auto load = builder.emitLoad(globalVar);
+ builder.replaceOperand(use, load);
+ });
+
+ // Now insert assignments from each entry point.
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ auto func = as<IRFunc>(getResolvedInstForDecorations(globalInst));
+ if (!func)
+ continue;
+ if (auto numthreadsDecor = func->findDecoration<IRNumThreadsDecoration>())
+ {
+ auto firstBlock = func->getFirstBlock();
+ if (!firstBlock)
+ continue;
+ builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
+ IRInst* args[] = {
+ numthreadsDecor->getExtentAlongAxis(0),
+ numthreadsDecor->getExtentAlongAxis(1),
+ numthreadsDecor->getExtentAlongAxis(2) };
+ auto workgroupSize = builder.emitMakeVector(
+ workgroupSizeInst->getFullType(), 3, args);
+ builder.emitStore(globalVar, workgroupSize);
+ }
+ }
+
+ workgroupSizeInst->removeAndDeallocate();
+ }
};
void translateGLSLGlobalVar(CodeGenContext* context, IRModule* module)
diff --git a/tests/spirv/subgroup-size-2.slang b/tests/spirv/subgroup-size-2.slang
new file mode 100644
index 000000000..68fee6fe6
--- /dev/null
+++ b/tests/spirv/subgroup-size-2.slang
@@ -0,0 +1,33 @@
+// Test that using workgroup size from more than one entrypoint result in
+// correct lowering into global variable.
+
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -fvk-use-entrypoint-name -O0
+
+RWStructuredBuffer<int> outputBuffer;
+
+uint3 f() { return WorkgroupSize(); }
+
+[shader("compute")]
+[numthreads(1u, 2u, 3)]
+void compute1()
+{
+ // CHECK-DAG: %[[VAR:[A-Za-z0-9_]+]] = OpVariable %_ptr_Private_v3int Private
+ // CHECK: OpStore %[[VAR]]
+
+ // CHECK-DAG: %[[CALL_RS:[A-Za-z0-9_]+]] = OpFunctionCall %v3uint %f
+ // CHECK: OpCompositeExtract %uint %[[CALL_RS]] 0
+ const int x = f().x;
+ outputBuffer[0] = x;
+
+ // CHECK-DAG: %[[PTR:[A-Za-z0-9_]+]] = OpAccessChain %_ptr_StorageBuffer_int %outputBuffer %int_0 %uint_1
+ // CHECK: OpStore %[[PTR]] %int_2
+ outputBuffer[1] = WorkgroupSize().y;
+}
+
+[shader("compute")]
+[numthreads(4, 5, 6)]
+void compute2()
+{
+ const int x = f().x;
+ outputBuffer[0] = x;
+} \ No newline at end of file
diff --git a/tests/spirv/subgroup-size.slang b/tests/spirv/subgroup-size.slang
new file mode 100644
index 000000000..c2ed4a3d8
--- /dev/null
+++ b/tests/spirv/subgroup-size.slang
@@ -0,0 +1,16 @@
+import "glsl";
+
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -O0
+
+// CHECK-DAG: %[[CONST:[A-Za-z0-9_]+]] = OpConstantComposite %v3int %int_1 %int_2 %int_3
+// CHECK: OpBitcast %v3uint %[[CONST]]
+
+RWStructuredBuffer<int> outputBuffer;
+
+[shader("compute")]
+[numthreads(1u, 2u, 3)]
+void compute()
+{
+ const int x = gl_WorkGroupSize.x;
+ outputBuffer[0] = x;
+} \ No newline at end of file