diff options
| -rw-r--r-- | source/slang/glsl.meta.slang | 23 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 5 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-translate-glsl-global-var.cpp | 83 | ||||
| -rw-r--r-- | tests/spirv/subgroup-size-2.slang | 33 | ||||
| -rw-r--r-- | tests/spirv/subgroup-size.slang | 16 |
7 files changed, 156 insertions, 15 deletions
diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index e915442c7..54f62308f 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -124,20 +124,13 @@ public property uint3 gl_NumWorkGroups { } } -[require(glsl)] -[require(spirv)] -public property uint3 gl_WorkGroupSize { - - get { - __target_switch - { - case glsl: - __intrinsic_asm "(gl_WorkGroupSize)"; - case spirv: - return spirv_asm { - result:$$uint3 = OpLoad builtin(WorkgroupSize:uint3); - }; - } +[require(compute)] +public property uint3 gl_WorkGroupSize +{ + [__unsafeForceInlineEarly] + get + { + return WorkgroupSize(); } } @@ -7902,4 +7895,4 @@ __spirv_version(1.3) public bool allInvocationsEqual(bool value) { return WaveActiveAllEqual(value); -}
\ No newline at end of file +} diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index fa2a7ccb1..ef5c3ae5d 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -3483,6 +3483,11 @@ void AllMemoryBarrierWithGroupSync() } } +// Returns the workgroup size of the calling entry point. +[require(compute)] +__intrinsic_op($(kIROp_GetWorkGroupSize)) +int3 WorkgroupSize(); + // Test if any components is non-zero (HLSL SM 1.0) __generic<T : __BuiltinType> diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index af47eaedb..6d39f977c 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -348,8 +348,15 @@ namespace Slang getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, constIntVal->getValue()); return false; } + if (intValue->getType() != m_astBuilder->getIntType()) + { + intValue = m_astBuilder->getIntVal(m_astBuilder->getIntType(), constIntVal->getValue()); + } } + // Make sure we always canonicalize the type to int. value = intValue; + if (value->getType() != m_astBuilder->getIntType()) + value = m_astBuilder->getTypeCastIntVal(m_astBuilder->getIntType(), value); } else { diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 25f331708..5acb22674 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -359,6 +359,10 @@ INST(UnpackAnyValue, unpackAnyValue, 1, 0) INST(WitnessTableEntry, witness_table_entry, 2, 0) INST(InterfaceRequirementEntry, interface_req_entry, 2, GLOBAL) +// An inst to represent the workgroup size of the calling entry point. +// We will materialize this inst during `translateGLSLGlobalVar`. +INST(GetWorkGroupSize, kIROp_GetWorkGroupSize, 0, HOISTABLE) + INST(Param, param, 0, 0) INST(StructField, field, 2, 0) INST(Var, var, 0, 0) diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp index 57bd71418..575a66457 100644 --- a/source/slang/slang-ir-translate-glsl-global-var.cpp +++ b/source/slang/slang-ir-translate-glsl-global-var.cpp @@ -18,10 +18,15 @@ namespace Slang buildEntryPointReferenceGraph(referencingEntryPoints, module); List<IRInst*> entryPoints; + // Traverse the module to find all entry points. + // If we see a `GetWorkGroupSize` instruction, we will materialize it. + // for (auto inst : module->getGlobalInsts()) { if (inst->getOp() == kIROp_Func && inst->findDecoration<IREntryPointDecoration>()) entryPoints.add(inst); + else if (inst->getOp() == kIROp_GetWorkGroupSize) + materializeGetWorkGroupSize(module, referencingEntryPoints, inst); } IRBuilder builder(module); @@ -223,6 +228,84 @@ namespace Slang entryPointFunc->setFullType(newFuncType); } } + + // If we see a `GetWorkGroupSize` instruction, we should materialize it by replacing its uses with a constant + // that represent the workgroup size of the calling entrypoint. + // This is trivial if the `GetWorkGroupSize` instruction is used from a function called by one entry point. + // If it is used in a place reachable from multiple entry points, we will introduce a global variable to represent + // the workgroup size, and replace the uses with a load from the global variable. + // We will assign the value of the global variable at the start of each entry point. + // + void materializeGetWorkGroupSize(IRModule* module, Dictionary<IRInst*, HashSet<IRFunc*>>& referenceGraph, IRInst* workgroupSizeInst) + { + IRBuilder builder(workgroupSizeInst); + traverseUses(workgroupSizeInst, [&](IRUse* use) + { + if (auto parentFunc = getParentFunc(use->getUser())) + { + auto referenceSet = referenceGraph.tryGetValue(parentFunc); + if (!referenceSet) + return; + if (referenceSet->getCount() == 1) + { + // If the function that uses the workgroup size is only used by one entry point, + // we can materialize the workgroup size by substituting the use with a constant. + auto entryPoint = *referenceSet->begin(); + auto numthreadsDecor = entryPoint->findDecoration<IRNumThreadsDecoration>(); + if (!numthreadsDecor) + return; + builder.setInsertBefore(use->getUser()); + IRInst* values[] = { + numthreadsDecor->getExtentAlongAxis(0), + numthreadsDecor->getExtentAlongAxis(1), + numthreadsDecor->getExtentAlongAxis(2) }; + auto workgroupSize = builder.emitMakeVector(builder.getVectorType(builder.getIntType(), 3), + 3, values); + builder.replaceOperand(use, workgroupSize); + } + } + }); + + // If workgroupSizeInst still has uses, it means it is used by multiple entry points. + // We need to introduce a global variable and assign value to it in each entry point. + + if (!workgroupSizeInst->hasUses()) + return; + builder.setInsertBefore(workgroupSizeInst); + auto globalVar = builder.createGlobalVar(workgroupSizeInst->getFullType()); + + // Replace all remaining uses of the workgroupSize inst of a load from globalVar. + traverseUses(workgroupSizeInst, [&](IRUse* use) + { + builder.setInsertBefore(use->getUser()); + auto load = builder.emitLoad(globalVar); + builder.replaceOperand(use, load); + }); + + // Now insert assignments from each entry point. + for (auto globalInst : module->getGlobalInsts()) + { + auto func = as<IRFunc>(getResolvedInstForDecorations(globalInst)); + if (!func) + continue; + if (auto numthreadsDecor = func->findDecoration<IRNumThreadsDecoration>()) + { + auto firstBlock = func->getFirstBlock(); + if (!firstBlock) + continue; + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + IRInst* args[] = { + numthreadsDecor->getExtentAlongAxis(0), + numthreadsDecor->getExtentAlongAxis(1), + numthreadsDecor->getExtentAlongAxis(2) }; + auto workgroupSize = builder.emitMakeVector( + workgroupSizeInst->getFullType(), 3, args); + builder.emitStore(globalVar, workgroupSize); + } + } + + workgroupSizeInst->removeAndDeallocate(); + } }; void translateGLSLGlobalVar(CodeGenContext* context, IRModule* module) diff --git a/tests/spirv/subgroup-size-2.slang b/tests/spirv/subgroup-size-2.slang new file mode 100644 index 000000000..68fee6fe6 --- /dev/null +++ b/tests/spirv/subgroup-size-2.slang @@ -0,0 +1,33 @@ +// Test that using workgroup size from more than one entrypoint result in +// correct lowering into global variable. + +//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -fvk-use-entrypoint-name -O0 + +RWStructuredBuffer<int> outputBuffer; + +uint3 f() { return WorkgroupSize(); } + +[shader("compute")] +[numthreads(1u, 2u, 3)] +void compute1() +{ + // CHECK-DAG: %[[VAR:[A-Za-z0-9_]+]] = OpVariable %_ptr_Private_v3int Private + // CHECK: OpStore %[[VAR]] + + // CHECK-DAG: %[[CALL_RS:[A-Za-z0-9_]+]] = OpFunctionCall %v3uint %f + // CHECK: OpCompositeExtract %uint %[[CALL_RS]] 0 + const int x = f().x; + outputBuffer[0] = x; + + // CHECK-DAG: %[[PTR:[A-Za-z0-9_]+]] = OpAccessChain %_ptr_StorageBuffer_int %outputBuffer %int_0 %uint_1 + // CHECK: OpStore %[[PTR]] %int_2 + outputBuffer[1] = WorkgroupSize().y; +} + +[shader("compute")] +[numthreads(4, 5, 6)] +void compute2() +{ + const int x = f().x; + outputBuffer[0] = x; +}
\ No newline at end of file diff --git a/tests/spirv/subgroup-size.slang b/tests/spirv/subgroup-size.slang new file mode 100644 index 000000000..c2ed4a3d8 --- /dev/null +++ b/tests/spirv/subgroup-size.slang @@ -0,0 +1,16 @@ +import "glsl"; + +//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -O0 + +// CHECK-DAG: %[[CONST:[A-Za-z0-9_]+]] = OpConstantComposite %v3int %int_1 %int_2 %int_3 +// CHECK: OpBitcast %v3uint %[[CONST]] + +RWStructuredBuffer<int> outputBuffer; + +[shader("compute")] +[numthreads(1u, 2u, 3)] +void compute() +{ + const int x = gl_WorkGroupSize.x; + outputBuffer[0] = x; +}
\ No newline at end of file |
