summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorcheneym2 <acheney@nvidia.com>2025-03-05 16:45:03 -0500
committerGitHub <noreply@github.com>2025-03-05 13:45:03 -0800
commit0634684495f709fe3594fdcd483cfce7933e54eb (patch)
tree7a5d99705475a885b0d22169a56678a399133d12
parent5248a0254a48382d06ecb190c9f87c0ab62ff534 (diff)
Support SPIR-V deferred linking option (#6500)
The new option "SkipDownstreamLinking" will defer final downstream IR linking to the user application. This option only has an effect if there are modules that were precompiled to the target IR using precompileForTarget(). Until now, the default behavior for SPIR-V was to use deferred linking, and the default behavior for DXIL was to use immediate/internal linking in Slang. This change only affects the SPIR-V behavior such that both deferred and non-deferred linking is supported based on the new option. To support the non-deferred option, Slang will internally call into SPIRV-Tools-link to reconstitute a complete SPIR-V shader program when necessary (due to modules having been precompiled to target IR). Otherwise, if SkipDownstreamLinking is enabled, the shader returned by e.g. getTargetCode() or getEntryPointCode() may have import linkage to the SPIR-V embedded in the constituent modules. Closes #4994 Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
-rw-r--r--include/slang-gfx.h9
-rw-r--r--include/slang.h3
-rw-r--r--source/compiler-core/slang-downstream-compiler.h13
-rw-r--r--source/compiler-core/slang-glslang-compiler.cpp41
-rw-r--r--source/compiler-core/slang-glslang-compiler.h8
-rw-r--r--source/slang-glslang/slang-glslang.cpp2
-rw-r--r--source/slang/slang-compiler.cpp6
-rw-r--r--source/slang/slang-compiler.h9
-rw-r--r--source/slang/slang-emit.cpp63
-rw-r--r--tools/gfx-unit-test/gfx-test-util.cpp11
-rw-r--r--tools/gfx-unit-test/gfx-test-util.h10
-rw-r--r--tools/gfx-unit-test/precompiled-module-2.cpp65
-rw-r--r--tools/gfx/renderer-shared.cpp10
-rw-r--r--tools/gfx/vulkan/vk-device.h3
-rw-r--r--tools/gfx/vulkan/vk-shader-program.cpp16
15 files changed, 245 insertions, 24 deletions
diff --git a/include/slang-gfx.h b/include/slang-gfx.h
index 6f46fed33..db9dcbacb 100644
--- a/include/slang-gfx.h
+++ b/include/slang-gfx.h
@@ -163,6 +163,12 @@ public:
SeparateEntryPointCompilation
};
+ enum class DownstreamLinkMode
+ {
+ None,
+ Deferred,
+ };
+
struct Desc
{
// TODO: Tess doesn't like this but doesn't know what to do about it
@@ -180,6 +186,9 @@ public:
// An array of Slang entry points. The size of the array must be `entryPointCount`.
// Each element must define only 1 Slang EntryPoint.
slang::IComponentType** slangEntryPoints = nullptr;
+
+ // Indicates whether the app is responsible for final downstream linking.
+ DownstreamLinkMode downstreamLinkMode = DownstreamLinkMode::None;
};
struct CreateDesc2
diff --git a/include/slang.h b/include/slang.h
index 66fd317c6..d000dab9f 100644
--- a/include/slang.h
+++ b/include/slang.h
@@ -653,6 +653,7 @@ typedef uint32_t SlangSizeT;
SLANG_PASS_THROUGH_SPIRV_OPT, ///< SPIRV-opt
SLANG_PASS_THROUGH_METAL, ///< Metal compiler
SLANG_PASS_THROUGH_TINT, ///< Tint WGSL compiler
+ SLANG_PASS_THROUGH_SPIRV_LINK, ///< SPIRV-link
SLANG_PASS_THROUGH_COUNT_OF,
};
@@ -1008,6 +1009,8 @@ typedef uint32_t SlangSizeT;
EmitReflectionJSON, // bool
SaveGLSLModuleBinSource,
+
+ SkipDownstreamLinking, // bool, experimental
CountOf,
};
diff --git a/source/compiler-core/slang-downstream-compiler.h b/source/compiler-core/slang-downstream-compiler.h
index 82aaef107..c96003cc4 100644
--- a/source/compiler-core/slang-downstream-compiler.h
+++ b/source/compiler-core/slang-downstream-compiler.h
@@ -343,6 +343,19 @@ public:
/// True if underlying compiler uses file system to communicate source
virtual SLANG_NO_THROW bool SLANG_MCALL isFileBased() = 0;
+
+ virtual SLANG_NO_THROW int SLANG_MCALL link(
+ const uint32_t** modules,
+ const uint32_t* moduleSizes,
+ const uint32_t moduleCount,
+ IArtifact** outArtifact)
+ {
+ SLANG_UNREFERENCED_PARAMETER(modules);
+ SLANG_UNREFERENCED_PARAMETER(moduleSizes);
+ SLANG_UNREFERENCED_PARAMETER(moduleCount);
+ SLANG_UNREFERENCED_PARAMETER(outArtifact);
+ return 0;
+ }
};
class DownstreamCompilerBase : public ComBaseObject, public IDownstreamCompiler
diff --git a/source/compiler-core/slang-glslang-compiler.cpp b/source/compiler-core/slang-glslang-compiler.cpp
index b619f468f..540b437c5 100644
--- a/source/compiler-core/slang-glslang-compiler.cpp
+++ b/source/compiler-core/slang-glslang-compiler.cpp
@@ -49,6 +49,11 @@ public:
validate(const uint32_t* contents, int contentsSize) SLANG_OVERRIDE;
virtual SLANG_NO_THROW SlangResult SLANG_MCALL
disassemble(const uint32_t* contents, int contentsSize) SLANG_OVERRIDE;
+ int link(
+ const uint32_t** modules,
+ const uint32_t* moduleSizes,
+ const uint32_t moduleCount,
+ IArtifact** outArtifact) SLANG_OVERRIDE;
/// Must be called before use
SlangResult init(ISlangSharedLibrary* library);
@@ -66,6 +71,7 @@ protected:
glslang_CompileFunc_1_2 m_compile_1_2 = nullptr;
glslang_ValidateSPIRVFunc m_validate = nullptr;
glslang_DisassembleSPIRVFunc m_disassemble = nullptr;
+ glslang_LinkSPIRVFunc m_link = nullptr;
ComPtr<ISlangSharedLibrary> m_sharedLibrary;
@@ -80,6 +86,7 @@ SlangResult GlslangDownstreamCompiler::init(ISlangSharedLibrary* library)
m_validate = (glslang_ValidateSPIRVFunc)library->findFuncByName("glslang_validateSPIRV");
m_disassemble =
(glslang_DisassembleSPIRVFunc)library->findFuncByName("glslang_disassembleSPIRV");
+ m_link = (glslang_LinkSPIRVFunc)library->findFuncByName("glslang_linkSPIRV");
if (m_compile_1_0 == nullptr && m_compile_1_1 == nullptr && m_compile_1_2 == nullptr)
{
@@ -323,6 +330,32 @@ SlangResult GlslangDownstreamCompiler::disassemble(const uint32_t* contents, int
return SLANG_FAIL;
}
+SlangResult GlslangDownstreamCompiler::link(
+ const uint32_t** modules,
+ const uint32_t* moduleSizes,
+ const uint32_t moduleCount,
+ IArtifact** outArtifact)
+{
+ glslang_LinkRequest request;
+ memset(&request, 0, sizeof(request));
+
+ request.modules = modules;
+ request.moduleSizes = moduleSizes;
+ request.moduleCount = moduleCount;
+
+ if (!m_link(&request))
+ {
+ return SLANG_FAIL;
+ }
+
+ auto artifact = ArtifactUtil::createArtifactForCompileTarget(SLANG_SPIRV);
+ artifact->addRepresentationUnknown(
+ Slang::RawBlob::create(request.linkResult, request.linkResultSize * sizeof(uint32_t)));
+
+ *outArtifact = artifact.detach();
+ return SLANG_OK;
+}
+
bool GlslangDownstreamCompiler::canConvert(const ArtifactDesc& from, const ArtifactDesc& to)
{
// Can only disassemble blobs that are SPIR-V
@@ -467,6 +500,14 @@ SlangResult SpirvDisDownstreamCompilerUtil::locateCompilers(
return locateGlslangSpirvDownstreamCompiler(path, loader, set, SLANG_PASS_THROUGH_SPIRV_DIS);
}
+SlangResult SpirvLinkDownstreamCompilerUtil::locateCompilers(
+ const String& path,
+ ISlangSharedLibraryLoader* loader,
+ DownstreamCompilerSet* set)
+{
+ return locateGlslangSpirvDownstreamCompiler(path, loader, set, SLANG_PASS_THROUGH_SPIRV_LINK);
+}
+
#else // SLANG_ENABLE_GLSLANG_SUPPORT
/* static */ SlangResult GlslangDownstreamCompilerUtil::locateCompilers(
diff --git a/source/compiler-core/slang-glslang-compiler.h b/source/compiler-core/slang-glslang-compiler.h
index 73cc61135..d56ad7114 100644
--- a/source/compiler-core/slang-glslang-compiler.h
+++ b/source/compiler-core/slang-glslang-compiler.h
@@ -32,6 +32,14 @@ struct SpirvDisDownstreamCompilerUtil
DownstreamCompilerSet* set);
};
+struct SpirvLinkDownstreamCompilerUtil
+{
+ static SlangResult locateCompilers(
+ const String& path,
+ ISlangSharedLibraryLoader* loader,
+ DownstreamCompilerSet* set);
+};
+
} // namespace Slang
#endif
diff --git a/source/slang-glslang/slang-glslang.cpp b/source/slang-glslang/slang-glslang.cpp
index bbb3f6afc..b3370d803 100644
--- a/source/slang-glslang/slang-glslang.cpp
+++ b/source/slang-glslang/slang-glslang.cpp
@@ -1037,7 +1037,7 @@ extern "C"
request->linkResultSize = linkedBinary.size();
}
- return success;
+ return success == SPV_SUCCESS;
}
catch (...)
{
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 58cc55e71..3839e0722 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -2669,6 +2669,12 @@ bool CodeGenContext::shouldDumpIR()
return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIr);
}
+bool CodeGenContext::shouldSkipDownstreamLinking()
+{
+ return getTargetProgram()->getOptionSet().getBoolOption(
+ CompilerOptionName::SkipDownstreamLinking);
+}
+
bool CodeGenContext::shouldReportCheckpointIntermediates()
{
return getTargetProgram()->getOptionSet().getBoolOption(
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 10da32400..cfcbe816f 100644
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -1384,7 +1384,8 @@ enum class PassThroughMode : SlangPassThroughIntegral
LLVM = SLANG_PASS_THROUGH_LLVM, ///< LLVM 'compiler'
SpirvOpt = SLANG_PASS_THROUGH_SPIRV_OPT, ///< pass thorugh spirv to spirv-opt
MetalC = SLANG_PASS_THROUGH_METAL,
- Tint = SLANG_PASS_THROUGH_TINT, ///< pass through spirv to Tint API
+ Tint = SLANG_PASS_THROUGH_TINT, ///< pass through spirv to Tint API
+ SpirvLink = SLANG_PASS_THROUGH_SPIRV_LINK, ///< pass through spirv to spirv-link
CountOf = SLANG_PASS_THROUGH_COUNT_OF,
};
void printDiagnosticArg(StringBuilder& sb, PassThroughMode val);
@@ -2886,6 +2887,12 @@ public:
// removed between IR linking and target source generation.
bool removeAvailableInDownstreamIR = false;
+ // Determines if program level compilation like getTargetCode() or getEntryPointCode()
+ // should return a fully linked downstream program or just the glue SPIR-V/DXIL that
+ // imports and uses the precompiled SPIR-V/DXIL from constituent modules.
+ // This is a no-op if modules are not precompiled.
+ bool shouldSkipDownstreamLinking();
+
protected:
CodeGenTarget m_targetFormat = CodeGenTarget::Unknown;
ExtensionTracker* m_extensionTracker = nullptr;
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index ddb4ea67a..94ea66d71 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -2093,10 +2093,71 @@ SlangResult emitSPIRVForEntryPointsDirectly(
if (compiler)
{
#if 0
- // Dump the unoptimized SPIRV after lowering from slang IR -> SPIRV
+ // Dump the unoptimized/unlinked SPIRV after lowering from slang IR -> SPIRV
compiler->disassemble((uint32_t*)spirv.getBuffer(), int(spirv.getCount() / 4));
#endif
+ bool isPrecompilation = codeGenContext->getTargetProgram()->getOptionSet().getBoolOption(
+ CompilerOptionName::EmbedDownstreamIR);
+
+ if (!isPrecompilation && !codeGenContext->shouldSkipDownstreamLinking())
+ {
+ ComPtr<IArtifact> linkedArtifact;
+
+ // collect spirv files
+ List<uint32_t*> spirvFiles;
+ List<uint32_t> spirvSizes;
+
+ // Start with the SPIR-V we just generated.
+ // SPIRV-Tools-link expects the size in 32-bit words
+ // whereas the spirv blob size is in bytes.
+ spirvFiles.add((uint32_t*)spirv.getBuffer());
+ spirvSizes.add(int(spirv.getCount()) / 4);
+
+ // Iterate over all modules in the linkedIR. For each module, if it
+ // contains an embedded downstream ir instruction, add it to the list
+ // of spirv files.
+ auto program = codeGenContext->getProgram();
+
+ program->enumerateIRModules(
+ [&](IRModule* irModule)
+ {
+ for (auto globalInst : irModule->getModuleInst()->getChildren())
+ {
+ if (auto inst = as<IREmbeddedDownstreamIR>(globalInst))
+ {
+ if (inst->getTarget() == CodeGenTarget::SPIRV)
+ {
+ auto slice = inst->getBlob()->getStringSlice();
+ spirvFiles.add((uint32_t*)slice.begin());
+ spirvSizes.add(int(slice.getLength()) / 4);
+ }
+ }
+ }
+ });
+
+ SLANG_ASSERT(int(spirv.getCount()) % 4 == 0);
+ SLANG_ASSERT(spirvFiles.getCount() == spirvSizes.getCount());
+
+ if (spirvFiles.getCount() > 1)
+ {
+ SlangResult linkresult = compiler->link(
+ (const uint32_t**)spirvFiles.getBuffer(),
+ (const uint32_t*)spirvSizes.getBuffer(),
+ (uint32_t)spirvFiles.getCount(),
+ linkedArtifact.writeRef());
+
+ if (linkresult != SLANG_OK)
+ {
+ return SLANG_FAIL;
+ }
+
+ ComPtr<ISlangBlob> blob;
+ linkedArtifact->loadBlob(ArtifactKeep::No, blob.writeRef());
+ artifact = _Move(linkedArtifact);
+ }
+ }
+
if (!codeGenContext->shouldSkipSPIRVValidation())
{
StringBuilder runSpirvValEnvVar;
diff --git a/tools/gfx-unit-test/gfx-test-util.cpp b/tools/gfx-unit-test/gfx-test-util.cpp
index 2bbe65416..d7cfa0f65 100644
--- a/tools/gfx-unit-test/gfx-test-util.cpp
+++ b/tools/gfx-unit-test/gfx-test-util.cpp
@@ -80,7 +80,8 @@ Slang::Result loadComputeProgram(
Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
const char* shaderModuleName,
const char* entryPointName,
- slang::ProgramLayout*& slangReflection)
+ slang::ProgramLayout*& slangReflection,
+ PrecompilationMode precompilationMode)
{
Slang::ComPtr<slang::IBlob> diagnosticsBlob;
slang::IModule* module = slangSession->loadModule(shaderModuleName, diagnosticsBlob.writeRef());
@@ -115,6 +116,14 @@ Slang::Result loadComputeProgram(
gfx::IShaderProgram::Desc programDesc = {};
programDesc.slangGlobalScope = composedProgram.get();
+ if (precompilationMode == PrecompilationMode::ExternalLink)
+ {
+ programDesc.downstreamLinkMode = gfx::IShaderProgram::DownstreamLinkMode::Deferred;
+ }
+ else
+ {
+ programDesc.downstreamLinkMode = gfx::IShaderProgram::DownstreamLinkMode::None;
+ }
auto shaderProgram = device->createProgram(programDesc);
diff --git a/tools/gfx-unit-test/gfx-test-util.h b/tools/gfx-unit-test/gfx-test-util.h
index 558670162..3397584fc 100644
--- a/tools/gfx-unit-test/gfx-test-util.h
+++ b/tools/gfx-unit-test/gfx-test-util.h
@@ -7,6 +7,13 @@
namespace gfx_test
{
+enum class PrecompilationMode
+{
+ None,
+ SlangIR,
+ InternalLink,
+ ExternalLink,
+};
/// Helper function for print out diagnostic messages output by Slang compiler.
void diagnoseIfNeeded(slang::IBlob* diagnosticsBlob);
@@ -24,7 +31,8 @@ Slang::Result loadComputeProgram(
Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
const char* shaderModuleName,
const char* entryPointName,
- slang::ProgramLayout*& slangReflection);
+ slang::ProgramLayout*& slangReflection,
+ PrecompilationMode precompilationMode = PrecompilationMode::None);
Slang::Result loadComputeProgramFromSource(
gfx::IDevice* device,
diff --git a/tools/gfx-unit-test/precompiled-module-2.cpp b/tools/gfx-unit-test/precompiled-module-2.cpp
index ca9f8b565..792f328b0 100644
--- a/tools/gfx-unit-test/precompiled-module-2.cpp
+++ b/tools/gfx-unit-test/precompiled-module-2.cpp
@@ -17,7 +17,7 @@ static Slang::Result precompileProgram(
gfx::IDevice* device,
ISlangMutableFileSystem* fileSys,
const char* shaderModuleName,
- bool precompileToTarget)
+ PrecompilationMode precompilationMode)
{
Slang::ComPtr<slang::ISession> slangSession;
SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
@@ -37,7 +37,8 @@ static Slang::Result precompileProgram(
if (!module)
return SLANG_FAIL;
- if (precompileToTarget)
+ if (precompilationMode == PrecompilationMode::InternalLink ||
+ precompilationMode == PrecompilationMode::ExternalLink)
{
SlangCompileTarget target;
switch (device->getDeviceInfo().deviceType)
@@ -82,7 +83,7 @@ static Slang::Result precompileProgram(
void precompiledModule2TestImplCommon(
IDevice* device,
UnitTestContext* context,
- bool precompileToTarget)
+ PrecompilationMode precompilationMode)
{
Slang::ComPtr<ITransientResourceHeap> transientHeap;
ITransientResourceHeap::Desc transientHeapDesc = {};
@@ -100,7 +101,7 @@ void precompiledModule2TestImplCommon(
device,
memoryFileSystem.get(),
"precompiled-module-imported",
- precompileToTarget));
+ precompilationMode));
// Next, load the precompiled slang program.
Slang::ComPtr<slang::ISession> slangSession;
@@ -121,6 +122,17 @@ void precompiledModule2TestImplCommon(
}
sessionDesc.targets = &targetDesc;
sessionDesc.fileSystem = memoryFileSystem.get();
+
+ Slang::List<slang::CompilerOptionEntry> options;
+ slang::CompilerOptionEntry skipDownstreamLinkingOption;
+ skipDownstreamLinkingOption.name = slang::CompilerOptionName::SkipDownstreamLinking;
+ skipDownstreamLinkingOption.value.kind = slang::CompilerOptionValueKind::Int;
+ skipDownstreamLinkingOption.value.intValue0 =
+ precompilationMode == PrecompilationMode::ExternalLink;
+ options.add(skipDownstreamLinkingOption);
+
+ sessionDesc.compilerOptionEntries = options.getBuffer();
+ sessionDesc.compilerOptionEntryCount = options.getCount();
auto globalSession = slangSession->getGlobalSession();
globalSession->createSession(sessionDesc, slangSession.writeRef());
@@ -147,7 +159,8 @@ void precompiledModule2TestImplCommon(
shaderProgram,
"precompiled-module",
"computeMain",
- slangReflection));
+ slangReflection,
+ precompilationMode));
ComputePipelineStateDesc pipelineDesc = {};
pipelineDesc.program = shaderProgram.get();
@@ -208,12 +221,17 @@ void precompiledModule2TestImplCommon(
void precompiledModule2TestImpl(IDevice* device, UnitTestContext* context)
{
- precompiledModule2TestImplCommon(device, context, false);
+ precompiledModule2TestImplCommon(device, context, PrecompilationMode::SlangIR);
+}
+
+void precompiledTargetModule2InternalLinkTestImpl(IDevice* device, UnitTestContext* context)
+{
+ precompiledModule2TestImplCommon(device, context, PrecompilationMode::InternalLink);
}
-void precompiledTargetModule2TestImpl(IDevice* device, UnitTestContext* context)
+void precompiledTargetModule2ExternalLinkTestImpl(IDevice* device, UnitTestContext* context)
{
- precompiledModule2TestImplCommon(device, context, true);
+ precompiledModule2TestImplCommon(device, context, PrecompilationMode::ExternalLink);
}
SLANG_UNIT_TEST(precompiledModule2D3D12)
@@ -221,19 +239,42 @@ SLANG_UNIT_TEST(precompiledModule2D3D12)
runTestImpl(precompiledModule2TestImpl, unitTestContext, Slang::RenderApiFlag::D3D12);
}
-SLANG_UNIT_TEST(precompiledTargetModule2D3D12)
+SLANG_UNIT_TEST(precompiledTargetModuleInternalLink2D3D12)
{
- runTestImpl(precompiledTargetModule2TestImpl, unitTestContext, Slang::RenderApiFlag::D3D12);
+ runTestImpl(
+ precompiledTargetModule2InternalLinkTestImpl,
+ unitTestContext,
+ Slang::RenderApiFlag::D3D12);
}
+/*
+// Unavailable on D3D12/DXIL currently
+SLANG_UNIT_TEST(precompiledTargetModuleExternalLink2D3D12)
+{
+ runTestImpl(precompiledTargetModule2ExternalLinkTestImpl, unitTestContext,
+Slang::RenderApiFlag::D3D12);
+}
+*/
+
SLANG_UNIT_TEST(precompiledModule2Vulkan)
{
runTestImpl(precompiledModule2TestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan);
}
-SLANG_UNIT_TEST(precompiledTargetModule2Vulkan)
+SLANG_UNIT_TEST(precompiledTargetModule2InternalLinkVulkan)
+{
+ runTestImpl(
+ precompiledTargetModule2InternalLinkTestImpl,
+ unitTestContext,
+ Slang::RenderApiFlag::Vulkan);
+}
+
+SLANG_UNIT_TEST(precompiledTargetModule2ExternalLinkVulkan)
{
- runTestImpl(precompiledTargetModule2TestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ runTestImpl(
+ precompiledTargetModule2ExternalLinkTestImpl,
+ unitTestContext,
+ Slang::RenderApiFlag::Vulkan);
}
} // namespace gfx_test
diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp
index ae4ea3eef..edd271ecb 100644
--- a/tools/gfx/renderer-shared.cpp
+++ b/tools/gfx/renderer-shared.cpp
@@ -1133,11 +1133,13 @@ Result ShaderProgramBase::compileShaders(RendererBase* device)
kernelCodes.add(downstreamIR);
}
- // If target precompilation was used, kernelCode may only represent the
- // glue code holding together the bits of precompiled target IR.
- // Collect those dependency target IRs too.
+ // If target precompilation with deferred downstream linking is enabled,
+ // kernelCode may only represent the glue code holding together the
+ // bits of precompiled target IR. It's the application's job to pull it
+ // together. Collect those dependency target IRs too.
ComPtr<slang::IModulePrecompileService_Experimental> componentPrecompileService;
- if (entryPointComponent->queryInterface(
+ if (this->desc.downstreamLinkMode == DownstreamLinkMode::Deferred &&
+ entryPointComponent->queryInterface(
slang::IModulePrecompileService_Experimental::getTypeGuid(),
(void**)componentPrecompileService.writeRef()) == SLANG_OK)
{
diff --git a/tools/gfx/vulkan/vk-device.h b/tools/gfx/vulkan/vk-device.h
index 27cc4c3ce..06e19ad7e 100644
--- a/tools/gfx/vulkan/vk-device.h
+++ b/tools/gfx/vulkan/vk-device.h
@@ -223,6 +223,9 @@ public:
VkSampler m_defaultSampler;
RefPtr<FramebufferImpl> m_emptyFramebuffer;
+
+ // If true, slang will skip downstream linking, so we need to do it ourselves
+ bool m_skipsDownstreamLinking = false;
};
} // namespace vk
diff --git a/tools/gfx/vulkan/vk-shader-program.cpp b/tools/gfx/vulkan/vk-shader-program.cpp
index 1627c95a7..2dd9b0326 100644
--- a/tools/gfx/vulkan/vk-shader-program.cpp
+++ b/tools/gfx/vulkan/vk-shader-program.cpp
@@ -74,10 +74,20 @@ Result ShaderProgramImpl::createShaderModule(
slang::EntryPointReflection* entryPointInfo,
List<ComPtr<ISlangBlob>>& kernelCodes)
{
- ComPtr<ISlangBlob> linkedKernel = m_device->m_glslang.linkSPIRV(kernelCodes);
- if (!linkedKernel)
+ ComPtr<ISlangBlob> linkedKernel;
+ ComPtr<slang::ISession> slangSession;
+ m_device->getSlangSession(slangSession.writeRef());
+ if (kernelCodes.getCount() == 1)
{
- return SLANG_FAIL;
+ linkedKernel = kernelCodes[0];
+ }
+ else
+ {
+ linkedKernel = m_device->m_glslang.linkSPIRV(kernelCodes);
+ if (!linkedKernel)
+ {
+ return SLANG_FAIL;
+ }
}
m_codeBlobs.add(linkedKernel);