From 0a6828572aa4cc1f0f99993e77c321799eb88cca Mon Sep 17 00:00:00 2001 From: Darren Wihandi <65404740+fairywreath@users.noreply.github.com> Date: Sun, 2 Feb 2025 15:27:11 -0500 Subject: Add support for WGSL subgroup operations (#6213) * initial work * more work * more work on glsl intrinsics * add subgroup broadcast for glsl * wip add wgsl extension tracking * enable tests, enable extensions and added some todos * format and warning fixes * fix wgsl extension tracker --------- Co-authored-by: Yong He --- source/slang/slang-compiler.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) (limited to 'source/slang/slang-compiler.cpp') diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 448534ce8..04ebb753c 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -28,7 +28,7 @@ // Artifact output #include "slang-artifact-output-util.h" #include "slang-emit-cuda.h" -#include "slang-glsl-extension-tracker.h" +#include "slang-extension-tracker.h" #include "slang-lower-to-ir.h" #include "slang-mangle.h" #include "slang-parameter-binding.h" @@ -658,7 +658,7 @@ static void _appendCodeWithPath( outCodeBuilder << fileContent << "\n"; } -void trackGLSLTargetCaps(GLSLExtensionTracker* extensionTracker, CapabilitySet const& caps) +void trackGLSLTargetCaps(ShaderExtensionTracker* extensionTracker, CapabilitySet const& caps) { for (auto& conjunctions : caps.getAtomSets()) { @@ -1037,8 +1037,11 @@ static RefPtr _newExtensionTracker(CodeGenTarget target) } case CodeGenTarget::SPIRV: case CodeGenTarget::GLSL: + case CodeGenTarget::WGSL: + case CodeGenTarget::WGSLSPIRV: + case CodeGenTarget::WGSLSPIRVAssembly: { - return new GLSLExtensionTracker; + return new ShaderExtensionTracker; } default: return nullptr; @@ -1261,7 +1264,7 @@ SlangResult CodeGenContext::emitWithDownstreamForEntryPoints(ComPtr& if (auto endToEndReq = isPassThroughEnabled()) { // If we are pass through, we may need to set extension tracker state. - if (GLSLExtensionTracker* glslTracker = as(extensionTracker)) + if (ShaderExtensionTracker* glslTracker = as(extensionTracker)) { trackGLSLTargetCaps(glslTracker, getTargetCaps()); } @@ -1400,7 +1403,7 @@ SlangResult CodeGenContext::emitWithDownstreamForEntryPoints(ComPtr& options.flags |= CompileOptions::Flag::EnableFloat16; } } - else if (GLSLExtensionTracker* glslTracker = as(extensionTracker)) + else if (ShaderExtensionTracker* glslTracker = as(extensionTracker)) { DownstreamCompileOptions::CapabilityVersion version; version.kind = DownstreamCompileOptions::CapabilityVersion::Kind::SPIRV; -- cgit v1.2.3