summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--include/slang-gfx.h9
-rw-r--r--include/slang.h3
-rw-r--r--source/compiler-core/slang-downstream-compiler.h13
-rw-r--r--source/compiler-core/slang-glslang-compiler.cpp41
-rw-r--r--source/compiler-core/slang-glslang-compiler.h8
-rw-r--r--source/slang-glslang/slang-glslang.cpp2
-rw-r--r--source/slang/slang-compiler.cpp6
-rw-r--r--source/slang/slang-compiler.h9
-rw-r--r--source/slang/slang-emit.cpp63
-rw-r--r--tools/gfx-unit-test/gfx-test-util.cpp11
-rw-r--r--tools/gfx-unit-test/gfx-test-util.h10
-rw-r--r--tools/gfx-unit-test/precompiled-module-2.cpp65
-rw-r--r--tools/gfx/renderer-shared.cpp10
-rw-r--r--tools/gfx/vulkan/vk-device.h3
-rw-r--r--tools/gfx/vulkan/vk-shader-program.cpp16
15 files changed, 245 insertions, 24 deletions
diff --git a/include/slang-gfx.h b/include/slang-gfx.h
index 6f46fed33..db9dcbacb 100644
--- a/include/slang-gfx.h
+++ b/include/slang-gfx.h
@@ -163,6 +163,12 @@ public:
SeparateEntryPointCompilation
};
+ enum class DownstreamLinkMode
+ {
+ None,
+ Deferred,
+ };
+
struct Desc
{
// TODO: Tess doesn't like this but doesn't know what to do about it
@@ -180,6 +186,9 @@ public:
// An array of Slang entry points. The size of the array must be `entryPointCount`.
// Each element must define only 1 Slang EntryPoint.
slang::IComponentType** slangEntryPoints = nullptr;
+
+ // Indicates whether the app is responsible for final downstream linking.
+ DownstreamLinkMode downstreamLinkMode = DownstreamLinkMode::None;
};
struct CreateDesc2
diff --git a/include/slang.h b/include/slang.h
index 66fd317c6..d000dab9f 100644
--- a/include/slang.h
+++ b/include/slang.h
@@ -653,6 +653,7 @@ typedef uint32_t SlangSizeT;
SLANG_PASS_THROUGH_SPIRV_OPT, ///< SPIRV-opt
SLANG_PASS_THROUGH_METAL, ///< Metal compiler
SLANG_PASS_THROUGH_TINT, ///< Tint WGSL compiler
+ SLANG_PASS_THROUGH_SPIRV_LINK, ///< SPIRV-link
SLANG_PASS_THROUGH_COUNT_OF,
};
@@ -1008,6 +1009,8 @@ typedef uint32_t SlangSizeT;
EmitReflectionJSON, // bool
SaveGLSLModuleBinSource,
+
+ SkipDownstreamLinking, // bool, experimental
CountOf,
};
diff --git a/source/compiler-core/slang-downstream-compiler.h b/source/compiler-core/slang-downstream-compiler.h
index 82aaef107..c96003cc4 100644
--- a/source/compiler-core/slang-downstream-compiler.h
+++ b/source/compiler-core/slang-downstream-compiler.h
@@ -343,6 +343,19 @@ public:
/// True if underlying compiler uses file system to communicate source
virtual SLANG_NO_THROW bool SLANG_MCALL isFileBased() = 0;
+
+ virtual SLANG_NO_THROW int SLANG_MCALL link(
+ const uint32_t** modules,
+ const uint32_t* moduleSizes,
+ const uint32_t moduleCount,
+ IArtifact** outArtifact)
+ {
+ SLANG_UNREFERENCED_PARAMETER(modules);
+ SLANG_UNREFERENCED_PARAMETER(moduleSizes);
+ SLANG_UNREFERENCED_PARAMETER(moduleCount);
+ SLANG_UNREFERENCED_PARAMETER(outArtifact);
+ return 0;
+ }
};
class DownstreamCompilerBase : public ComBaseObject, public IDownstreamCompiler
diff --git a/source/compiler-core/slang-glslang-compiler.cpp b/source/compiler-core/slang-glslang-compiler.cpp
index b619f468f..540b437c5 100644
--- a/source/compiler-core/slang-glslang-compiler.cpp
+++ b/source/compiler-core/slang-glslang-compiler.cpp
@@ -49,6 +49,11 @@ public:
validate(const uint32_t* contents, int contentsSize) SLANG_OVERRIDE;
virtual SLANG_NO_THROW SlangResult SLANG_MCALL
disassemble(const uint32_t* contents, int contentsSize) SLANG_OVERRIDE;
+ int link(
+ const uint32_t** modules,
+ const uint32_t* moduleSizes,
+ const uint32_t moduleCount,
+ IArtifact** outArtifact) SLANG_OVERRIDE;
/// Must be called before use
SlangResult init(ISlangSharedLibrary* library);
@@ -66,6 +71,7 @@ protected:
glslang_CompileFunc_1_2 m_compile_1_2 = nullptr;
glslang_ValidateSPIRVFunc m_validate = nullptr;
glslang_DisassembleSPIRVFunc m_disassemble = nullptr;
+ glslang_LinkSPIRVFunc m_link = nullptr;
ComPtr<ISlangSharedLibrary> m_sharedLibrary;
@@ -80,6 +86,7 @@ SlangResult GlslangDownstreamCompiler::init(ISlangSharedLibrary* library)
m_validate = (glslang_ValidateSPIRVFunc)library->findFuncByName("glslang_validateSPIRV");
m_disassemble =
(glslang_DisassembleSPIRVFunc)library->findFuncByName("glslang_disassembleSPIRV");
+ m_link = (glslang_LinkSPIRVFunc)library->findFuncByName("glslang_linkSPIRV");
if (m_compile_1_0 == nullptr && m_compile_1_1 == nullptr && m_compile_1_2 == nullptr)
{
@@ -323,6 +330,32 @@ SlangResult GlslangDownstreamCompiler::disassemble(const uint32_t* contents, int
return SLANG_FAIL;
}
+SlangResult GlslangDownstreamCompiler::link(
+ const uint32_t** modules,
+ const uint32_t* moduleSizes,
+ const uint32_t moduleCount,
+ IArtifact** outArtifact)
+{
+ glslang_LinkRequest request;
+ memset(&request, 0, sizeof(request));
+
+ request.modules = modules;
+ request.moduleSizes = moduleSizes;
+ request.moduleCount = moduleCount;
+
+ if (!m_link(&request))
+ {
+ return SLANG_FAIL;
+ }
+
+ auto artifact = ArtifactUtil::createArtifactForCompileTarget(SLANG_SPIRV);
+ artifact->addRepresentationUnknown(
+ Slang::RawBlob::create(request.linkResult, request.linkResultSize * sizeof(uint32_t)));
+
+ *outArtifact = artifact.detach();
+ return SLANG_OK;
+}
+
bool GlslangDownstreamCompiler::canConvert(const ArtifactDesc& from, const ArtifactDesc& to)
{
// Can only disassemble blobs that are SPIR-V
@@ -467,6 +500,14 @@ SlangResult SpirvDisDownstreamCompilerUtil::locateCompilers(
return locateGlslangSpirvDownstreamCompiler(path, loader, set, SLANG_PASS_THROUGH_SPIRV_DIS);
}
+SlangResult SpirvLinkDownstreamCompilerUtil::locateCompilers(
+ const String& path,
+ ISlangSharedLibraryLoader* loader,
+ DownstreamCompilerSet* set)
+{
+ return locateGlslangSpirvDownstreamCompiler(path, loader, set, SLANG_PASS_THROUGH_SPIRV_LINK);
+}
+
#else // SLANG_ENABLE_GLSLANG_SUPPORT
/* static */ SlangResult GlslangDownstreamCompilerUtil::locateCompilers(
diff --git a/source/compiler-core/slang-glslang-compiler.h b/source/compiler-core/slang-glslang-compiler.h
index 73cc61135..d56ad7114 100644
--- a/source/compiler-core/slang-glslang-compiler.h
+++ b/source/compiler-core/slang-glslang-compiler.h
@@ -32,6 +32,14 @@ struct SpirvDisDownstreamCompilerUtil
DownstreamCompilerSet* set);
};
+struct SpirvLinkDownstreamCompilerUtil
+{
+ static SlangResult locateCompilers(
+ const String& path,
+ ISlangSharedLibraryLoader* loader,
+ DownstreamCompilerSet* set);
+};
+
} // namespace Slang
#endif
diff --git a/source/slang-glslang/slang-glslang.cpp b/source/slang-glslang/slang-glslang.cpp
index bbb3f6afc..b3370d803 100644
--- a/source/slang-glslang/slang-glslang.cpp
+++ b/source/slang-glslang/slang-glslang.cpp
@@ -1037,7 +1037,7 @@ extern "C"
request->linkResultSize = linkedBinary.size();
}
- return success;
+ return success == SPV_SUCCESS;
}
catch (...)
{
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 58cc55e71..3839e0722 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -2669,6 +2669,12 @@ bool CodeGenContext::shouldDumpIR()
return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIr);
}
+bool CodeGenContext::shouldSkipDownstreamLinking()
+{
+ return getTargetProgram()->getOptionSet().getBoolOption(
+ CompilerOptionName::SkipDownstreamLinking);
+}
+
bool CodeGenContext::shouldReportCheckpointIntermediates()
{
return getTargetProgram()->getOptionSet().getBoolOption(
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 10da32400..cfcbe816f 100644
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -1384,7 +1384,8 @@ enum class PassThroughMode : SlangPassThroughIntegral
LLVM = SLANG_PASS_THROUGH_LLVM, ///< LLVM 'compiler'
SpirvOpt = SLANG_PASS_THROUGH_SPIRV_OPT, ///< pass thorugh spirv to spirv-opt
MetalC = SLANG_PASS_THROUGH_METAL,
- Tint = SLANG_PASS_THROUGH_TINT, ///< pass through spirv to Tint API
+ Tint = SLANG_PASS_THROUGH_TINT, ///< pass through spirv to Tint API
+ SpirvLink = SLANG_PASS_THROUGH_SPIRV_LINK, ///< pass through spirv to spirv-link
CountOf = SLANG_PASS_THROUGH_COUNT_OF,
};
void printDiagnosticArg(StringBuilder& sb, PassThroughMode val);
@@ -2886,6 +2887,12 @@ public:
// removed between IR linking and target source generation.
bool removeAvailableInDownstreamIR = false;
+ // Determines if program level compilation like getTargetCode() or getEntryPointCode()
+ // should return a fully linked downstream program or just the glue SPIR-V/DXIL that
+ // imports and uses the precompiled SPIR-V/DXIL from constituent modules.
+ // This is a no-op if modules are not precompiled.
+ bool shouldSkipDownstreamLinking();
+
protected:
CodeGenTarget m_targetFormat = CodeGenTarget::Unknown;
ExtensionTracker* m_extensionTracker = nullptr;
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index ddb4ea67a..94ea66d71 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -2093,10 +2093,71 @@ SlangResult emitSPIRVForEntryPointsDirectly(
if (compiler)
{
#if 0
- // Dump the unoptimized SPIRV after lowering from slang IR -> SPIRV
+ // Dump the unoptimized/unlinked SPIRV after lowering from slang IR -> SPIRV
compiler->disassemble((uint32_t*)spirv.getBuffer(), int(spirv.getCount() / 4));
#endif
+ bool isPrecompilation = codeGenContext->getTargetProgram()->getOptionSet().getBoolOption(
+ CompilerOptionName::EmbedDownstreamIR);
+
+ if (!isPrecompilation && !codeGenContext->shouldSkipDownstreamLinking())
+ {
+ ComPtr<IArtifact> linkedArtifact;
+
+ // collect spirv files
+ List<uint32_t*> spirvFiles;
+ List<uint32_t> spirvSizes;
+
+ // Start with the SPIR-V we just generated.
+ // SPIRV-Tools-link expects the size in 32-bit words
+ // whereas the spirv blob size is in bytes.
+ spirvFiles.add((uint32_t*)spirv.getBuffer());
+ spirvSizes.add(int(spirv.getCount()) / 4);
+
+ // Iterate over all modules in the linkedIR. For each module, if it
+ // contains an embedded downstream ir instruction, add it to the list
+ // of spirv files.
+ auto program = codeGenContext->getProgram();
+
+ program->enumerateIRModules(
+ [&](IRModule* irModule)
+ {
+ for (auto globalInst : irModule->getModuleInst()->getChildren())
+ {
+ if (auto inst = as<IREmbeddedDownstreamIR>(globalInst))
+ {
+ if (inst->getTarget() == CodeGenTarget::SPIRV)
+ {
+ auto slice = inst->getBlob()->getStringSlice();
+ spirvFiles.add((uint32_t*)slice.begin());
+ spirvSizes.add(int(slice.getLength()) / 4);
+ }
+ }
+ }
+ });
+
+ SLANG_ASSERT(int(spirv.getCount()) % 4 == 0);
+ SLANG_ASSERT(spirvFiles.getCount() == spirvSizes.getCount());
+
+ if (spirvFiles.getCount() > 1)
+ {
+ SlangResult linkresult = compiler->link(
+ (const uint32_t**)spirvFiles.getBuffer(),
+ (const uint32_t*)spirvSizes.getBuffer(),
+ (uint32_t)spirvFiles.getCount(),
+ linkedArtifact.writeRef());
+
+ if (linkresult != SLANG_OK)
+ {
+ return SLANG_FAIL;
+ }
+
+ ComPtr<ISlangBlob> blob;
+ linkedArtifact->loadBlob(ArtifactKeep::No, blob.writeRef());
+ artifact = _Move(linkedArtifact);
+ }
+ }
+
if (!codeGenContext->shouldSkipSPIRVValidation())
{
StringBuilder runSpirvValEnvVar;
diff --git a/tools/gfx-unit-test/gfx-test-util.cpp b/tools/gfx-unit-test/gfx-test-util.cpp
index 2bbe65416..d7cfa0f65 100644
--- a/tools/gfx-unit-test/gfx-test-util.cpp
+++ b/tools/gfx-unit-test/gfx-test-util.cpp
@@ -80,7 +80,8 @@ Slang::Result loadComputeProgram(
Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
const char* shaderModuleName,
const char* entryPointName,
- slang::ProgramLayout*& slangReflection)
+ slang::ProgramLayout*& slangReflection,
+ PrecompilationMode precompilationMode)
{
Slang::ComPtr<slang::IBlob> diagnosticsBlob;
slang::IModule* module = slangSession->loadModule(shaderModuleName, diagnosticsBlob.writeRef());
@@ -115,6 +116,14 @@ Slang::Result loadComputeProgram(
gfx::IShaderProgram::Desc programDesc = {};
programDesc.slangGlobalScope = composedProgram.get();
+ if (precompilationMode == PrecompilationMode::ExternalLink)
+ {
+ programDesc.downstreamLinkMode = gfx::IShaderProgram::DownstreamLinkMode::Deferred;
+ }
+ else
+ {
+ programDesc.downstreamLinkMode = gfx::IShaderProgram::DownstreamLinkMode::None;
+ }
auto shaderProgram = device->createProgram(programDesc);
diff --git a/tools/gfx-unit-test/gfx-test-util.h b/tools/gfx-unit-test/gfx-test-util.h
index 558670162..3397584fc 100644
--- a/tools/gfx-unit-test/gfx-test-util.h
+++ b/tools/gfx-unit-test/gfx-test-util.h
@@ -7,6 +7,13 @@
namespace gfx_test
{
+enum class PrecompilationMode
+{
+ None,
+ SlangIR,
+ InternalLink,
+ ExternalLink,
+};
/// Helper function for print out diagnostic messages output by Slang compiler.
void diagnoseIfNeeded(slang::IBlob* diagnosticsBlob);
@@ -24,7 +31,8 @@ Slang::Result loadComputeProgram(
Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
const char* shaderModuleName,
const char* entryPointName,
- slang::ProgramLayout*& slangReflection);
+ slang::ProgramLayout*& slangReflection,
+ PrecompilationMode precompilationMode = PrecompilationMode::None);
Slang::Result loadComputeProgramFromSource(
gfx::IDevice* device,
diff --git a/tools/gfx-unit-test/precompiled-module-2.cpp b/tools/gfx-unit-test/precompiled-module-2.cpp
index ca9f8b565..792f328b0 100644
--- a/tools/gfx-unit-test/precompiled-module-2.cpp
+++ b/tools/gfx-unit-test/precompiled-module-2.cpp
@@ -17,7 +17,7 @@ static Slang::Result precompileProgram(
gfx::IDevice* device,
ISlangMutableFileSystem* fileSys,
const char* shaderModuleName,
- bool precompileToTarget)
+ PrecompilationMode precompilationMode)
{
Slang::ComPtr<slang::ISession> slangSession;
SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
@@ -37,7 +37,8 @@ static Slang::Result precompileProgram(
if (!module)
return SLANG_FAIL;
- if (precompileToTarget)
+ if (precompilationMode == PrecompilationMode::InternalLink ||
+ precompilationMode == PrecompilationMode::ExternalLink)
{
SlangCompileTarget target;
switch (device->getDeviceInfo().deviceType)
@@ -82,7 +83,7 @@ static Slang::Result precompileProgram(
void precompiledModule2TestImplCommon(
IDevice* device,
UnitTestContext* context,
- bool precompileToTarget)
+ PrecompilationMode precompilationMode)
{
Slang::ComPtr<ITransientResourceHeap> transientHeap;
ITransientResourceHeap::Desc transientHeapDesc = {};
@@ -100,7 +101,7 @@ void precompiledModule2TestImplCommon(
device,
memoryFileSystem.get(),
"precompiled-module-imported",
- precompileToTarget));
+ precompilationMode));
// Next, load the precompiled slang program.
Slang::ComPtr<slang::ISession> slangSession;
@@ -121,6 +122,17 @@ void precompiledModule2TestImplCommon(
}
sessionDesc.targets = &targetDesc;
sessionDesc.fileSystem = memoryFileSystem.get();
+
+ Slang::List<slang::CompilerOptionEntry> options;
+ slang::CompilerOptionEntry skipDownstreamLinkingOption;
+ skipDownstreamLinkingOption.name = slang::CompilerOptionName::SkipDownstreamLinking;
+ skipDownstreamLinkingOption.value.kind = slang::CompilerOptionValueKind::Int;
+ skipDownstreamLinkingOption.value.intValue0 =
+ precompilationMode == PrecompilationMode::ExternalLink;
+ options.add(skipDownstreamLinkingOption);
+
+ sessionDesc.compilerOptionEntries = options.getBuffer();
+ sessionDesc.compilerOptionEntryCount = options.getCount();
auto globalSession = slangSession->getGlobalSession();
globalSession->createSession(sessionDesc, slangSession.writeRef());
@@ -147,7 +159,8 @@ void precompiledModule2TestImplCommon(
shaderProgram,
"precompiled-module",
"computeMain",
- slangReflection));
+ slangReflection,
+ precompilationMode));
ComputePipelineStateDesc pipelineDesc = {};
pipelineDesc.program = shaderProgram.get();
@@ -208,12 +221,17 @@ void precompiledModule2TestImplCommon(
void precompiledModule2TestImpl(IDevice* device, UnitTestContext* context)
{
- precompiledModule2TestImplCommon(device, context, false);
+ precompiledModule2TestImplCommon(device, context, PrecompilationMode::SlangIR);
+}
+
+void precompiledTargetModule2InternalLinkTestImpl(IDevice* device, UnitTestContext* context)
+{
+ precompiledModule2TestImplCommon(device, context, PrecompilationMode::InternalLink);
}
-void precompiledTargetModule2TestImpl(IDevice* device, UnitTestContext* context)
+void precompiledTargetModule2ExternalLinkTestImpl(IDevice* device, UnitTestContext* context)
{
- precompiledModule2TestImplCommon(device, context, true);
+ precompiledModule2TestImplCommon(device, context, PrecompilationMode::ExternalLink);
}
SLANG_UNIT_TEST(precompiledModule2D3D12)
@@ -221,19 +239,42 @@ SLANG_UNIT_TEST(precompiledModule2D3D12)
runTestImpl(precompiledModule2TestImpl, unitTestContext, Slang::RenderApiFlag::D3D12);
}
-SLANG_UNIT_TEST(precompiledTargetModule2D3D12)
+SLANG_UNIT_TEST(precompiledTargetModuleInternalLink2D3D12)
{
- runTestImpl(precompiledTargetModule2TestImpl, unitTestContext, Slang::RenderApiFlag::D3D12);
+ runTestImpl(
+ precompiledTargetModule2InternalLinkTestImpl,
+ unitTestContext,
+ Slang::RenderApiFlag::D3D12);
}
+/*
+// Unavailable on D3D12/DXIL currently
+SLANG_UNIT_TEST(precompiledTargetModuleExternalLink2D3D12)
+{
+ runTestImpl(precompiledTargetModule2ExternalLinkTestImpl, unitTestContext,
+Slang::RenderApiFlag::D3D12);
+}
+*/
+
SLANG_UNIT_TEST(precompiledModule2Vulkan)
{
runTestImpl(precompiledModule2TestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan);
}
-SLANG_UNIT_TEST(precompiledTargetModule2Vulkan)
+SLANG_UNIT_TEST(precompiledTargetModule2InternalLinkVulkan)
+{
+ runTestImpl(
+ precompiledTargetModule2InternalLinkTestImpl,
+ unitTestContext,
+ Slang::RenderApiFlag::Vulkan);
+}
+
+SLANG_UNIT_TEST(precompiledTargetModule2ExternalLinkVulkan)
{
- runTestImpl(precompiledTargetModule2TestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ runTestImpl(
+ precompiledTargetModule2ExternalLinkTestImpl,
+ unitTestContext,
+ Slang::RenderApiFlag::Vulkan);
}
} // namespace gfx_test
diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp
index ae4ea3eef..edd271ecb 100644
--- a/tools/gfx/renderer-shared.cpp
+++ b/tools/gfx/renderer-shared.cpp
@@ -1133,11 +1133,13 @@ Result ShaderProgramBase::compileShaders(RendererBase* device)
kernelCodes.add(downstreamIR);
}
- // If target precompilation was used, kernelCode may only represent the
- // glue code holding together the bits of precompiled target IR.
- // Collect those dependency target IRs too.
+ // If target precompilation with deferred downstream linking is enabled,
+ // kernelCode may only represent the glue code holding together the
+ // bits of precompiled target IR. It's the application's job to pull it
+ // together. Collect those dependency target IRs too.
ComPtr<slang::IModulePrecompileService_Experimental> componentPrecompileService;
- if (entryPointComponent->queryInterface(
+ if (this->desc.downstreamLinkMode == DownstreamLinkMode::Deferred &&
+ entryPointComponent->queryInterface(
slang::IModulePrecompileService_Experimental::getTypeGuid(),
(void**)componentPrecompileService.writeRef()) == SLANG_OK)
{
diff --git a/tools/gfx/vulkan/vk-device.h b/tools/gfx/vulkan/vk-device.h
index 27cc4c3ce..06e19ad7e 100644
--- a/tools/gfx/vulkan/vk-device.h
+++ b/tools/gfx/vulkan/vk-device.h
@@ -223,6 +223,9 @@ public:
VkSampler m_defaultSampler;
RefPtr<FramebufferImpl> m_emptyFramebuffer;
+
+ // If true, slang will skip downstream linking, so we need to do it ourselves
+ bool m_skipsDownstreamLinking = false;
};
} // namespace vk
diff --git a/tools/gfx/vulkan/vk-shader-program.cpp b/tools/gfx/vulkan/vk-shader-program.cpp
index 1627c95a7..2dd9b0326 100644
--- a/tools/gfx/vulkan/vk-shader-program.cpp
+++ b/tools/gfx/vulkan/vk-shader-program.cpp
@@ -74,10 +74,20 @@ Result ShaderProgramImpl::createShaderModule(
slang::EntryPointReflection* entryPointInfo,
List<ComPtr<ISlangBlob>>& kernelCodes)
{
- ComPtr<ISlangBlob> linkedKernel = m_device->m_glslang.linkSPIRV(kernelCodes);
- if (!linkedKernel)
+ ComPtr<ISlangBlob> linkedKernel;
+ ComPtr<slang::ISession> slangSession;
+ m_device->getSlangSession(slangSession.writeRef());
+ if (kernelCodes.getCount() == 1)
{
- return SLANG_FAIL;
+ linkedKernel = kernelCodes[0];
+ }
+ else
+ {
+ linkedKernel = m_device->m_glslang.linkSPIRV(kernelCodes);
+ if (!linkedKernel)
+ {
+ return SLANG_FAIL;
+ }
}
m_codeBlobs.add(linkedKernel);