From 66984eb856454d0a372e3b30643823af18612067 Mon Sep 17 00:00:00 2001 From: Darren Wihandi <65404740+fairywreath@users.noreply.github.com> Date: Fri, 28 Feb 2025 13:11:26 -0500 Subject: Add WaveGetLane* support for Metal and WGSL (#6371) * support WaveGetLane* for WGSL and Metal * update test and glsl support * address review comments and fix metal test * add missing pragma guard * update test * Revert "update test" This reverts commit f2b97e91c29de154190710580c343bd0764aedbb. * update failing glsl metal test and added new test * make hlsl and glsl outputs similar * update test * disable tests for Metal and cleanup * comment fix * add expected failures * correct expected failures list * remove expected failure * add tests to expected failure --------- Co-authored-by: Yong He --- source/slang/core.meta.slang | 4 +- source/slang/glsl.meta.slang | 56 +-- source/slang/hlsl.meta.slang | 88 +++-- source/slang/slang-core-module-textures.cpp | 2 +- source/slang/slang-emit-c-like.cpp | 9 +- source/slang/slang-emit-c-like.h | 2 + source/slang/slang-emit-glsl.cpp | 2 +- source/slang/slang-emit-wgsl.cpp | 5 + source/slang/slang-emit-wgsl.h | 2 + source/slang/slang-emit.cpp | 10 +- source/slang/slang-ir-call-graph.h | 3 + source/slang/slang-ir-inst-defs.h | 4 +- source/slang/slang-ir-insts.h | 4 +- source/slang/slang-ir-legalize-varying-params.cpp | 29 +- source/slang/slang-ir-legalize-varying-params.h | 2 + .../slang-ir-translate-global-varying-var.cpp | 382 +++++++++++++++++++++ .../slang/slang-ir-translate-global-varying-var.h | 14 + .../slang/slang-ir-translate-glsl-global-var.cpp | 382 --------------------- source/slang/slang-ir-translate-glsl-global-var.h | 17 - 19 files changed, 538 insertions(+), 479 deletions(-) create mode 100644 source/slang/slang-ir-translate-global-varying-var.cpp create mode 100644 source/slang/slang-ir-translate-global-varying-var.h delete mode 100644 source/slang/slang-ir-translate-glsl-global-var.cpp delete mode 100644 source/slang/slang-ir-translate-glsl-global-var.h (limited to 'source') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index da1b47e13..e2fb8bbf2 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -440,8 +440,8 @@ attribute_syntax [Differentiable(order:int = 0)] : BackwardDifferentiableAttribu __intrinsic_op($(kIROp_RequirePrelude)) void __requirePrelude(constexpr String preludeText); -__intrinsic_op($(kIROp_RequireGLSLExtension)) -void __requireGLSLExtension(constexpr String preludeText); +__intrinsic_op($(kIROp_RequireTargetExtension)) +void __requireTargetExtension(constexpr String preludeText); /// @experimetal /// Perform a compile-time condition check and emit a compile-time error if the condition is false. diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index eed6cc690..2a89f2b66 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -4296,7 +4296,7 @@ __generic case glsl: { if (__type_equals()) - __requireGLSLExtension("GL_EXT_shader_atomic_float"); + __requireTargetExtension("GL_EXT_shader_atomic_float"); } case spirv: if (__type_equals()) @@ -4318,7 +4318,7 @@ __generic case glsl: { if (__type_equals()) - __requireGLSLExtension("GL_EXT_shader_atomic_float2"); + __requireTargetExtension("GL_EXT_shader_atomic_float2"); } case spirv: if (__type_equals()) @@ -4758,7 +4758,7 @@ void requireGLSLExtForRayTracingBuiltin() __target_switch { case glsl: - __requireGLSLExtension("GL_EXT_ray_tracing"); + __requireTargetExtension("GL_EXT_ray_tracing"); __intrinsic_asm ""; default: return; @@ -6304,22 +6304,22 @@ public void traceRayMotionNV( __generic [ForceInline] void typeRequireChecks_shader_subgroup_GLSL() { - // the following is a seperate function call, since else the `__requireGLSLExtension` and associated __intrinsic_asm is ignored if the calling function also calls an __intrinsic_asm + // the following is a seperate function call, since else the `__requireTargetExtension` and associated __intrinsic_asm is ignored if the calling function also calls an __intrinsic_asm __target_switch { case glsl: if (__type_equals() || __type_equals() - ) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + ) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); else if (__type_equals() || __type_equals() - ) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_int8"); + ) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_int8"); else if (__type_equals() || __type_equals() - ) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_int16"); + ) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_int16"); else if (__type_equals() || __type_equals() - ) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_int64"); + ) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_int64"); __intrinsic_asm ""; } @@ -6327,7 +6327,7 @@ void typeRequireChecks_shader_subgroup_GLSL() { __generic void shader_subgroup_preamble() { - // checks needed for shader_subgroup functions; __requireGLSLExtension does not work + // checks needed for shader_subgroup functions; __requireTargetExtension does not work // (does not add the ext specified correctly to the compile output; using extended type // will result in error for using the type) __target_switch @@ -6347,14 +6347,14 @@ void requireGLSLExtForSubgroupBasicBuiltin() { __target_switch { case glsl: - __requireGLSLExtension("GL_KHR_shader_subgroup_basic"); + __requireTargetExtension("GL_KHR_shader_subgroup_basic"); __intrinsic_asm ""; default: return; } } -[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)] void setupExtForSubgroupBasicBuiltIn() { __target_switch { @@ -6371,7 +6371,7 @@ void requireGLSLExtForSubgroupBallotBuiltin() { __target_switch { case glsl: - __requireGLSLExtension("GL_KHR_shader_subgroup_ballot"); + __requireTargetExtension("GL_KHR_shader_subgroup_ballot"); __intrinsic_asm ""; default: return; @@ -6429,7 +6429,8 @@ public property uint gl_SubgroupID public property uint gl_SubgroupSize { - [require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] + [ForceInline] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)] get { setupExtForSubgroupBasicBuiltIn(); return WaveGetLaneCount(); @@ -6438,7 +6439,8 @@ public property uint gl_SubgroupSize public property uint gl_SubgroupInvocationID { - [require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] + [ForceInline] + [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)] get { setupExtForSubgroupBasicBuiltIn(); return WaveGetLaneIndex(); @@ -8388,8 +8390,8 @@ void typeRequireChecks_atomic_using_float0_tier() { case glsl: { - if (__type_equals() || __type_equals()) - __requireGLSLExtension("GL_EXT_shader_atomic_int64"); + if (__type_equals() || __type_equals()) + __requireTargetExtension("GL_EXT_shader_atomic_int64"); } case spirv: return; @@ -8405,16 +8407,16 @@ void typeRequireChecks_atomic_using_float1_tier() case glsl: { if (__type_equals()) - __requireGLSLExtension("GL_EXT_shader_atomic_float"); + __requireTargetExtension("GL_EXT_shader_atomic_float"); else if (__type_equals() || __type_equals()) { - __requireGLSLExtension("GL_EXT_shader_atomic_float2"); - __requireGLSLExtension("GL_EXT_shader_explicit_arithmetic_types"); + __requireTargetExtension("GL_EXT_shader_atomic_float2"); + __requireTargetExtension("GL_EXT_shader_explicit_arithmetic_types"); } else if (__type_equals()) - __requireGLSLExtension("GL_EXT_shader_atomic_float"); + __requireTargetExtension("GL_EXT_shader_atomic_float"); else if (__type_equals() || __type_equals()) - __requireGLSLExtension("GL_EXT_shader_atomic_int64"); + __requireTargetExtension("GL_EXT_shader_atomic_int64"); } case spirv: return; @@ -8430,16 +8432,16 @@ void typeRequireChecks_atomic_using_float2_tier() case glsl: { if (__type_equals()) - __requireGLSLExtension("GL_EXT_shader_atomic_float2"); + __requireTargetExtension("GL_EXT_shader_atomic_float2"); else if (__type_equals() || __type_equals()) { - __requireGLSLExtension("GL_EXT_shader_atomic_float2"); - __requireGLSLExtension("GL_EXT_shader_explicit_arithmetic_types"); + __requireTargetExtension("GL_EXT_shader_atomic_float2"); + __requireTargetExtension("GL_EXT_shader_explicit_arithmetic_types"); } else if (__type_equals()) - __requireGLSLExtension("GL_EXT_shader_atomic_float2"); - else if (__type_equals() || __type_equals()) - __requireGLSLExtension("GL_EXT_shader_atomic_int64"); + __requireTargetExtension("GL_EXT_shader_atomic_float2"); + else if (__type_equals() || __type_equals()) + __requireTargetExtension("GL_EXT_shader_atomic_int64"); } case spirv: return; diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index a2b685b69..c9f3fb533 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -3,8 +3,14 @@ typedef uint UINT; -__intrinsic_op($(kIROp_RequireGLSLExtension)) -void __requireGLSLExtension(String extensionName); +__intrinsic_op($(kIROp_RequireTargetExtension)) +void __requireTargetExtension(constexpr String extensionName); + +/// Built-in values or system value semantics represented as in/out global variables. +/// This allows the built-ins to be arbitrarily used from a global scope without being +/// explicitly passed as entry point parameters. +in uint __builtinWaveLaneIndex : SV_WaveLaneIndex; +in uint __builtinWaveLaneCount : SV_WaveLaneCount; //@public: /// Represents an interface for buffer data layout. @@ -3505,7 +3511,7 @@ extension _Texture __intrinsic_asm ""; case glsl: if (isCombined == 0) - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); __intrinsic_asm "$ctexelFetch($0, ($1).$w1b, ($1).$w1e)$z"; case spirv: const int lodLoc = Shape.dimensions+isArray; @@ -3569,7 +3575,7 @@ extension _Texture __intrinsic_asm ".Load"; case glsl: if (isCombined == 0) - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); __intrinsic_asm "$ctexelFetchOffset($0, ($1).$w1b, ($1).$w1e, ($2))$z"; case spirv: const int lodLoc = Shape.dimensions+isArray; @@ -3625,7 +3631,7 @@ extension _Texture return Load(__makeVector(location, 0)); case glsl: if (isCombined == 0) - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); return Load(__makeVector(location, 0)); case spirv: @@ -3702,7 +3708,7 @@ extension _Texture __intrinsic_asm ""; case glsl: if (isCombined == 0) - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); __intrinsic_asm "$ctexelFetch($0, $1, ($2))$z"; case spirv: if (isCombined != 0) @@ -3752,7 +3758,7 @@ extension _Texture __intrinsic_asm ".Load"; case glsl: if (isCombined == 0) - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); __intrinsic_asm "$ctexelFetchOffset($0, $1, ($2), ($3))$z"; case spirv: if (isCombined != 0) @@ -3807,7 +3813,7 @@ extension _Texture return Load(location, 0); case glsl: if (isCombined == 0) - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); return Load(location, 0); } } @@ -3830,7 +3836,7 @@ extension _Texture return Load(location, sampleIndex); case glsl: if (isCombined == 0) - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); return Load(location, sampleIndex); } } @@ -13913,7 +13919,7 @@ T WaveMaskSum(WaveMask mask, T expr) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupAdd($1)"; case cuda: __intrinsic_asm "_waveSum($0, $1)"; case hlsl: __intrinsic_asm "WaveActiveSum($1)"; @@ -13940,7 +13946,7 @@ vector WaveMaskSum(WaveMask mask, vector expr) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupAdd($1)"; case cuda: __intrinsic_asm "_waveSumMultiple($0, $1)"; case hlsl: __intrinsic_asm "WaveActiveSum($1)"; @@ -13979,7 +13985,7 @@ bool WaveMaskAllEqual(WaveMask mask, T value) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupAllEqual($1)"; case hlsl: __intrinsic_asm "WaveActiveAllEqual($1)"; @@ -14003,7 +14009,7 @@ bool WaveMaskAllEqual(WaveMask mask, vector value) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupAllEqual($1)"; case hlsl: __intrinsic_asm "WaveActiveAllEqual($1)"; @@ -14040,7 +14046,7 @@ T WaveMaskPrefixProduct(WaveMask mask, T expr) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupExclusiveMul($1)"; case cuda: __intrinsic_asm "_wavePrefixProduct($0, $1)"; case hlsl: __intrinsic_asm "WavePrefixProduct($1)"; @@ -14067,7 +14073,7 @@ vector WaveMaskPrefixProduct(WaveMask mask, vector expr) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupExclusiveMul($1)"; case cuda: __intrinsic_asm "_wavePrefixProductMultiple($0, $1)"; case hlsl: __intrinsic_asm "WavePrefixProduct($1)"; @@ -14105,7 +14111,7 @@ T WaveMaskPrefixSum(WaveMask mask, T expr) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupExclusiveAdd($1)"; case cuda: __intrinsic_asm "_wavePrefixSum($0, $1)"; case hlsl: __intrinsic_asm "WavePrefixSum($1)"; @@ -14133,7 +14139,7 @@ vector WaveMaskPrefixSum(WaveMask mask, vector expr) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupExclusiveAdd($1)"; case cuda: __intrinsic_asm "_wavePrefixSumMultiple($0, $1)"; case hlsl: __intrinsic_asm "WavePrefixSum($1)"; @@ -14761,7 +14767,7 @@ T WaveActive$(opName.hlslName)(T expr) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroup$(opName.glslName)($0)"; case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; case metal: __intrinsic_asm "simd_$(opName.metalName)"; @@ -14796,7 +14802,7 @@ vector WaveActive$(opName.hlslName)(vector expr) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroup$(opName.glslName)($0)"; case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; case metal: __intrinsic_asm "simd_$(opName.metalName)"; @@ -15018,7 +15024,8 @@ uint WaveActiveCountBits(bool value) __glsl_extension(GL_KHR_shader_subgroup_basic) __spirv_version(1.3) [NonUniformReturn] -[require(cuda_glsl_hlsl_spirv, subgroup_basic)] +[ForceInline] +[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)] uint WaveGetLaneCount() { __target_switch @@ -15032,6 +15039,11 @@ uint WaveGetLaneCount() OpCapability GroupNonUniform; result:$$uint = OpLoad builtin(SubgroupSize:uint) }; + case metal: + return __builtinWaveLaneCount; + case wgsl: + __requireTargetExtension("subgroups"); + return __builtinWaveLaneCount; } } @@ -15039,7 +15051,8 @@ uint WaveGetLaneCount() __glsl_extension(GL_KHR_shader_subgroup_basic) __spirv_version(1.3) [NonUniformReturn] -[require(cuda_glsl_hlsl_spirv, subgroup_basic)] +[ForceInline] +[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)] uint WaveGetLaneIndex() { __target_switch @@ -15053,6 +15066,11 @@ uint WaveGetLaneIndex() OpCapability GroupNonUniform; result:$$uint = OpLoad builtin(SubgroupLocalInvocationId:uint) }; + case metal: + return __builtinWaveLaneIndex; + case wgsl: + __requireTargetExtension("subgroups"); + return __builtinWaveLaneIndex; } } @@ -15122,7 +15140,7 @@ T WavePrefixProduct(T expr) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupExclusiveMul($0)"; case hlsl: __intrinsic_asm "WavePrefixProduct"; case metal: __intrinsic_asm "simd_prefix_exclusive_product"; @@ -15158,7 +15176,7 @@ vector WavePrefixProduct(vector expr) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupExclusiveMul($0)"; case hlsl: __intrinsic_asm "WavePrefixProduct"; case metal: __intrinsic_asm "simd_prefix_exclusive_product"; @@ -15209,7 +15227,7 @@ T WavePrefixSum(T expr) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupExclusiveAdd($0)"; case hlsl: __intrinsic_asm "WavePrefixSum"; case metal: __intrinsic_asm "simd_prefix_exclusive_sum"; @@ -15241,7 +15259,7 @@ vector WavePrefixSum(vector expr) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupExclusiveAdd($0)"; case hlsl: __intrinsic_asm "WavePrefixSum"; case metal: __intrinsic_asm "simd_prefix_exclusive_sum"; @@ -15292,7 +15310,7 @@ T WaveReadLaneFirst(T expr) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupBroadcastFirst($0)"; case hlsl: __intrinsic_asm "WaveReadLaneFirst"; case metal: __intrinsic_asm "simd_broadcast_first"; @@ -15314,7 +15332,7 @@ vector WaveReadLaneFirst(vector expr) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupBroadcastFirst($0)"; case hlsl: __intrinsic_asm "WaveReadLaneFirst"; case metal: __intrinsic_asm "simd_broadcast_first"; @@ -15360,7 +15378,7 @@ T WaveBroadcastLaneAt(T value, constexpr int lane) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupBroadcast($0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneAt"; case metal: __intrinsic_asm "simd_broadcast($0, ushort($1))"; @@ -15384,7 +15402,7 @@ vector WaveBroadcastLaneAt(vector value, constexpr int lane) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupBroadcast($0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneAt"; case metal: __intrinsic_asm "simd_broadcast($0, ushort($1))"; @@ -15426,7 +15444,7 @@ T WaveReadLaneAt(T value, int lane) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupShuffle($0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneAt"; case metal: __intrinsic_asm "simd_shuffle($0, ushort($1))"; @@ -15449,7 +15467,7 @@ vector WaveReadLaneAt(vector value, int lane) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupShuffle($0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneAt"; case metal: __intrinsic_asm "simd_shuffle($0, ushort($1))"; @@ -15492,7 +15510,7 @@ T WaveShuffle(T value, int lane) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupShuffle($0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneAt"; case metal: __intrinsic_asm "simd_shuffle($0, ushort($1))"; @@ -15516,7 +15534,7 @@ vector WaveShuffle(vector value, int lane) __target_switch { case glsl: - if (__isHalf()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16"); + if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); __intrinsic_asm "subgroupShuffle($0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneAt"; case metal: __intrinsic_asm "simd_shuffle($0, ushort($1))"; @@ -16158,7 +16176,7 @@ extension _Texture { case hlsl: __intrinsic_asm ".GetDimensions"; case glsl: - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); __intrinsic_asm "($1 = $(glslTextureSizeFunc)($0))"; case metal: __intrinsic_asm "(*($1) = $0.get_width())"; case spirv: @@ -16178,7 +16196,7 @@ extension _Texture case hlsl: __intrinsic_asm ".Load"; case metal: __intrinsic_asm "$c$0.read(uint($1))$z"; case glsl: - __requireGLSLExtension("GL_EXT_samplerless_texture_functions"); + __requireTargetExtension("GL_EXT_samplerless_texture_functions"); __intrinsic_asm "$(glslLoadFuncName)($0, $1)$z"; case spirv: return spirv_asm { %sampled:__sampledType(T) = $(spvLoadInstName) $this $location; diff --git a/source/slang/slang-core-module-textures.cpp b/source/slang/slang-core-module-textures.cpp index 22c1fc63f..f703a8a3b 100644 --- a/source/slang/slang-core-module-textures.cpp +++ b/source/slang/slang-core-module-textures.cpp @@ -439,7 +439,7 @@ void TextureTypeInfo::writeGetDimensionFunctions() } }; glsl << "if (isCombined == 0) { " - "__requireGLSLExtension(\"GL_EXT_samplerless_texture_functions\"); }\n"; + "__requireTargetExtension(\"GL_EXT_samplerless_texture_functions\"); }\n"; glsl << "if (access == " << kCoreModule_ResourceAccessReadOnly << ") __intrinsic_asm \""; emitIntrinsic(toSlice("textureSize"), !isMultisample); diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 946e9c429..1c48d98ef 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -3061,10 +3061,6 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO m_requiredPreludes.add(preludeTextInst); break; } - case kIROp_RequireGLSLExtension: - { - break; // should already have set requirement; case covered for empty intrinsic block - } case kIROp_RequireComputeDerivative: { break; // should already have been parsed and used. @@ -3074,6 +3070,11 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO emitOperand(as(inst)->getOperand(0), getInfo(EmitOp::General)); break; } + case kIROp_RequireTargetExtension: + { + emitRequireExtension(as(inst)); + break; + } default: diagnoseUnhandledInst(inst); break; diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 6fe7f5d34..ca915ab2d 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -678,6 +678,8 @@ protected: void _emitCallArgList(IRCall* call, int startingOperandIndex = 1); virtual void emitCallArg(IRInst* arg); + virtual void emitRequireExtension(IRRequireTargetExtension* inst) { SLANG_UNUSED(inst); } + String _generateUniqueName(const UnownedStringSlice& slice); // Sort witnessTable entries according to the order defined in the witnessed interface type. diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 776c539b4..696830bf2 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -30,7 +30,7 @@ void GLSLSourceEmitter::_beforeComputeEmitProcessInstruction( IRInst* inst, IRBuilder& builder) { - if (auto requireGLSLExt = as(inst)) + if (auto requireGLSLExt = as(inst)) { _requireGLSLExtension(requireGLSLExt->getExtensionName()); return; diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index 13c79e9ac..7c83b194d 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -1696,4 +1696,9 @@ void WGSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst) } } +void WGSLSourceEmitter::emitRequireExtension(IRRequireTargetExtension* inst) +{ + _requireExtension(inst->getExtensionName()); +} + } // namespace Slang diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h index 441933b57..a29f39a1d 100644 --- a/source/slang/slang-emit-wgsl.h +++ b/source/slang/slang-emit-wgsl.h @@ -57,6 +57,8 @@ public: EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE; virtual void emitGlobalParamDefaultVal(IRGlobalParam* varDecl) SLANG_OVERRIDE; + virtual void emitRequireExtension(IRRequireTargetExtension* inst) SLANG_OVERRIDE; + virtual void handleRequiredCapabilitiesImpl(IRInst* inst) SLANG_OVERRIDE; void emit(const AddressSpace addressSpace); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 847c5b55c..ddb4ea67a 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -100,7 +100,7 @@ #include "slang-ir-strip-default-construct.h" #include "slang-ir-strip-legalization-insts.h" #include "slang-ir-synthesize-active-mask.h" -#include "slang-ir-translate-glsl-global-var.h" +#include "slang-ir-translate-global-varying-var.h" #include "slang-ir-uniformity.h" #include "slang-ir-user-type-hint.h" #include "slang-ir-validate.h" @@ -318,7 +318,7 @@ struct RequiredLoweringPassSet bool bindingQuery; bool meshOutput; bool higherOrderFunc; - bool glslGlobalVar; + bool globalVaryingVar; bool glslSSBO; bool byteAddressBuffer; bool dynamicResource; @@ -422,7 +422,7 @@ void calcRequiredLoweringPassSet( case kIROp_GlobalInputDecoration: case kIROp_GlobalOutputDecoration: case kIROp_GetWorkGroupSize: - result.glslGlobalVar = true; + result.globalVaryingVar = true; break; case kIROp_BindExistentialSlotsDecoration: result.bindExistential = true; @@ -667,8 +667,8 @@ Result linkAndOptimizeIR( if (!isKhronosTarget(targetRequest) && requiredLoweringPassSet.glslSSBO) lowerGLSLShaderStorageBufferObjectsToStructuredBuffers(irModule, sink); - if (requiredLoweringPassSet.glslGlobalVar) - translateGLSLGlobalVar(codeGenContext, irModule); + if (requiredLoweringPassSet.globalVaryingVar) + translateGlobalVaryingVar(codeGenContext, irModule); if (requiredLoweringPassSet.resolveVaryingInputRef) resolveVaryingInputRef(irModule); diff --git a/source/slang/slang-ir-call-graph.h b/source/slang/slang-ir-call-graph.h index 4ee642356..b7290ef79 100644 --- a/source/slang/slang-ir-call-graph.h +++ b/source/slang/slang-ir-call-graph.h @@ -1,3 +1,6 @@ +// slang-ir-call-graph.h +#pragma once + #include "slang-ir-clone.h" #include "slang-ir-insts.h" diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 3e2872cb7..714ba146d 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -407,7 +407,7 @@ INST(WitnessTableEntry, witness_table_entry, 2, 0) INST(InterfaceRequirementEntry, interface_req_entry, 2, GLOBAL) // An inst to represent the workgroup size of the calling entry point. -// We will materialize this inst during `translateGLSLGlobalVar`. +// We will materialize this inst during `translateGlobalVaryingVar`. INST(GetWorkGroupSize, GetWorkGroupSize, 0, HOISTABLE) // An inst that returns the current stage of the calling entry point. @@ -666,7 +666,7 @@ INST_RANGE(TerminatorInst, Return, Unreachable) INST(discard, discard, 0, 0) INST(RequirePrelude, RequirePrelude, 1, 0) -INST(RequireGLSLExtension, RequireGLSLExtension, 1, 0) +INST(RequireTargetExtension, RequireTargetExtension, 1, 0) INST(RequireComputeDerivative, RequireComputeDerivative, 0, 0) INST(StaticAssert, StaticAssert, 2, 0) INST(Printf, Printf, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 4efb7d671..5231592ca 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3506,9 +3506,9 @@ struct IRRequirePrelude : IRInst UnownedStringSlice getPrelude() { return as(getOperand(0))->getStringSlice(); } }; -struct IRRequireGLSLExtension : IRInst +struct IRRequireTargetExtension : IRInst { - IR_LEAF_ISA(RequireGLSLExtension) + IR_LEAF_ISA(RequireTargetExtension) UnownedStringSlice getExtensionName() { return as(getOperand(0))->getStringSlice(); diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 3b65ee59a..e744969db 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -3228,6 +3228,20 @@ protected: result.permittedTypes.add(builder.getBasicType(BaseType::UInt)); break; } + case SystemValueSemanticName::WaveLaneCount: + { + result.systemValueName = toSlice("threads_per_simdgroup"); + result.permittedTypes.add(builder.getUIntType()); + result.permittedTypes.add(builder.getUInt16Type()); + break; + } + case SystemValueSemanticName::WaveLaneIndex: + { + result.systemValueName = toSlice("thread_index_in_simdgroup"); + result.permittedTypes.add(builder.getUIntType()); + result.permittedTypes.add(builder.getUInt16Type()); + break; + } default: m_sink->diagnose( parentVar, @@ -3845,6 +3859,20 @@ protected: break; } + case SystemValueSemanticName::WaveLaneCount: + { + result.systemValueName = toSlice("subgroup_size"); + result.permittedTypes.add(builder.getUIntType()); + break; + } + + case SystemValueSemanticName::WaveLaneIndex: + { + result.systemValueName = toSlice("subgroup_invocation_id"); + result.permittedTypes.add(builder.getUIntType()); + break; + } + case SystemValueSemanticName::ViewID: case SystemValueSemanticName::ViewportArrayIndex: case SystemValueSemanticName::StartVertexLocation: @@ -3853,7 +3881,6 @@ protected: result.isUnsupported = true; break; } - default: { m_sink->diagnose( diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h index e742f3093..0a7c3be8e 100644 --- a/source/slang/slang-ir-legalize-varying-params.h +++ b/source/slang/slang-ir-legalize-varying-params.h @@ -68,6 +68,8 @@ void depointerizeInputParams(IRFunc* entryPoint); M(Target, SV_Target) \ M(StartVertexLocation, SV_StartVertexLocation) \ M(StartInstanceLocation, SV_StartInstanceLocation) \ + M(WaveLaneCount, SV_WaveLaneCount) \ + M(WaveLaneIndex, SV_WaveLaneIndex) \ /* end */ /// A known system-value semantic name that can be applied to a parameter diff --git a/source/slang/slang-ir-translate-global-varying-var.cpp b/source/slang/slang-ir-translate-global-varying-var.cpp new file mode 100644 index 000000000..80f5c42c3 --- /dev/null +++ b/source/slang/slang-ir-translate-global-varying-var.cpp @@ -0,0 +1,382 @@ +#include "slang-ir-translate-global-varying-var.h" + +#include "slang-ir-call-graph.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir.h" + +namespace Slang +{ +struct GlobalVarTranslationContext +{ + CodeGenContext* context; + + void processModule(IRModule* module) + { + Dictionary> referencingEntryPoints; + buildEntryPointReferenceGraph(referencingEntryPoints, module); + + List entryPoints; + List getWorkGroupSizeInsts; + + // Traverse the module to find all entry points. + // If we see a `GetWorkGroupSize` instruction, we will materialize it. + // + for (auto inst : module->getGlobalInsts()) + { + if (inst->getOp() == kIROp_Func && inst->findDecoration()) + entryPoints.add(inst); + else if (inst->getOp() == kIROp_GetWorkGroupSize) + getWorkGroupSizeInsts.add(inst); + } + for (auto inst : getWorkGroupSizeInsts) + materializeGetWorkGroupSize(module, referencingEntryPoints, inst); + IRBuilder builder(module); + + for (auto entryPoint : entryPoints) + { + List outputVars; + List inputVars; + for (auto inst : module->getGlobalInsts()) + { + if (auto referencingEntryPointSet = referencingEntryPoints.tryGetValue(inst)) + { + if (referencingEntryPointSet->contains((IRFunc*)entryPoint)) + { + if (inst->findDecoration()) + { + outputVars.add(inst); + } + if (inst->findDecoration()) + { + inputVars.add(inst); + } + } + } + } + + bool hasInput = inputVars.getCount() != 0; + bool hasOutput = outputVars.getCount() != 0; + + if (!hasInput && !hasOutput) + continue; + + auto entryPointFunc = as(entryPoint); + if (!entryPointFunc) + continue; + + auto entryPointDecor = entryPointFunc->findDecoration(); + + IRVarLayout* resultVarLayout = nullptr; + IRVarLayout* paramLayout = nullptr; + IRType* resultType = entryPointFunc->getResultType(); + + // Create a struct type to receive all inputs. + builder.setInsertBefore(entryPointFunc); + auto inputStructType = builder.createStructType(); + IRStructTypeLayout::Builder inputStructTypeLayoutBuilder(&builder); + UInt inputVarIndex = 0; + List inputKeys; + for (auto input : inputVars) + { + auto inputType = cast(input->getDataType())->getValueType(); + auto key = builder.createStructKey(); + inputKeys.add(key); + builder.createStructField(inputStructType, key, inputType); + + IRTypeLayout::Builder fieldTypeLayoutBuilder(&builder); + IRTypeLayout* fieldTypeLayout = nullptr; + bool hasExistingLayout = false; + if (auto existingLayoutDecoration = input->findDecoration()) + { + if (auto existingVarLayout = + as(existingLayoutDecoration->getLayout())) + { + fieldTypeLayout = existingVarLayout->getTypeLayout(); + hasExistingLayout = true; + } + } + + if (!hasExistingLayout) + { + fieldTypeLayout = fieldTypeLayoutBuilder.build(); + } + + IRVarLayout::Builder varLayoutBuilder(&builder, fieldTypeLayout); + varLayoutBuilder.setStage(entryPointDecor->getProfile().getStage()); + if (auto semanticDecor = input->findDecoration()) + { + varLayoutBuilder.setSystemValueSemantic( + semanticDecor->getSemanticName(), + semanticDecor->getSemanticIndex()); + } + else + { + if (!hasExistingLayout) + { + fieldTypeLayoutBuilder.addResourceUsage( + LayoutResourceKind::VaryingInput, + LayoutSize(1)); + } + if (auto layoutDecor = findVarLayout(input)) + { + if (auto offsetAttr = + layoutDecor->findOffsetAttr(LayoutResourceKind::VaryingInput)) + { + varLayoutBuilder + .findOrAddResourceInfo(LayoutResourceKind::VaryingInput) + ->offset = (UInt)offsetAttr->getOffset(); + } + } + if (entryPointDecor->getProfile().getStage() == Stage::Fragment) + { + varLayoutBuilder.setUserSemantic("COLOR", inputVarIndex); + } + else if (entryPointDecor->getProfile().getStage() == Stage::Vertex) + { + varLayoutBuilder.setUserSemantic("VERTEX_IN_", inputVarIndex); + } + inputVarIndex++; + } + inputStructTypeLayoutBuilder.addField(key, varLayoutBuilder.build()); + input->transferDecorationsTo(key); + } + auto paramTypeLayout = inputStructTypeLayoutBuilder.build(); + IRVarLayout::Builder paramVarLayoutBuilder(&builder, paramTypeLayout); + paramLayout = paramVarLayoutBuilder.build(); + + // Add an entry point parameter for all the inputs. + auto firstBlock = entryPointFunc->getFirstBlock(); + builder.setInsertInto(firstBlock); + auto inputParam = builder.emitParam( + builder.getPtrType(kIROp_ConstRefType, inputStructType, AddressSpace::Input)); + builder.addLayoutDecoration(inputParam, paramLayout); + + // Initialize all global variables in the order of struct member declaration. + for (Index i = inputVars.getCount() - 1; i >= 0; i--) + { + auto input = inputVars[i]; + setInsertBeforeOrdinaryInst(&builder, firstBlock->getFirstOrdinaryInst()); + auto inputType = cast(input->getDataType())->getValueType(); + // TODO: This could be more efficient as a Load(FieldAddress(inputParam, i)) + // operation instead of a FieldExtract(Load(inputParam)). + builder.emitStore( + input, + builder + .emitFieldExtract(inputType, builder.emitLoad(inputParam), inputKeys[i])); + // Relate "global variable" to a "global parameter" for use later in compilation + // to resolve a "global variable" shadowing a "global parameter" relationship. + builder.addGlobalVariableShadowingGlobalParameterDecoration( + inputParam, + input, + inputKeys[i]); + } + + // For each entry point, introduce a new parameter to represent each input parameter, + // and return all outputs via a struct value. + if (hasOutput) + { + // If we have global outputs, the entry-point must not return anything itself. + if (as(entryPoint->getDataType())->getResultType()->getOp() != + kIROp_VoidType) + { + context->getSink()->diagnose( + entryPointFunc, + Diagnostics::entryPointMustReturnVoidWhenGlobalOutputPresent); + continue; + } + builder.setInsertBefore(entryPointFunc); + resultType = builder.createStructType(); + IRStructTypeLayout::Builder typeLayoutBuilder(&builder); + UInt outputVarIndex = 0; + for (auto output : outputVars) + { + auto key = builder.createStructKey(); + auto ptrType = as(output->getDataType()); + builder.createStructField(resultType, key, ptrType->getValueType()); + IRTypeLayout::Builder fieldTypeLayout(&builder); + IRVarLayout::Builder varLayoutBuilder(&builder, fieldTypeLayout.build()); + varLayoutBuilder.setStage(entryPointDecor->getProfile().getStage()); + if (auto semanticDecor = output->findDecoration()) + { + varLayoutBuilder.setSystemValueSemantic( + semanticDecor->getSemanticName(), + semanticDecor->getSemanticIndex()); + } + else + { + fieldTypeLayout.addResourceUsage( + LayoutResourceKind::VaryingOutput, + LayoutSize(1)); + if (auto layoutDecor = findVarLayout(output)) + { + if (auto offsetAttr = + layoutDecor->findOffsetAttr(LayoutResourceKind::VaryingOutput)) + { + varLayoutBuilder + .findOrAddResourceInfo(LayoutResourceKind::VaryingOutput) + ->offset = (UInt)offsetAttr->getOffset(); + } + } + if (entryPointDecor->getProfile().getStage() == Stage::Fragment) + { + varLayoutBuilder.setSystemValueSemantic("SV_TARGET", outputVarIndex); + } + else if (entryPointDecor->getProfile().getStage() == Stage::Vertex) + { + varLayoutBuilder.setUserSemantic("COLOR", outputVarIndex); + } + outputVarIndex++; + } + typeLayoutBuilder.addField(key, varLayoutBuilder.build()); + output->transferDecorationsTo(key); + } + auto resultTypeLayout = typeLayoutBuilder.build(); + IRVarLayout::Builder resultVarLayoutBuilder(&builder, resultTypeLayout); + resultVarLayout = resultVarLayoutBuilder.build(); + + for (auto block : entryPointFunc->getBlocks()) + { + if (auto returnInst = as(block->getTerminator())) + { + // Return the struct value. + builder.setInsertBefore(returnInst); + List fieldVals; + for (auto outputVar : outputVars) + { + auto load = builder.emitLoad(outputVar); + fieldVals.add(load); + } + auto resultVal = builder.emitMakeStruct( + resultType, + (UInt)fieldVals.getCount(), + fieldVals.getBuffer()); + builder.emitReturn(resultVal); + returnInst->removeAndDeallocate(); + } + } + } + if (auto entryPointLayoutDecor = entryPointFunc->findDecoration()) + { + if (auto entryPointLayout = + as(entryPointLayoutDecor->getLayout())) + { + if (paramLayout) + builder.replaceOperand(entryPointLayout->getOperands(), paramLayout); + if (resultVarLayout) + builder.replaceOperand( + entryPointLayout->getOperands() + 1, + resultVarLayout); + } + } + // Update func type for the entry point. + List paramTypes; + for (auto param : entryPointFunc->getParams()) + { + paramTypes.add(param->getDataType()); + } + IRType* newFuncType = builder.getFuncType(paramTypes, resultType); + entryPointFunc->setFullType(newFuncType); + } + } + + // If we see a `GetWorkGroupSize` instruction, we should materialize it by replacing its uses + // with a constant that represent the workgroup size of the calling entrypoint. This is trivial + // if the `GetWorkGroupSize` instruction is used from a function called by one entry point. If + // it is used in a place reachable from multiple entry points, we will introduce a global + // variable to represent the workgroup size, and replace the uses with a load from the global + // variable. We will assign the value of the global variable at the start of each entry point. + // + void materializeGetWorkGroupSize( + IRModule* module, + Dictionary>& referenceGraph, + IRInst* workgroupSizeInst) + { + IRBuilder builder(workgroupSizeInst); + traverseUses( + workgroupSizeInst, + [&](IRUse* use) + { + if (auto parentFunc = getParentFunc(use->getUser())) + { + auto referenceSet = referenceGraph.tryGetValue(parentFunc); + if (!referenceSet) + return; + if (referenceSet->getCount() == 1) + { + // If the function that uses the workgroup size is only used by one entry + // point, we can materialize the workgroup size by substituting the use with + // a constant. + auto entryPoint = *referenceSet->begin(); + auto numthreadsDecor = entryPoint->findDecoration(); + if (!numthreadsDecor) + return; + builder.setInsertBefore(use->getUser()); + IRInst* values[3] = { + numthreadsDecor->getOperand(0), + numthreadsDecor->getOperand(1), + numthreadsDecor->getOperand(2)}; + + auto workgroupSize = builder.emitMakeVector( + builder.getVectorType(builder.getIntType(), 3), + 3, + values); + builder.replaceOperand(use, workgroupSize); + } + } + }); + + // If workgroupSizeInst still has uses, it means it is used by multiple entry points. + // We need to introduce a global variable and assign value to it in each entry point. + + if (!workgroupSizeInst->hasUses()) + { + workgroupSizeInst->removeAndDeallocate(); + return; + } + builder.setInsertBefore(workgroupSizeInst); + auto globalVar = builder.createGlobalVar(workgroupSizeInst->getFullType()); + + // Replace all remaining uses of the workgroupSize inst of a load from globalVar. + traverseUses( + workgroupSizeInst, + [&](IRUse* use) + { + builder.setInsertBefore(use->getUser()); + auto load = builder.emitLoad(globalVar); + builder.replaceOperand(use, load); + }); + + // Now insert assignments from each entry point. + for (auto globalInst : module->getGlobalInsts()) + { + auto func = as(getResolvedInstForDecorations(globalInst)); + if (!func) + continue; + if (auto numthreadsDecor = func->findDecoration()) + { + auto firstBlock = func->getFirstBlock(); + if (!firstBlock) + continue; + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + IRInst* args[3] = { + numthreadsDecor->getOperand(0), + numthreadsDecor->getOperand(1), + numthreadsDecor->getOperand(2)}; + auto workgroupSize = + builder.emitMakeVector(workgroupSizeInst->getFullType(), 3, args); + builder.emitStore(globalVar, workgroupSize); + } + } + + workgroupSizeInst->removeAndDeallocate(); + } +}; + +void translateGlobalVaryingVar(CodeGenContext* context, IRModule* module) +{ + GlobalVarTranslationContext ctx; + ctx.context = context; + ctx.processModule(module); +} +} // namespace Slang diff --git a/source/slang/slang-ir-translate-global-varying-var.h b/source/slang/slang-ir-translate-global-varying-var.h new file mode 100644 index 000000000..f97683700 --- /dev/null +++ b/source/slang/slang-ir-translate-global-varying-var.h @@ -0,0 +1,14 @@ +// slang-ir-translate-global-varying-var.h +#pragma once + +namespace Slang +{ + +struct IRModule; +struct CodeGenContext; + +/// Translate GLSL-flavored global in/out variables into +/// entry point parameters with system value semantics. +void translateGlobalVaryingVar(CodeGenContext* context, IRModule* module); + +} // namespace Slang diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp deleted file mode 100644 index 80ed3c3e4..000000000 --- a/source/slang/slang-ir-translate-glsl-global-var.cpp +++ /dev/null @@ -1,382 +0,0 @@ -#include "slang-ir-translate-glsl-global-var.h" - -#include "slang-ir-call-graph.h" -#include "slang-ir-insts.h" -#include "slang-ir-util.h" -#include "slang-ir.h" - -namespace Slang -{ -struct GlobalVarTranslationContext -{ - CodeGenContext* context; - - void processModule(IRModule* module) - { - Dictionary> referencingEntryPoints; - buildEntryPointReferenceGraph(referencingEntryPoints, module); - - List entryPoints; - List getWorkGroupSizeInsts; - - // Traverse the module to find all entry points. - // If we see a `GetWorkGroupSize` instruction, we will materialize it. - // - for (auto inst : module->getGlobalInsts()) - { - if (inst->getOp() == kIROp_Func && inst->findDecoration()) - entryPoints.add(inst); - else if (inst->getOp() == kIROp_GetWorkGroupSize) - getWorkGroupSizeInsts.add(inst); - } - for (auto inst : getWorkGroupSizeInsts) - materializeGetWorkGroupSize(module, referencingEntryPoints, inst); - IRBuilder builder(module); - - for (auto entryPoint : entryPoints) - { - List outputVars; - List inputVars; - for (auto inst : module->getGlobalInsts()) - { - if (auto referencingEntryPointSet = referencingEntryPoints.tryGetValue(inst)) - { - if (referencingEntryPointSet->contains((IRFunc*)entryPoint)) - { - if (inst->findDecoration()) - { - outputVars.add(inst); - } - if (inst->findDecoration()) - { - inputVars.add(inst); - } - } - } - } - - bool hasInput = inputVars.getCount() != 0; - bool hasOutput = outputVars.getCount() != 0; - - if (!hasInput && !hasOutput) - continue; - - auto entryPointFunc = as(entryPoint); - if (!entryPointFunc) - continue; - - auto entryPointDecor = entryPointFunc->findDecoration(); - - IRVarLayout* resultVarLayout = nullptr; - IRVarLayout* paramLayout = nullptr; - IRType* resultType = entryPointFunc->getResultType(); - - // Create a struct type to receive all inputs. - builder.setInsertBefore(entryPointFunc); - auto inputStructType = builder.createStructType(); - IRStructTypeLayout::Builder inputStructTypeLayoutBuilder(&builder); - UInt inputVarIndex = 0; - List inputKeys; - for (auto input : inputVars) - { - auto inputType = cast(input->getDataType())->getValueType(); - auto key = builder.createStructKey(); - inputKeys.add(key); - builder.createStructField(inputStructType, key, inputType); - - IRTypeLayout::Builder fieldTypeLayoutBuilder(&builder); - IRTypeLayout* fieldTypeLayout = nullptr; - bool hasExistingLayout = false; - if (auto existingLayoutDecoration = input->findDecoration()) - { - if (auto existingVarLayout = - as(existingLayoutDecoration->getLayout())) - { - fieldTypeLayout = existingVarLayout->getTypeLayout(); - hasExistingLayout = true; - } - } - - if (!hasExistingLayout) - { - fieldTypeLayout = fieldTypeLayoutBuilder.build(); - } - - IRVarLayout::Builder varLayoutBuilder(&builder, fieldTypeLayout); - varLayoutBuilder.setStage(entryPointDecor->getProfile().getStage()); - if (auto semanticDecor = input->findDecoration()) - { - varLayoutBuilder.setSystemValueSemantic( - semanticDecor->getSemanticName(), - semanticDecor->getSemanticIndex()); - } - else - { - if (!hasExistingLayout) - { - fieldTypeLayoutBuilder.addResourceUsage( - LayoutResourceKind::VaryingInput, - LayoutSize(1)); - } - if (auto layoutDecor = findVarLayout(input)) - { - if (auto offsetAttr = - layoutDecor->findOffsetAttr(LayoutResourceKind::VaryingInput)) - { - varLayoutBuilder - .findOrAddResourceInfo(LayoutResourceKind::VaryingInput) - ->offset = (UInt)offsetAttr->getOffset(); - } - } - if (entryPointDecor->getProfile().getStage() == Stage::Fragment) - { - varLayoutBuilder.setUserSemantic("COLOR", inputVarIndex); - } - else if (entryPointDecor->getProfile().getStage() == Stage::Vertex) - { - varLayoutBuilder.setUserSemantic("VERTEX_IN_", inputVarIndex); - } - inputVarIndex++; - } - inputStructTypeLayoutBuilder.addField(key, varLayoutBuilder.build()); - input->transferDecorationsTo(key); - } - auto paramTypeLayout = inputStructTypeLayoutBuilder.build(); - IRVarLayout::Builder paramVarLayoutBuilder(&builder, paramTypeLayout); - paramLayout = paramVarLayoutBuilder.build(); - - // Add an entry point parameter for all the inputs. - auto firstBlock = entryPointFunc->getFirstBlock(); - builder.setInsertInto(firstBlock); - auto inputParam = builder.emitParam( - builder.getPtrType(kIROp_ConstRefType, inputStructType, AddressSpace::Input)); - builder.addLayoutDecoration(inputParam, paramLayout); - - // Initialize all global variables. - for (Index i = 0; i < inputVars.getCount(); i++) - { - auto input = inputVars[i]; - setInsertBeforeOrdinaryInst(&builder, firstBlock->getFirstOrdinaryInst()); - auto inputType = cast(input->getDataType())->getValueType(); - // TODO: This could be more efficient as a Load(FieldAddress(inputParam, i)) - // operation instead of a FieldExtract(Load(inputParam)). - builder.emitStore( - input, - builder - .emitFieldExtract(inputType, builder.emitLoad(inputParam), inputKeys[i])); - // Relate "global variable" to a "global parameter" for use later in compilation - // to resolve a "global variable" shadowing a "global parameter" relationship. - builder.addGlobalVariableShadowingGlobalParameterDecoration( - inputParam, - input, - inputKeys[i]); - } - - // For each entry point, introduce a new parameter to represent each input parameter, - // and return all outputs via a struct value. - if (hasOutput) - { - // If we have global outputs, the entry-point must not return anything itself. - if (as(entryPoint->getDataType())->getResultType()->getOp() != - kIROp_VoidType) - { - context->getSink()->diagnose( - entryPointFunc, - Diagnostics::entryPointMustReturnVoidWhenGlobalOutputPresent); - continue; - } - builder.setInsertBefore(entryPointFunc); - resultType = builder.createStructType(); - IRStructTypeLayout::Builder typeLayoutBuilder(&builder); - UInt outputVarIndex = 0; - for (auto output : outputVars) - { - auto key = builder.createStructKey(); - auto ptrType = as(output->getDataType()); - builder.createStructField(resultType, key, ptrType->getValueType()); - IRTypeLayout::Builder fieldTypeLayout(&builder); - IRVarLayout::Builder varLayoutBuilder(&builder, fieldTypeLayout.build()); - varLayoutBuilder.setStage(entryPointDecor->getProfile().getStage()); - if (auto semanticDecor = output->findDecoration()) - { - varLayoutBuilder.setSystemValueSemantic( - semanticDecor->getSemanticName(), - semanticDecor->getSemanticIndex()); - } - else - { - fieldTypeLayout.addResourceUsage( - LayoutResourceKind::VaryingOutput, - LayoutSize(1)); - if (auto layoutDecor = findVarLayout(output)) - { - if (auto offsetAttr = - layoutDecor->findOffsetAttr(LayoutResourceKind::VaryingOutput)) - { - varLayoutBuilder - .findOrAddResourceInfo(LayoutResourceKind::VaryingOutput) - ->offset = (UInt)offsetAttr->getOffset(); - } - } - if (entryPointDecor->getProfile().getStage() == Stage::Fragment) - { - varLayoutBuilder.setSystemValueSemantic("SV_TARGET", outputVarIndex); - } - else if (entryPointDecor->getProfile().getStage() == Stage::Vertex) - { - varLayoutBuilder.setUserSemantic("COLOR", outputVarIndex); - } - outputVarIndex++; - } - typeLayoutBuilder.addField(key, varLayoutBuilder.build()); - output->transferDecorationsTo(key); - } - auto resultTypeLayout = typeLayoutBuilder.build(); - IRVarLayout::Builder resultVarLayoutBuilder(&builder, resultTypeLayout); - resultVarLayout = resultVarLayoutBuilder.build(); - - for (auto block : entryPointFunc->getBlocks()) - { - if (auto returnInst = as(block->getTerminator())) - { - // Return the struct value. - builder.setInsertBefore(returnInst); - List fieldVals; - for (auto outputVar : outputVars) - { - auto load = builder.emitLoad(outputVar); - fieldVals.add(load); - } - auto resultVal = builder.emitMakeStruct( - resultType, - (UInt)fieldVals.getCount(), - fieldVals.getBuffer()); - builder.emitReturn(resultVal); - returnInst->removeAndDeallocate(); - } - } - } - if (auto entryPointLayoutDecor = entryPointFunc->findDecoration()) - { - if (auto entryPointLayout = - as(entryPointLayoutDecor->getLayout())) - { - if (paramLayout) - builder.replaceOperand(entryPointLayout->getOperands(), paramLayout); - if (resultVarLayout) - builder.replaceOperand( - entryPointLayout->getOperands() + 1, - resultVarLayout); - } - } - // Update func type for the entry point. - List paramTypes; - for (auto param : entryPointFunc->getParams()) - { - paramTypes.add(param->getDataType()); - } - IRType* newFuncType = builder.getFuncType(paramTypes, resultType); - entryPointFunc->setFullType(newFuncType); - } - } - - // If we see a `GetWorkGroupSize` instruction, we should materialize it by replacing its uses - // with a constant that represent the workgroup size of the calling entrypoint. This is trivial - // if the `GetWorkGroupSize` instruction is used from a function called by one entry point. If - // it is used in a place reachable from multiple entry points, we will introduce a global - // variable to represent the workgroup size, and replace the uses with a load from the global - // variable. We will assign the value of the global variable at the start of each entry point. - // - void materializeGetWorkGroupSize( - IRModule* module, - Dictionary>& referenceGraph, - IRInst* workgroupSizeInst) - { - IRBuilder builder(workgroupSizeInst); - traverseUses( - workgroupSizeInst, - [&](IRUse* use) - { - if (auto parentFunc = getParentFunc(use->getUser())) - { - auto referenceSet = referenceGraph.tryGetValue(parentFunc); - if (!referenceSet) - return; - if (referenceSet->getCount() == 1) - { - // If the function that uses the workgroup size is only used by one entry - // point, we can materialize the workgroup size by substituting the use with - // a constant. - auto entryPoint = *referenceSet->begin(); - auto numthreadsDecor = entryPoint->findDecoration(); - if (!numthreadsDecor) - return; - builder.setInsertBefore(use->getUser()); - IRInst* values[3] = { - numthreadsDecor->getOperand(0), - numthreadsDecor->getOperand(1), - numthreadsDecor->getOperand(2)}; - - auto workgroupSize = builder.emitMakeVector( - builder.getVectorType(builder.getIntType(), 3), - 3, - values); - builder.replaceOperand(use, workgroupSize); - } - } - }); - - // If workgroupSizeInst still has uses, it means it is used by multiple entry points. - // We need to introduce a global variable and assign value to it in each entry point. - - if (!workgroupSizeInst->hasUses()) - { - workgroupSizeInst->removeAndDeallocate(); - return; - } - builder.setInsertBefore(workgroupSizeInst); - auto globalVar = builder.createGlobalVar(workgroupSizeInst->getFullType()); - - // Replace all remaining uses of the workgroupSize inst of a load from globalVar. - traverseUses( - workgroupSizeInst, - [&](IRUse* use) - { - builder.setInsertBefore(use->getUser()); - auto load = builder.emitLoad(globalVar); - builder.replaceOperand(use, load); - }); - - // Now insert assignments from each entry point. - for (auto globalInst : module->getGlobalInsts()) - { - auto func = as(getResolvedInstForDecorations(globalInst)); - if (!func) - continue; - if (auto numthreadsDecor = func->findDecoration()) - { - auto firstBlock = func->getFirstBlock(); - if (!firstBlock) - continue; - builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - IRInst* args[3] = { - numthreadsDecor->getOperand(0), - numthreadsDecor->getOperand(1), - numthreadsDecor->getOperand(2)}; - auto workgroupSize = - builder.emitMakeVector(workgroupSizeInst->getFullType(), 3, args); - builder.emitStore(globalVar, workgroupSize); - } - } - - workgroupSizeInst->removeAndDeallocate(); - } -}; - -void translateGLSLGlobalVar(CodeGenContext* context, IRModule* module) -{ - GlobalVarTranslationContext ctx; - ctx.context = context; - ctx.processModule(module); -} -} // namespace Slang diff --git a/source/slang/slang-ir-translate-glsl-global-var.h b/source/slang/slang-ir-translate-glsl-global-var.h deleted file mode 100644 index 5821ba5c5..000000000 --- a/source/slang/slang-ir-translate-glsl-global-var.h +++ /dev/null @@ -1,17 +0,0 @@ -// slang-ir-translate-glsl-global-var.h -#ifndef SLANG_IR_TRANSLATE_GLSL_GLOBAL_VAR_H -#define SLANG_IR_TRANSLATE_GLSL_GLOBAL_VAR_H - -namespace Slang -{ - -struct IRModule; -struct CodeGenContext; - -/// Translate global in/out variables defined in GLSL-flavored code -/// into entry point parameters with system value semantics. -void translateGLSLGlobalVar(CodeGenContext* context, IRModule* module); - -} // namespace Slang - -#endif // SLANG_IR_TRANSLATE_GLSL_GLOBAL_VAR_H -- cgit v1.2.3