summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-02-15 00:05:51 -0800
committerGitHub <noreply@github.com>2024-02-15 00:05:51 -0800
commit5a623ec227726ad1d988a5d91f55f19b62a98e03 (patch)
tree94a3fd2f00ce1a95035f39cd3571c9e97a70d24e
parent2ced683f10fb82f63a2e2c3d7b5f099c53bb57b0 (diff)
Support loading serialized modules. (#3588)
* Support loading serialized modules. * Fix. * Fix vs solution files * Fix glsl module loading. * C++ fix. * Fix. * Try fix c++ error. * Try fix. * Fix. * Fix.
-rw-r--r--build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj5
-rw-r--r--build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters15
-rw-r--r--slang.h27
-rw-r--r--source/slang/slang-ast-base.h2
-rw-r--r--source/slang/slang-ast-decl.h5
-rw-r--r--source/slang/slang-check-decl.cpp13
-rw-r--r--source/slang/slang-check-expr.cpp8
-rw-r--r--source/slang/slang-check-impl.h7
-rw-r--r--source/slang/slang-check-shader.cpp52
-rw-r--r--source/slang/slang-compiler.cpp112
-rwxr-xr-xsource/slang/slang-compiler.h61
-rw-r--r--source/slang/slang-module-library.cpp2
-rw-r--r--source/slang/slang-parser.cpp5
-rw-r--r--source/slang/slang-serialize-ast-type-info.h4
-rw-r--r--source/slang/slang-serialize-container.cpp11
-rw-r--r--source/slang/slang-serialize-container.h2
-rw-r--r--source/slang/slang-serialize-factory.cpp11
-rw-r--r--source/slang/slang-serialize-factory.h1
-rw-r--r--source/slang/slang-serialize-types.h1
-rw-r--r--source/slang/slang-serialize.cpp167
-rw-r--r--source/slang/slang-serialize.h23
-rw-r--r--source/slang/slang-syntax.h11
-rw-r--r--source/slang/slang.cpp202
-rw-r--r--tools/gfx-unit-test/gfx-test-util.cpp62
-rw-r--r--tools/gfx-unit-test/gfx-test-util.h10
-rw-r--r--tools/gfx-unit-test/precompiled-module-2.cpp181
-rw-r--r--tools/gfx-unit-test/precompiled-module-imported.slang11
-rw-r--r--tools/gfx-unit-test/precompiled-module-included.slang9
-rw-r--r--tools/gfx-unit-test/precompiled-module.cpp160
-rw-r--r--tools/gfx-unit-test/precompiled-module.slang14
30 files changed, 1063 insertions, 131 deletions
diff --git a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj
index 335fa8156..df2f9b4dd 100644
--- a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj
+++ b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj
@@ -306,6 +306,8 @@
<ClCompile Include="..\..\..\tools\gfx-unit-test\link-time-constant.cpp" />
<ClCompile Include="..\..\..\tools\gfx-unit-test\mutable-shader-object.cpp" />
<ClCompile Include="..\..\..\tools\gfx-unit-test\nested-parameter-block.cpp" />
+ <ClCompile Include="..\..\..\tools\gfx-unit-test\precompiled-module-2.cpp" />
+ <ClCompile Include="..\..\..\tools\gfx-unit-test\precompiled-module.cpp" />
<ClCompile Include="..\..\..\tools\gfx-unit-test\ray-tracing-tests.cpp" />
<ClCompile Include="..\..\..\tools\gfx-unit-test\resolve-resource-tests.cpp" />
<ClCompile Include="..\..\..\tools\gfx-unit-test\root-mutable-shader-object.cpp" />
@@ -327,6 +329,9 @@
<None Include="..\..\..\tools\gfx-unit-test\link-time-constant.slang" />
<None Include="..\..\..\tools\gfx-unit-test\mutable-shader-object.slang" />
<None Include="..\..\..\tools\gfx-unit-test\nested-parameter-block.slang" />
+ <None Include="..\..\..\tools\gfx-unit-test\precompiled-module-imported.slang" />
+ <None Include="..\..\..\tools\gfx-unit-test\precompiled-module-included.slang" />
+ <None Include="..\..\..\tools\gfx-unit-test\precompiled-module.slang" />
<None Include="..\..\..\tools\gfx-unit-test\ray-tracing-test-shaders.slang" />
<None Include="..\..\..\tools\gfx-unit-test\resolve-resource-shader.slang" />
<None Include="..\..\..\tools\gfx-unit-test\root-shader-parameter.slang" />
diff --git a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters
index fa5ca30e5..9ab4fb3c4 100644
--- a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters
+++ b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters
@@ -74,6 +74,12 @@
<ClCompile Include="..\..\..\tools\gfx-unit-test\nested-parameter-block.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\tools\gfx-unit-test\precompiled-module-2.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="..\..\..\tools\gfx-unit-test\precompiled-module.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\tools\gfx-unit-test\ray-tracing-tests.cpp">
<Filter>Source Files</Filter>
</ClCompile>
@@ -133,6 +139,15 @@
<None Include="..\..\..\tools\gfx-unit-test\nested-parameter-block.slang">
<Filter>Source Files</Filter>
</None>
+ <None Include="..\..\..\tools\gfx-unit-test\precompiled-module-imported.slang">
+ <Filter>Source Files</Filter>
+ </None>
+ <None Include="..\..\..\tools\gfx-unit-test\precompiled-module-included.slang">
+ <Filter>Source Files</Filter>
+ </None>
+ <None Include="..\..\..\tools\gfx-unit-test\precompiled-module.slang">
+ <Filter>Source Files</Filter>
+ </None>
<None Include="..\..\..\tools\gfx-unit-test\ray-tracing-test-shaders.slang">
<Filter>Source Files</Filter>
</None>
diff --git a/slang.h b/slang.h
index 54466ca04..b9da7df41 100644
--- a/slang.h
+++ b/slang.h
@@ -4444,6 +4444,17 @@ namespace slang
ITypeConformance** outConformance,
SlangInt conformanceIdOverride,
ISlangBlob** outDiagnostics) = 0;
+
+ /** Load a module from a Slang module blob.
+ */
+ virtual SLANG_NO_THROW IModule* SLANG_MCALL loadModuleFromIRBlob(
+ const char* moduleName,
+ const char* path,
+ slang::IBlob* source,
+ slang::IBlob** outDiagnostics = nullptr) = 0;
+
+ virtual SLANG_NO_THROW SlangInt SLANG_MCALL getLoadedModuleCount() = 0;
+ virtual SLANG_NO_THROW IModule* SLANG_MCALL getLoadedModule(SlangInt index) = 0;
};
#define SLANG_UUID_ISession ISession::getTypeGuid()
@@ -4691,6 +4702,22 @@ namespace slang
/// Get the name of an entry point defined in the module.
virtual SLANG_NO_THROW SlangResult SLANG_MCALL
getDefinedEntryPoint(SlangInt32 index, IEntryPoint** outEntryPoint) = 0;
+
+ /// Get a serialized representation of the checked module.
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL serialize(ISlangBlob** outSerializedBlob) = 0;
+
+ /// Write the serialized representation of this module to a file.
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL writeToFile(char const* fileName) = 0;
+
+ /// Get the name of the module.
+ virtual SLANG_NO_THROW const char* SLANG_MCALL getName() = 0;
+
+ /// Get the path of the module.
+ virtual SLANG_NO_THROW const char* SLANG_MCALL getFilePath() = 0;
+
+ /// Get the unique identity of the module.
+ virtual SLANG_NO_THROW const char* SLANG_MCALL getUniqueIdentity() = 0;
+
};
#define SLANG_UUID_IModule IModule::getTypeGuid()
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h
index e11dbe259..b16eb5ddf 100644
--- a/source/slang/slang-ast-base.h
+++ b/source/slang/slang-ast-base.h
@@ -104,10 +104,10 @@ struct Scope : public NodeBase
// but the opposite it allowed.
ContainerDecl* containerDecl = nullptr;
- SLANG_UNREFLECTED
// The parent of this scope (where lookup should go if nothing is found locally)
Scope* parent = nullptr;
+ SLANG_UNREFLECTED
// The next sibling of this scope (a peer for lookup)
Scope* nextSibling = nullptr;
};
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index 8d598c474..5d8023eaf 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -633,4 +633,9 @@ InterfaceDecl* findParentInterfaceDecl(Decl* decl);
bool isLocalVar(const Decl* decl);
+
+// Add a sibling lookup scope for `dest` to refer to `source`.
+void addSiblingScopeForContainerDecl(ASTBuilder* builder, ContainerDecl* dest, ContainerDecl* source);
+void addSiblingScopeForContainerDecl(ASTBuilder* builder, Scope* destScope, ContainerDecl* source);
+
} // namespace Slang
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 2a6ee8abc..f1409efe1 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -820,6 +820,11 @@ namespace Slang
return as<NamespaceDeclBase>(parentDecl) != nullptr || as<FileDecl>(parentDecl) != nullptr;
}
+ bool isUnsafeForceInlineFunc(FunctionDeclBase* funcDecl)
+ {
+ return funcDecl->hasModifier<UnsafeForceInlineEarlyAttribute>();
+ }
+
/// Is `decl` a global shader parameter declaration?
bool isGlobalShaderParameter(VarDeclBase* decl)
{
@@ -7302,7 +7307,7 @@ namespace Slang
{
// Create a new sub-scope to wire the module
// into our lookup chain.
- addSiblingScopeForContainerDecl(scope, fileDecl);
+ addSiblingScopeForContainerDecl(getASTBuilder(), scope, fileDecl);
}
void SemanticsVisitor::importModuleIntoScope(Scope* scope, ModuleDecl* moduleDecl)
@@ -7325,7 +7330,7 @@ namespace Slang
if (moduleScope->containerDecl != moduleDecl && moduleScope->containerDecl->parentDecl != moduleDecl)
continue;
- addSiblingScopeForContainerDecl(scope, moduleScope->containerDecl);
+ addSiblingScopeForContainerDecl(getASTBuilder(), scope, moduleScope->containerDecl);
}
// Also import any modules from nested `import` declarations
@@ -7547,7 +7552,7 @@ namespace Slang
if (addedScopes.add(s->containerDecl))
{
scopesAdded = true;
- addSiblingScopeForContainerDecl(scope, s->containerDecl);
+ addSiblingScopeForContainerDecl(getASTBuilder(), scope, s->containerDecl);
}
}
};
@@ -7608,7 +7613,7 @@ namespace Slang
{
ensureDecl(ns, DeclCheckState::ScopesWired);
}
- addSiblingScopeForContainerDecl(decl, otherNamespace);
+ addSiblingScopeForContainerDecl(getASTBuilder(), decl, otherNamespace);
}
}
// For file decls, we need to continue searching up in the parent module scope.
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index c04cb73e4..e9de74d8e 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -248,14 +248,14 @@ namespace Slang
return SourceLoc();
}
- void SemanticsVisitor::addSiblingScopeForContainerDecl(ContainerDecl* dest, ContainerDecl* source)
+ void addSiblingScopeForContainerDecl(ASTBuilder* builder, ContainerDecl* dest, ContainerDecl* source)
{
- addSiblingScopeForContainerDecl(dest->ownedScope, source);
+ addSiblingScopeForContainerDecl(builder, dest->ownedScope, source);
}
- void SemanticsVisitor::addSiblingScopeForContainerDecl(Scope* destScope, ContainerDecl* source)
+ void addSiblingScopeForContainerDecl(ASTBuilder* builder, Scope* destScope, ContainerDecl* source)
{
- auto subScope = getASTBuilder()->create<Scope>();
+ auto subScope = builder->create<Scope>();
subScope->containerDecl = source;
subScope->nextSibling = destScope->nextSibling;
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 1808274f3..28ed47c53 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -17,6 +17,8 @@ namespace Slang
bool isGlobalDecl(Decl* decl);
+ bool isUnsafeForceInlineFunc(FunctionDeclBase* funcDecl);
+
bool isUniformParameterType(Type* type);
Type* checkProperType(
@@ -1048,11 +1050,6 @@ namespace Slang
Scope* getScope(SyntaxNode* node);
- // Add a sibling lookup scope for `dest` to refer to `source`.
- void addSiblingScopeForContainerDecl(ContainerDecl* dest, ContainerDecl* source);
- void addSiblingScopeForContainerDecl(Scope* destScope, ContainerDecl* source);
-
-
void diagnoseDeprecatedDeclRefUsage(DeclRef<Decl> declRef, SourceLoc loc, Expr* originalExpr);
DeclRef<Decl> getDefaultDeclRef(Decl* decl)
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index 2e854554e..7a39f114b 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -905,58 +905,10 @@ namespace Slang
// should work for typical HLSL code.
//
Index translationUnitCount = translationUnits.getCount();
- for(Index tt = 0; tt < translationUnitCount; ++tt)
+ for (Index tt = 0; tt < translationUnitCount; ++tt)
{
auto translationUnit = translationUnits[tt];
- for( auto globalDecl : translationUnit->getModuleDecl()->members )
- {
- auto maybeFuncDecl = globalDecl;
- if( auto genericDecl = as<GenericDecl>(maybeFuncDecl) )
- {
- maybeFuncDecl = genericDecl->inner;
- }
-
- auto funcDecl = as<FuncDecl>(maybeFuncDecl);
- if(!funcDecl)
- continue;
-
- auto entryPointAttr = funcDecl->findModifier<EntryPointAttribute>();
- if(!entryPointAttr)
- continue;
-
- // We've discovered a valid entry point. It is a function (possibly
- // generic) that has a `[shader(...)]` attribute to mark it as an
- // entry point.
- //
- // We will now register that entry point as an `EntryPoint`
- // with an appropriately chosen profile.
- //
- // The profile will only include a stage, so that the profile "family"
- // and "version" are left unspecified. Downstream code will need
- // to be able to handle this case.
- //
- Profile profile;
- profile.setStage(entryPointAttr->stage);
-
- RefPtr<EntryPoint> entryPoint = EntryPoint::create(
- linkage,
- makeDeclRef(funcDecl),
- profile);
-
- validateEntryPoint(entryPoint, sink);
-
- // Note: in the case that the user didn't explicitly
- // specify entry points and we are instead compiling
- // a shader "library," then we do not want to automatically
- // combine the entry points into groups in the generated
- // `Program`, since that would be slightly too magical.
- //
- // Instead, each entry point will end up in a singleton
- // group, so that its entry-point parameters lay out
- // independent of the others.
- //
- translationUnit->module->_addEntryPoint(entryPoint);
- }
+ translationUnit->getModule()->_discoverEntryPoints(sink);
}
}
}
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index f8ad95108..c77f736bc 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -2417,4 +2417,116 @@ namespace Slang
}
return false;
}
+
+ SLANG_NO_THROW SlangResult SLANG_MCALL Module::serialize(ISlangBlob** outSerializedBlob)
+ {
+ SerialContainerUtil::WriteOptions writeOptions;
+ writeOptions.sourceManager = getLinkage()->getSourceManager();
+ OwnedMemoryStream memoryStream(FileAccess::Write);
+ SLANG_RETURN_ON_FAIL(SerialContainerUtil::write(this, writeOptions, &memoryStream));
+ *outSerializedBlob = RawBlob::create(
+ memoryStream.getContents().getBuffer(),
+ (size_t)memoryStream.getContents().getCount()).detach();
+ return SLANG_OK;
+ }
+
+ SLANG_NO_THROW SlangResult SLANG_MCALL Module::writeToFile(char const* fileName)
+ {
+ SerialContainerUtil::WriteOptions writeOptions;
+ writeOptions.sourceManager = getLinkage()->getSourceManager();
+ FileStream fileStream;
+ SLANG_RETURN_ON_FAIL(fileStream.init(fileName, FileMode::Create));
+ return SerialContainerUtil::write(this, writeOptions, &fileStream);
+ }
+
+ SLANG_NO_THROW const char* SLANG_MCALL Module::getName()
+ {
+ if (m_name)
+ return m_name->text.getBuffer();
+ return nullptr;
+ }
+
+ SLANG_NO_THROW const char* SLANG_MCALL Module::getFilePath()
+ {
+ if (m_pathInfo.hasFoundPath())
+ return m_pathInfo.foundPath.getBuffer();
+ return nullptr;
+ }
+
+ SLANG_NO_THROW const char* SLANG_MCALL Module::getUniqueIdentity()
+ {
+ if (m_pathInfo.hasUniqueIdentity())
+ return m_pathInfo.getMostUniqueIdentity().getBuffer();
+ return nullptr;
+ }
+
+ void validateEntryPoint(
+ EntryPoint* entryPoint,
+ DiagnosticSink* sink);
+
+ void Module::_discoverEntryPoints(DiagnosticSink* sink)
+ {
+ for (auto globalDecl : m_moduleDecl->members)
+ {
+ auto maybeFuncDecl = globalDecl;
+ if (auto genericDecl = as<GenericDecl>(maybeFuncDecl))
+ {
+ maybeFuncDecl = genericDecl->inner;
+ }
+
+ auto funcDecl = as<FuncDecl>(maybeFuncDecl);
+ if (!funcDecl)
+ continue;
+
+ Profile profile;
+
+ auto entryPointAttr = funcDecl->findModifier<EntryPointAttribute>();
+ if (entryPointAttr)
+ {
+ // We've discovered a valid entry point. It is a function (possibly
+ // generic) that has a `[shader(...)]` attribute to mark it as an
+ // entry point.
+ //
+ // We will now register that entry point as an `EntryPoint`
+ // with an appropriately chosen profile.
+ //
+ // The profile will only include a stage, so that the profile "family"
+ // and "version" are left unspecified. Downstream code will need
+ // to be able to handle this case.
+ //
+ profile.setStage(entryPointAttr->stage);
+ }
+ else
+ {
+ // If there isn't a [shader] attribute, look for a [numthreads] attribute
+ // since that implicitly means a compute shader.
+ auto numThreadsAttr = funcDecl->findModifier<NumThreadsAttribute>();
+ if (numThreadsAttr)
+ profile.setStage(Stage::Compute);
+ else
+ continue;
+ }
+
+ RefPtr<EntryPoint> entryPoint = EntryPoint::create(
+ getLinkage(),
+ makeDeclRef(funcDecl),
+ profile);
+
+ validateEntryPoint(entryPoint, sink);
+
+ // Note: in the case that the user didn't explicitly
+ // specify entry points and we are instead compiling
+ // a shader "library," then we do not want to automatically
+ // combine the entry points into groups in the generated
+ // `Program`, since that would be slightly too magical.
+ //
+ // Instead, each entry point will end up in a singleton
+ // group, so that its entry-point parameters lay out
+ // independent of the others.
+ //
+ _addEntryPoint(entryPoint);
+ }
+ }
+
}
+
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 18a91f78a..d93eff0a6 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -1325,6 +1325,22 @@ namespace Slang
return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash);
}
+ /// Get a serialized representation of the checked module.
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL serialize(ISlangBlob** outSerializedBlob) override;
+
+ /// Write the serialized representation of this module to a file.
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL writeToFile(char const* fileName) override;
+
+ /// Get the name of the module.
+ virtual SLANG_NO_THROW const char* SLANG_MCALL getName() override;
+
+ /// Get the path of the module.
+ virtual SLANG_NO_THROW const char* SLANG_MCALL getFilePath() override;
+
+ /// Get the unique identity of the module.
+ virtual SLANG_NO_THROW const char* SLANG_MCALL getUniqueIdentity() override;
+
+
virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE;
/// Create a module (initially empty).
@@ -1354,6 +1370,10 @@ namespace Slang
///
void setModuleDecl(ModuleDecl* moduleDecl);// { m_moduleDecl = moduleDecl; }
+ void setName(String name);
+ void setName(Name* name) { m_name = name; }
+ void setPathInfo(PathInfo pathInfo) { m_pathInfo = pathInfo; }
+
/// Set the IR for this module.
///
/// This should only be called once, during creation of the module.
@@ -1395,6 +1415,8 @@ namespace Slang
///
void _collectShaderParams();
+ void _discoverEntryPoints(DiagnosticSink* sink);
+
class ModuleSpecializationInfo : public SpecializationInfo
{
public:
@@ -1426,6 +1448,9 @@ namespace Slang
DiagnosticSink* sink) SLANG_OVERRIDE;
private:
+ Name* m_name = nullptr;
+ PathInfo m_pathInfo;
+
// The AST for the module
ModuleDecl* m_moduleDecl = nullptr;
@@ -1539,6 +1564,13 @@ namespace Slang
Dictionary<String, String> getCombinedPreprocessorDefinitions();
+ void setModuleName(Name* name)
+ {
+ moduleName = name;
+ if (module)
+ module->setName(name);
+ }
+
protected:
void _addSourceFile(SourceFile* sourceFile);
/* Given an artifact, find a PathInfo.
@@ -1730,6 +1762,11 @@ namespace Slang
/// lookup additional loaded modules.
typedef Dictionary<Name*, Module*> LoadedModuleDictionary;
+ enum ModuleBlobType
+ {
+ Source, IR
+ };
+
/// A context for loading and re-using code modules.
class Linkage : public RefObject, public slang::ISession
{
@@ -1742,6 +1779,17 @@ namespace Slang
SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModule(
const char* moduleName,
slang::IBlob** outDiagnostics = nullptr) override;
+ slang::IModule* loadModuleFromBlob(
+ const char* moduleName,
+ const char* path,
+ slang::IBlob* source,
+ ModuleBlobType blobType,
+ slang::IBlob** outDiagnostics = nullptr);
+ SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModuleFromIRBlob(
+ const char* moduleName,
+ const char* path,
+ slang::IBlob* source,
+ slang::IBlob** outDiagnostics = nullptr) override;
SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModuleFromSource(
const char* moduleName,
const char* path,
@@ -1786,6 +1834,8 @@ namespace Slang
ISlangBlob** outDiagnostics) override;
SLANG_NO_THROW SlangResult SLANG_MCALL createCompileRequest(
SlangCompileRequest** outCompileRequest) override;
+ virtual SLANG_NO_THROW SlangInt SLANG_MCALL getLoadedModuleCount() override;
+ virtual SLANG_NO_THROW slang::IModule* SLANG_MCALL getLoadedModule(SlangInt index) override;
// Updates the supplied builder with linkage-related information, which includes preprocessor
// defines, the compiler version, and other compiler options. This is then merged with the hash
@@ -1935,6 +1985,15 @@ namespace Slang
ISlangBlob* fileContentsBlob,
SourceLoc const& loc,
DiagnosticSink* sink,
+ const LoadedModuleDictionary* additionalLoadedModules,
+ ModuleBlobType blobType);
+
+ RefPtr<Module> loadModuleFromIRBlobImpl(
+ Name* name,
+ const PathInfo& filePathInfo,
+ ISlangBlob* fileContentsBlob,
+ SourceLoc const& loc,
+ DiagnosticSink* sink,
const LoadedModuleDictionary* additionalLoadedModules);
void loadParsedModule(
@@ -1952,6 +2011,8 @@ namespace Slang
DiagnosticSink* sink,
const LoadedModuleDictionary* loadedModules = nullptr);
+ void prepareDeserializedModule(Module* module, DiagnosticSink* sink);
+
SourceFile* findFile(Name* name, SourceLoc loc, IncludeSystem& outIncludeSystem);
struct IncludeResult
{
diff --git a/source/slang/slang-module-library.cpp b/source/slang/slang-module-library.cpp
index b7290008b..0975d6e8f 100644
--- a/source/slang/slang-module-library.cpp
+++ b/source/slang/slang-module-library.cpp
@@ -66,7 +66,7 @@ SlangResult loadModuleLibrary(const Byte* inBytes, size_t bytesCount, EndToEndCo
options.linkage = req->getLinkage();
options.sink = req->getSink();
- SLANG_RETURN_ON_FAIL(SerialContainerUtil::read(&riffContainer, options, containerData));
+ SLANG_RETURN_ON_FAIL(SerialContainerUtil::read(&riffContainer, options, nullptr, containerData));
for (const auto& module : containerData.modules)
{
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 45fa5a125..477b43726 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -1280,21 +1280,24 @@ namespace Slang
Parser* parser, void* /*userData*/)
{
auto decl = parser->astBuilder->create<ModuleDeclarationDecl>();
+ auto moduleDecl = parser->getCurrentModuleDecl();
if (parser->LookAheadToken(TokenType::Identifier))
{
auto nameToken = parser->ReadToken(TokenType::Identifier);
decl->nameAndLoc.name = parser->getNamePool()->getName(nameToken.getContent());
decl->nameAndLoc.loc = nameToken.loc;
+ if (moduleDecl) moduleDecl->nameAndLoc = decl->nameAndLoc;
}
else if (parser->LookAheadToken(TokenType::StringLiteral))
{
auto nameToken = parser->ReadToken(TokenType::StringLiteral);
decl->nameAndLoc.name = parser->getNamePool()->getName(getStringLiteralTokenValue(nameToken));
decl->nameAndLoc.loc = nameToken.loc;
+ if (moduleDecl) moduleDecl->nameAndLoc = decl->nameAndLoc;
}
else
{
- if (auto moduleDecl = parser->getCurrentModuleDecl())
+ if (moduleDecl)
decl->nameAndLoc.name = moduleDecl->getName();
decl->nameAndLoc.loc = parser->tokenReader.peekLoc();
}
diff --git a/source/slang/slang-serialize-ast-type-info.h b/source/slang/slang-serialize-ast-type-info.h
index f7d8cab08..96c8a438f 100644
--- a/source/slang/slang-serialize-ast-type-info.h
+++ b/source/slang/slang-serialize-ast-type-info.h
@@ -52,8 +52,8 @@ inline void serializeValPointerValue(SerialWriter* writer, Val* ptrValue, Serial
inline void deserializeValPointerValue(SerialReader* reader, const SerialIndex* inSerial, void* outPtr)
{
- auto val = reader->getPointer(*(const SerialIndex*)inSerial).dynamicCast<Val>();
- *(Val**)outPtr = val;
+ auto val = reader->getValPointer(*(const SerialIndex*)inSerial);
+ *(void**)outPtr = val.m_ptr;
}
template<typename T>
diff --git a/source/slang/slang-serialize-container.cpp b/source/slang/slang-serialize-container.cpp
index 175f970c9..6a75064ab 100644
--- a/source/slang/slang-serialize-container.cpp
+++ b/source/slang/slang-serialize-container.cpp
@@ -213,7 +213,11 @@ namespace Slang {
}
ModuleSerialFilter filter(moduleDecl);
- SerialWriter writer(serialClasses, &filter);
+ auto astWriterFlag = SerialWriter::Flag::ZeroInitialize;
+ if ((options.optionFlags & SerialOptionFlag::ASTFunctionBody) == 0)
+ astWriterFlag = (SerialWriter::Flag::Enum)(astWriterFlag | SerialWriter::Flag::SkipFunctionBody);
+
+ SerialWriter writer(serialClasses, &filter, astWriterFlag);
writer.getExtraObjects().set(sourceLocWriter);
@@ -300,7 +304,7 @@ static List<ExtensionDecl*>& _getCandidateExtensionList(
return entry->candidateExtensions;
}
-/* static */Result SerialContainerUtil::read(RiffContainer* container, const ReadOptions& options, SerialContainerData& out)
+/* static */Result SerialContainerUtil::read(RiffContainer* container, const ReadOptions& options, const LoadedModuleDictionary* additionalLoadedModules, SerialContainerData& out)
{
out.clear();
@@ -441,7 +445,7 @@ static List<ExtensionDecl*>& _getCandidateExtensionList(
NamePool* namePool = linkage->getNamePool();
Name* moduleNameName = namePool->getName(moduleName);
- readModule = linkage->findOrImportModule(moduleNameName, SourceLoc::fromRaw(0), options.sink);
+ readModule = linkage->findOrImportModule(moduleNameName, SourceLoc::fromRaw(0), options.sink, additionalLoadedModules);
if (!readModule)
{
return SLANG_FAIL;
@@ -570,7 +574,6 @@ static List<ExtensionDecl*>& _getCandidateExtensionList(
else if (Val* val = dynamicCast<Val>(nodeBase))
{
val->_setUnique();
- astBuilder->m_cachedNodes.tryGetValueOrAdd(ValKey(val), val);
}
}
}
diff --git a/source/slang/slang-serialize-container.h b/source/slang/slang-serialize-container.h
index 7cbc97aa2..a2b596a24 100644
--- a/source/slang/slang-serialize-container.h
+++ b/source/slang/slang-serialize-container.h
@@ -107,7 +107,7 @@ struct SerialContainerUtil
static SlangResult write(const SerialContainerData& data, const WriteOptions& options, RiffContainer* container);
/// Read the container into outData
- static SlangResult read(RiffContainer* container, const ReadOptions& options, SerialContainerData& outData);
+ static SlangResult read(RiffContainer* container, const ReadOptions& options, const LoadedModuleDictionary* additionalLoadedModules, SerialContainerData& outData);
/// Verify IR serialization
static SlangResult verifyIRSerialize(IRModule* module, Session* session, const WriteOptions& options);
diff --git a/source/slang/slang-serialize-factory.cpp b/source/slang/slang-serialize-factory.cpp
index 351742e60..5eae5e740 100644
--- a/source/slang/slang-serialize-factory.cpp
+++ b/source/slang/slang-serialize-factory.cpp
@@ -45,6 +45,11 @@ void* DefaultSerialObjectFactory::create(SerialTypeKind typeKind, SerialSubType
return nullptr;
}
+void* DefaultSerialObjectFactory::getOrCreateVal(ValNodeDesc&& desc)
+{
+ return m_astBuilder->_getOrCreateImpl(_Move(desc));
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ModuleSerialFilter !!!!!!!!!!!!!!!!!!!!!!!!
SerialIndex ModuleSerialFilter::writePointer(SerialWriter* writer, const RefObject* inPtr)
@@ -65,12 +70,6 @@ SerialIndex ModuleSerialFilter::writePointer(SerialWriter* writer, const NodeBas
NodeBase* ptr = const_cast<NodeBase*>(inPtr);
SLANG_ASSERT(ptr);
- // We don't serialize Scope
- if (as<Scope>(ptr))
- {
- writer->setPointerIndex(inPtr, SerialIndex(0));
- return SerialIndex(0);
- }
if (Decl* decl = as<Decl>(ptr))
{
diff --git a/source/slang/slang-serialize-factory.h b/source/slang/slang-serialize-factory.h
index 7e51a840d..caa2f9785 100644
--- a/source/slang/slang-serialize-factory.h
+++ b/source/slang/slang-serialize-factory.h
@@ -15,6 +15,7 @@ class DefaultSerialObjectFactory : public SerialObjectFactory
public:
virtual void* create(SerialTypeKind typeKind, SerialSubType subType) SLANG_OVERRIDE;
+ virtual void* getOrCreateVal(ValNodeDesc&& desc) SLANG_OVERRIDE;
DefaultSerialObjectFactory(ASTBuilder* astBuilder) :
m_astBuilder(astBuilder)
diff --git a/source/slang/slang-serialize-types.h b/source/slang/slang-serialize-types.h
index fb0dd2f9d..7623eb4cf 100644
--- a/source/slang/slang-serialize-types.h
+++ b/source/slang/slang-serialize-types.h
@@ -30,6 +30,7 @@ struct SerialOptionFlag
SourceLocation = 0x02, ///< If set will output SourceLoc information, that can be reconstructed when read after being stored.
ASTModule = 0x04, ///< If set will output AST modules - typically required, but potentially not desired (for example with obsfucation)
IRModule = 0x08, ///< If set will output IR modules - typically required
+ ASTFunctionBody = 0x10, ///< If set will serialize AST function bodies.
};
};
typedef SerialOptionFlag::Type SerialOptionFlags;
diff --git a/source/slang/slang-serialize.cpp b/source/slang/slang-serialize.cpp
index 1f5b6942d..1c8abb8f9 100644
--- a/source/slang/slang-serialize.cpp
+++ b/source/slang/slang-serialize.cpp
@@ -3,6 +3,7 @@
#include "slang-ast-base.h"
#include "slang-ast-builder.h"
+#include "slang-check-impl.h"
namespace Slang {
@@ -222,6 +223,50 @@ SerialWriter::SerialWriter(SerialClasses* classes, SerialFilter* filter, Flags f
m_ptrMap.add(nullptr, 0);
}
+struct SkipFunctionBodyRAII
+{
+ FunctionDeclBase* funcDecl = nullptr;
+ Stmt* oldBody = nullptr;
+ SkipFunctionBodyRAII(SerialWriter::Flags flags, const SerialClass* serialCls, const void* ptr)
+ {
+ if ((flags & SerialWriter::Flag::SkipFunctionBody) == 0)
+ return;
+
+ if (serialCls->typeKind != SerialTypeKind::NodeBase)
+ return;
+ auto cls = serialCls;
+ while (cls)
+ {
+ auto astNodeType = (ASTNodeType)cls->subType;
+ if (astNodeType == ASTNodeType::FunctionDeclBase)
+ {
+ funcDecl = (FunctionDeclBase*)ptr;
+ break;
+ }
+ cls = cls->super;
+ }
+ if (funcDecl)
+ {
+ oldBody = funcDecl->body;
+ // We always need to include body of unsafeForceInlineEarly functions
+ // since they will need to be available at IR lowering time of the
+ // user module for pre-linking inling.
+ if (!isUnsafeForceInlineFunc(funcDecl))
+ {
+ funcDecl->body = nullptr;
+ }
+ }
+
+ }
+ ~SkipFunctionBodyRAII()
+ {
+ if (funcDecl)
+ {
+ funcDecl->body = oldBody;
+ }
+ }
+};
+
SerialIndex SerialWriter::writeObject(const SerialClass* serialCls, const void* ptr)
{
if (serialCls->flags & SerialClassFlag::DontSerialize)
@@ -229,6 +274,16 @@ SerialIndex SerialWriter::writeObject(const SerialClass* serialCls, const void*
return SerialIndex(0);
}
+ if (serialCls->typeKind == SerialTypeKind::NodeBase &&
+ ReflectClassInfo::isSubClassOf(serialCls->subType, Val::kReflectClassInfo))
+ {
+ return writeValObject((Val*)ptr);
+ }
+
+ // If we are skipping function bodies, set the body field to nullptr, and
+ // restore it after serialization.
+ SkipFunctionBodyRAII clearFunctionBodyRAII(m_flags, serialCls, ptr);
+
// This pointer cannot be in the map
SLANG_ASSERT(m_ptrMap.tryGetValue(ptr) == nullptr);
@@ -279,6 +334,62 @@ SerialIndex SerialWriter::writeObject(const NodeBase* node)
return writeObject(serialClass, (const void*)node);
}
+SerialIndex SerialWriter::writeValObject(const Val* node)
+{
+ typedef SerialInfo::ValEntry ValEntry;
+
+ size_t size = node->getOperandCount() * sizeof(SerialInfo::SerialValOperand);
+ ValEntry* nodeEntry = (ValEntry*)m_arena.allocateAligned(sizeof(ValEntry) + size, SerialInfo::MAX_ALIGNMENT);
+
+ nodeEntry->typeKind = SerialTypeKind::NodeBase;
+ nodeEntry->subType = (SerialSubType)node->astNodeType;
+ nodeEntry->operandCount = (uint32_t)node->getOperandCount();
+ nodeEntry->info = SerialInfo::makeEntryInfo(SerialInfo::MAX_ALIGNMENT);
+
+ // We add before adding fields, so if the fields point to this, the entry will be set
+ auto index = _add(node, nodeEntry);
+
+ ShortList<SerialIndex, 4> serializedOperands;
+
+ for (Index i = 0; i < node->getOperandCount(); i++)
+ {
+ auto operand = node->m_operands[i];
+ switch (operand.kind)
+ {
+ case ValNodeOperandKind::ConstantValue:
+ serializedOperands.add((SerialIndex)0);
+ break;
+ case ValNodeOperandKind::ValNode:
+ case ValNodeOperandKind::ASTNode:
+ serializedOperands.add(addPointer(operand.values.nodeOperand));
+ break;
+ }
+ }
+
+ SLANG_ASSERT(serializedOperands.getCount() == node->getOperandCount());
+
+ auto serialOperands = (SerialInfo::SerialValOperand*)(nodeEntry + 1);
+ for (Index i = 0; i < node->getOperandCount(); i++)
+ {
+ auto serialOperand = serialOperands + i;
+ auto operand = node->m_operands[i];
+ serialOperand->type = (int)operand.kind;
+ switch (operand.kind)
+ {
+ case ValNodeOperandKind::ConstantValue:
+ serialOperand->payload = operand.values.intOperand;
+ break;
+ case ValNodeOperandKind::ValNode:
+ serialOperand->payload = (uint64_t)serializedOperands[i];
+ break;
+ case ValNodeOperandKind::ASTNode:
+ serialOperand->payload = (uint64_t)serializedOperands[i];
+ break;
+ }
+ }
+ return index;
+}
+
SerialIndex SerialWriter::writeObject(const RefObject* obj)
{
const SerialRefObject* serialObj = as<const SerialRefObject>(obj);
@@ -633,6 +744,9 @@ size_t SerialInfo::Entry::calcSize(SerialClasses* serialClasses) const
auto serialClass = serialClasses->getSerialClass(typeKind, entry->subType);
+ if (ReflectClassInfo::isSubClassOf(entry->subType, Val::kReflectClassInfo))
+ return sizeof(ValEntry) + static_cast<const ValEntry*>(this)->operandCount * sizeof(SerialValOperand);
+
// Align by the alignment of the entry
size_t alignment = getAlignment(entry->info);
size_t size = sizeof(ObjectEntry) + serialClass->size;
@@ -722,6 +836,49 @@ SerialPointer SerialReader::getPointer(SerialIndex index)
return ptr;
}
+SerialPointer SerialReader::getValPointer(SerialIndex index)
+{
+ if (index == SerialIndex(0))
+ {
+ return SerialPointer();
+ }
+
+ SLANG_ASSERT(SerialIndexRaw(index) < SerialIndexRaw(m_entries.getCount()));
+
+ SerialPointer& ptr = m_objects[Index(index)];
+
+ if (ptr.m_ptr)
+ return ptr;
+
+ const SerialInfo::ValEntry* entry = (SerialInfo::ValEntry*)m_entries[Index(index)];
+ ValNodeDesc desc;
+ desc.type = (ASTNodeType)entry->subType;
+ auto readPtr = (SerialInfo::SerialValOperand*)(entry + 1);
+ for (uint32_t i = 0; i < entry->operandCount; i++)
+ {
+ auto serialOperand = readPtr[i];
+ ValNodeOperand operand;
+ operand.kind = (ValNodeOperandKind)(serialOperand.type);
+ switch (operand.kind)
+ {
+ case ValNodeOperandKind::ConstantValue:
+ operand.values.intOperand = serialOperand.payload;
+ break;
+ case ValNodeOperandKind::ASTNode:
+ operand.values.nodeOperand = (NodeBase*)getPointer((SerialIndex)serialOperand.payload).m_ptr;
+ break;
+ case ValNodeOperandKind::ValNode:
+ operand.values.nodeOperand = (Val*)getValPointer((SerialIndex)serialOperand.payload).m_ptr;
+ break;
+ }
+ desc.operands.add(operand);
+ }
+ desc.init();
+ ptr.m_kind = SerialTypeKind::NodeBase;
+ ptr.m_ptr = this->m_objectFactory->getOrCreateVal(_Move(desc));
+ return ptr;
+}
+
String SerialReader::getString(SerialIndex index)
{
if (index == SerialIndex(0))
@@ -902,6 +1059,12 @@ SlangResult SerialReader::constructObjects(NamePool* namePool)
case SerialTypeKind::NodeBase:
{
auto objectEntry = static_cast<const SerialInfo::ObjectEntry*>(entry);
+
+ // Don't create object for Vals.
+ if (objectEntry->typeKind == SerialTypeKind::NodeBase &&
+ ReflectClassInfo::isSubClassOf(objectEntry->subType, Val::kReflectClassInfo))
+ break;
+
void* obj = m_objectFactory->create(objectEntry->typeKind, objectEntry->subType);
if (!obj)
{
@@ -912,7 +1075,7 @@ SlangResult SerialReader::constructObjects(NamePool* namePool)
}
case SerialTypeKind::Array:
{
- // Don't need to construct an object, as will be accessed an interpreted by the object that holds it
+ // Don't need to construct an object, as will be accessed and interpreted by the object that holds it
break;
}
}
@@ -944,6 +1107,8 @@ SlangResult SerialReader::deserializeObjects()
{
return SLANG_FAIL;
}
+ if (ReflectClassInfo::isSubClassOf(objectEntry->subType, Val::kReflectClassInfo))
+ continue;
const uint8_t* src = (const uint8_t*)(objectEntry + 1);
uint8_t* dst = (uint8_t*)dstPtr.m_ptr;
diff --git a/source/slang/slang-serialize.h b/source/slang/slang-serialize.h
index 3071dc174..a91ff21e9 100644
--- a/source/slang/slang-serialize.h
+++ b/source/slang/slang-serialize.h
@@ -27,6 +27,8 @@ docs/design/serialization.md
// Predeclare
typedef uint32_t SerialSourceLoc;
class NodeBase;
+class Val;
+struct ValNodeDesc;
// Pre-declare
class SerialClasses;
@@ -119,11 +121,23 @@ struct SerialInfo
uint32_t _pad0; ///< Necessary, because a node *can* have MAX_ALIGNEMENT
};
+ struct ValEntry : Entry
+ {
+ SerialSubType subType;
+ uint32_t operandCount;
+ };
+
struct ArrayEntry : Entry
{
uint16_t elementSize;
uint32_t elementCount;
};
+
+ struct SerialValOperand
+ {
+ int type;
+ uint64_t payload;
+ };
};
typedef uint32_t SerialIndexRaw;
@@ -185,6 +199,7 @@ class SerialObjectFactory
{
public:
virtual void* create(SerialTypeKind typeKind, SerialSubType subType) = 0;
+ virtual void* getOrCreateVal(ValNodeDesc&& desc) = 0;
};
class SerialExtraObjects
@@ -229,6 +244,8 @@ public:
const void* getArray(SerialIndex index, Index& outCount);
SerialPointer getPointer(SerialIndex index);
+ SerialPointer getValPointer(SerialIndex index);
+
String getString(SerialIndex index);
Name* getName(SerialIndex index);
UnownedStringSlice getStringSlice(SerialIndex index);
@@ -329,7 +346,10 @@ public:
/// If set will zero initialize backing memory. This is slower but
/// is desirable to make two serializations of the same thing produce the
/// identical serialized result.
- ZeroInitialize = 0x1
+ ZeroInitialize = 0x1,
+
+ /// If set will not serialize function body.
+ SkipFunctionBody = 0x2,
};
};
@@ -342,6 +362,7 @@ public:
/// Write the object at the pointer
SerialIndex writeObject(const NodeBase* ptr);
SerialIndex writeObject(const RefObject* ptr);
+ SerialIndex writeValObject(const Val* ptr);
/// Add an array - may need to convert to serialized format
template <typename T>
diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h
index 6b134e523..807db9a85 100644
--- a/source/slang/slang-syntax.h
+++ b/source/slang/slang-syntax.h
@@ -295,12 +295,6 @@ namespace Slang
// being in templates, because gcc/clang get angry.
//
template<typename T>
- void FilteredModifierList<T>::Iterator::operator++()
- {
- current = adjust(current->next);
- }
- //
- template<typename T>
Modifier* FilteredModifierList<T>::adjust(Modifier* modifier)
{
Modifier* m = modifier;
@@ -315,6 +309,11 @@ namespace Slang
}
}
+ template<typename T>
+ void FilteredModifierList<T>::Iterator::operator++()
+ {
+ current = FilteredModifierList<T>::adjust(current->next);
+ }
//
enum class UserDefinedAttributeTargets
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 267d91173..395285b41 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -26,7 +26,7 @@
#include "slang-parser.h"
#include "slang-preprocessor.h"
#include "slang-type-layout.h"
-
+#
#include "slang-options.h"
#include "slang-repro.h"
@@ -480,11 +480,12 @@ SlangResult Session::_readBuiltinModule(ISlangFileSystem* fileSystem, Scope* sco
// Hmm - don't have a suitable sink yet, so attempt to just not have one
options.sink = nullptr;
- SLANG_RETURN_ON_FAIL(SerialContainerUtil::read(&riffContainer, options, containerData));
+ SLANG_RETURN_ON_FAIL(SerialContainerUtil::read(&riffContainer, options, nullptr, containerData));
for (auto& srcModule : containerData.modules)
{
RefPtr<Module> module(new Module(linkage, srcModule.astBuilder));
+ module->setName(moduleName);
ModuleDecl* moduleDecl = as<ModuleDecl>(srcModule.astRootNode);
// Set the module back reference on the decl
@@ -1086,10 +1087,11 @@ SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModule(
}
}
-SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModuleFromSource(
+slang::IModule* Linkage::loadModuleFromBlob(
const char* moduleName,
const char* path,
slang::IBlob* source,
+ ModuleBlobType blobType,
slang::IBlob** outDiagnostics)
{
SLANG_AST_BUILDER_RAII(getASTBuilder());
@@ -1124,7 +1126,8 @@ SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModuleFromSource(
source,
SourceLoc(),
&sink,
- nullptr);
+ nullptr,
+ blobType);
sink.getBlobIfNeeded(outDiagnostics);
return asExternal(module);
@@ -1136,6 +1139,24 @@ SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModuleFromSource(
}
}
+SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModuleFromSource(
+ const char* moduleName,
+ const char* path,
+ slang::IBlob* source,
+ slang::IBlob** outDiagnostics)
+{
+ return loadModuleFromBlob(moduleName, path, source, ModuleBlobType::Source, outDiagnostics);
+}
+
+SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModuleFromIRBlob(
+ const char* moduleName,
+ const char* path,
+ slang::IBlob* source,
+ slang::IBlob** outDiagnostics)
+{
+ return loadModuleFromBlob(moduleName, path, source, ModuleBlobType::IR, outDiagnostics);
+}
+
SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createCompositeComponentType(
slang::IComponentType* const* componentTypes,
SlangInt componentTypeCount,
@@ -1398,6 +1419,18 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createCompileRequest(
return SLANG_OK;
}
+SLANG_NO_THROW SlangInt SLANG_MCALL Linkage::getLoadedModuleCount()
+{
+ return loadedModulesList.getCount();
+}
+
+SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::getLoadedModule(SlangInt index)
+{
+ if (index >= 0 && index < loadedModulesList.getCount())
+ return loadedModulesList[index].get();
+ return nullptr;
+}
+
void Linkage::buildHash(DigestBuilder<SHA1>& builder, SlangInt targetIndex)
{
// Add the Slang compiler version to the hash
@@ -2909,8 +2942,7 @@ int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language, Name* mo
translationUnit->compileRequest = this;
translationUnit->sourceLanguage = SourceLanguage(language);
- translationUnit->moduleName = moduleName;
-
+ translationUnit->setModuleName(moduleName);
return addTranslationUnit(translationUnit);
}
@@ -3097,6 +3129,54 @@ void Linkage::loadParsedModule(
loadedModulesList.add(loadedModule);
}
+RefPtr<Module> Linkage::loadModuleFromIRBlobImpl(
+ Name* name,
+ const PathInfo& filePathInfo,
+ ISlangBlob* fileContentsBlob,
+ SourceLoc const& loc,
+ DiagnosticSink* sink,
+ const LoadedModuleDictionary* additionalLoadedModules)
+{
+ RefPtr<Module> resultModule = new Module(this, getASTBuilder());
+ resultModule->setName(name);
+ ModuleBeingImportedRAII moduleBeingImported(
+ this,
+ resultModule,
+ name,
+ loc);
+
+ String mostUniqueIdentity = filePathInfo.getMostUniqueIdentity();
+ SLANG_ASSERT(mostUniqueIdentity.getLength() > 0);
+
+ mapPathToLoadedModule.add(mostUniqueIdentity, resultModule);
+ mapNameToLoadedModules.add(name, resultModule);
+
+ RiffContainer container;
+ MemoryStreamBase readStream(FileAccess::Read, fileContentsBlob->getBufferPointer(), fileContentsBlob->getBufferSize());
+ SLANG_RETURN_NULL_ON_FAIL(RiffUtil::read(&readStream, container));
+ SerialContainerUtil::ReadOptions readOptions;
+ readOptions.linkage = this;
+ readOptions.astBuilder = getASTBuilder();
+ readOptions.session = getSessionImpl();
+ readOptions.sharedASTBuilder = getASTBuilder()->getSharedASTBuilder();
+ readOptions.sink = sink;
+ readOptions.sourceManager = getSourceManager();
+ readOptions.namePool = getNamePool();
+ SerialContainerData containerData;
+ SLANG_RETURN_NULL_ON_FAIL(SerialContainerUtil::read(&container, readOptions, additionalLoadedModules, containerData));
+ if (containerData.modules.getCount() != 1)
+ return nullptr;
+ auto moduleEntry = containerData.modules.getFirst();
+ resultModule->setIRModule(moduleEntry.irModule);
+ resultModule->setModuleDecl(as<ModuleDecl>(moduleEntry.astRootNode));
+
+ prepareDeserializedModule(resultModule, sink);
+
+ loadedModulesList.add(resultModule);
+ resultModule->setPathInfo(filePathInfo);
+ return resultModule;
+}
+
Module* Linkage::loadModule(String const& name)
{
// TODO: We either need to have a diagnostics sink
@@ -3129,15 +3209,19 @@ RefPtr<Module> Linkage::loadModule(
ISlangBlob* sourceBlob,
SourceLoc const& srcLoc,
DiagnosticSink* sink,
- const LoadedModuleDictionary* additionalLoadedModules)
+ const LoadedModuleDictionary* additionalLoadedModules,
+ ModuleBlobType blobType)
{
+ if (blobType == ModuleBlobType::IR)
+ return loadModuleFromIRBlobImpl(name, filePathInfo, sourceBlob, srcLoc, sink, additionalLoadedModules);
+
RefPtr<FrontEndCompileRequest> frontEndReq = new FrontEndCompileRequest(this, nullptr, sink);
frontEndReq->additionalLoadedModules = additionalLoadedModules;
RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(frontEndReq);
translationUnit->compileRequest = frontEndReq;
- translationUnit->moduleName = name;
+ translationUnit->setModuleName(name);
Stage impliedStage;
translationUnit->sourceLanguage = SourceLanguage::Slang;
@@ -3216,6 +3300,7 @@ RefPtr<Module> Linkage::loadModule(
return nullptr;
}
+ module->setPathInfo(filePathInfo);
return module;
}
@@ -3299,7 +3384,6 @@ RefPtr<Module> Linkage::findOrImportModule(
return previouslyLoadedModule;
}
- auto fileName = getFileNameFromModuleName(name);
// Next, try to find the file of the given name,
// using our ordinary include-handling logic.
@@ -3310,46 +3394,61 @@ RefPtr<Module> Linkage::findOrImportModule(
PathInfo pathIncludedFromInfo = getSourceManager()->getPathInfo(loc, SourceLocType::Actual);
PathInfo filePathInfo;
- ComPtr<ISlangBlob> fileContents;
+ auto moduleSourceFileName = getFileNameFromModuleName(name);
- // We have to load via the found path - as that is how file was originally loaded
- if (SLANG_FAILED(includeSystem.findFile(fileName, pathIncludedFromInfo.foundPath, filePathInfo)))
+ // Look for a precompiled module first, if not exist, load from source.
+ for (int checkBinaryModule = 1; checkBinaryModule >= 0; checkBinaryModule--)
{
- if (name && name->text == "glsl")
- {
- // This is a builtin glsl module, just load it from embedded definition.
- fileContents = getSessionImpl()->getGLSLLibraryCode();
- filePathInfo = PathInfo::makeFromString("glsl");
- }
+ String fileName;
+ if (checkBinaryModule == 1)
+ fileName = Path::replaceExt(moduleSourceFileName, "slang-module");
else
+ fileName = moduleSourceFileName;
+
+ ComPtr<ISlangBlob> fileContents;
+
+ // We have to load via the found path - as that is how file was originally loaded
+ if (SLANG_FAILED(includeSystem.findFile(fileName, pathIncludedFromInfo.foundPath, filePathInfo)))
{
- sink->diagnose(loc, Diagnostics::cannotFindFile, fileName);
- mapNameToLoadedModules[name] = nullptr;
- return nullptr;
+ if (name && name->text == "glsl")
+ {
+ // This is a builtin glsl module, just load it from embedded definition.
+ fileContents = getSessionImpl()->getGLSLLibraryCode();
+ filePathInfo = PathInfo::makeFromString("glsl");
+ checkBinaryModule = 0;
+ }
+ else
+ {
+ continue;
+ }
}
- }
- // Maybe this was loaded previously at a different relative name?
- if (mapPathToLoadedModule.tryGetValue(filePathInfo.getMostUniqueIdentity(), loadedModule))
- return loadedModule;
+ // Maybe this was loaded previously at a different relative name?
+ if (mapPathToLoadedModule.tryGetValue(filePathInfo.getMostUniqueIdentity(), loadedModule))
+ return loadedModule;
- // Try to load it
- if( !fileContents && SLANG_FAILED(includeSystem.loadFile(filePathInfo, fileContents)))
- {
- sink->diagnose(loc, Diagnostics::cannotOpenFile, fileName);
- mapNameToLoadedModules[name] = nullptr;
- return nullptr;
+ // Try to load it
+ if (!fileContents && SLANG_FAILED(includeSystem.loadFile(filePathInfo, fileContents)))
+ {
+ continue;
+ }
+
+ // We've found a file that we can load for the given module, so
+ // go ahead and perform the module-load action
+ return loadModule(
+ name,
+ filePathInfo,
+ fileContents,
+ loc,
+ sink,
+ loadedModules,
+ (checkBinaryModule == 1 ? ModuleBlobType::IR : ModuleBlobType::Source));
}
- // We've found a file that we can load for the given module, so
- // go ahead and perform the module-load action
- return loadModule(
- name,
- filePathInfo,
- fileContents,
- loc,
- sink,
- loadedModules);
+ // Error: we cannot find the file.
+ sink->diagnose(loc, Diagnostics::cannotOpenFile, moduleSourceFileName);
+ mapNameToLoadedModules[name] = nullptr;
+ return nullptr;
}
SourceFile* Linkage::findFile(Name* name, SourceLoc loc, IncludeSystem& outIncludeSystem)
@@ -3567,8 +3666,15 @@ void Module::addFileDependency(SourceFile* sourceFile)
void Module::setModuleDecl(ModuleDecl* moduleDecl)
{
m_moduleDecl = moduleDecl;
+ moduleDecl->module = this;
+}
+
+void Module::setName(String name)
+{
+ m_name = getLinkage()->getNamePool()->getName(name);
}
+
RefPtr<EntryPoint> Module::findEntryPointByName(UnownedStringSlice const& name)
{
// TODO: We should consider having this function be expanded to be able
@@ -4828,6 +4934,22 @@ void Linkage::setFileSystem(ISlangFileSystem* inFileSystem)
getSourceManager()->setFileSystemExt(m_fileSystemExt);
}
+void Linkage::prepareDeserializedModule(Module* module, DiagnosticSink* sink)
+{
+ module->_collectShaderParams();
+ module->_discoverEntryPoints(sink);
+
+ // Hook up fileDecl's scope to module's scope.
+ auto moduleDecl = module->getModuleDecl();
+ for (auto globalDecl : moduleDecl->members)
+ {
+ if (auto fileDecl = as<FileDecl>(globalDecl))
+ {
+ addSiblingScopeForContainerDecl(m_astBuilder, moduleDecl->ownedScope, fileDecl);
+ }
+ }
+}
+
void Linkage::setRequireCacheFileSystem(bool requireCacheFileSystem)
{
if (requireCacheFileSystem == m_requireCacheFileSystem)
diff --git a/tools/gfx-unit-test/gfx-test-util.cpp b/tools/gfx-unit-test/gfx-test-util.cpp
index 298283a4a..748ced5eb 100644
--- a/tools/gfx-unit-test/gfx-test-util.cpp
+++ b/tools/gfx-unit-test/gfx-test-util.cpp
@@ -71,6 +71,54 @@ namespace gfx_test
return SLANG_OK;
}
+ Slang::Result loadComputeProgram(
+ gfx::IDevice* device,
+ slang::ISession* slangSession,
+ Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
+ const char* shaderModuleName,
+ const char* entryPointName,
+ slang::ProgramLayout*& slangReflection)
+ {
+ Slang::ComPtr<slang::IBlob> diagnosticsBlob;
+ slang::IModule* module = slangSession->loadModule(shaderModuleName, diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ if (!module)
+ return SLANG_FAIL;
+
+ ComPtr<slang::IEntryPoint> computeEntryPoint;
+ SLANG_RETURN_ON_FAIL(
+ module->findEntryPointByName(entryPointName, computeEntryPoint.writeRef()));
+
+ Slang::List<slang::IComponentType*> componentTypes;
+ componentTypes.add(module);
+ componentTypes.add(computeEntryPoint);
+
+ Slang::ComPtr<slang::IComponentType> composedProgram;
+ SlangResult result = slangSession->createCompositeComponentType(
+ componentTypes.getBuffer(),
+ componentTypes.getCount(),
+ composedProgram.writeRef(),
+ diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ SLANG_RETURN_ON_FAIL(result);
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ result = composedProgram->link(linkedProgram.writeRef(), diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ SLANG_RETURN_ON_FAIL(result);
+
+ composedProgram = linkedProgram;
+ slangReflection = composedProgram->getLayout();
+
+ gfx::IShaderProgram::Desc programDesc = {};
+ programDesc.slangGlobalScope = composedProgram.get();
+
+ auto shaderProgram = device->createProgram(programDesc);
+
+ outShaderProgram = shaderProgram;
+ return SLANG_OK;
+ }
+
Slang::Result loadComputeProgramFromSource(
gfx::IDevice* device,
Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
@@ -222,10 +270,7 @@ namespace gfx_test
SLANG_IGNORE_TEST
}
deviceDesc.slang.slangGlobalSession = context->slangGlobalSession;
- Slang::List<const char*> searchPaths;
- searchPaths.add("");
- searchPaths.add("../../tools/gfx-unit-test");
- searchPaths.add("tools/gfx-unit-test");
+ Slang::List<const char*> searchPaths = getSlangSearchPaths();
searchPaths.addRange(additionalSearchPaths);
deviceDesc.slang.searchPaths = searchPaths.getBuffer();
deviceDesc.slang.searchPathCount = (gfx::GfxCount)searchPaths.getCount();
@@ -253,6 +298,15 @@ namespace gfx_test
return device;
}
+ Slang::List<const char*> getSlangSearchPaths()
+ {
+ Slang::List<const char*> searchPaths;
+ searchPaths.add("");
+ searchPaths.add("../../tools/gfx-unit-test");
+ searchPaths.add("tools/gfx-unit-test");
+ return searchPaths;
+ }
+
#if GFX_ENABLE_RENDERDOC_INTEGRATION
RENDERDOC_API_1_1_2* rdoc_api = NULL;
void initializeRenderDoc()
diff --git a/tools/gfx-unit-test/gfx-test-util.h b/tools/gfx-unit-test/gfx-test-util.h
index 501deeae0..643830413 100644
--- a/tools/gfx-unit-test/gfx-test-util.h
+++ b/tools/gfx-unit-test/gfx-test-util.h
@@ -18,6 +18,14 @@ namespace gfx_test
const char* entryPointName,
slang::ProgramLayout*& slangReflection);
+ Slang::Result loadComputeProgram(
+ gfx::IDevice* device,
+ slang::ISession* slangSession,
+ Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
+ const char* shaderModuleName,
+ const char* entryPointName,
+ slang::ProgramLayout*& slangReflection);
+
Slang::Result loadComputeProgramFromSource(
gfx::IDevice* device,
Slang::ComPtr<gfx::IShaderProgram>& outShaderProgram,
@@ -79,6 +87,8 @@ namespace gfx_test
Slang::RenderApiFlag::Enum api,
Slang::List<const char*> additionalSearchPaths = {},
gfx::IDevice::ShaderCacheDesc shaderCache = {});
+
+ Slang::List<const char*> getSlangSearchPaths();
void initializeRenderDoc();
void renderDocBeginFrame();
diff --git a/tools/gfx-unit-test/precompiled-module-2.cpp b/tools/gfx-unit-test/precompiled-module-2.cpp
new file mode 100644
index 000000000..3da77e05c
--- /dev/null
+++ b/tools/gfx-unit-test/precompiled-module-2.cpp
@@ -0,0 +1,181 @@
+#include "tools/unit-test/slang-unit-test.h"
+
+#include "slang-gfx.h"
+#include "gfx-test-util.h"
+#include "tools/gfx-util/shader-cursor.h"
+#include "source/core/slang-basic.h"
+#include "source/core/slang-blob.h"
+#include "source/core/slang-memory-file-system.h"
+#include "source/core/slang-io.h"
+
+using namespace gfx;
+
+namespace gfx_test
+{
+ // Test that mixing precompiled and non-precompiled modules is working.
+
+ static Slang::Result precompileProgram(
+ gfx::IDevice* device,
+ ISlangMutableFileSystem* fileSys,
+ const char* shaderModuleName)
+ {
+ Slang::ComPtr<slang::ISession> slangSession;
+ SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
+ slang::SessionDesc sessionDesc = {};
+ auto searchPaths = getSlangSearchPaths();
+ sessionDesc.searchPathCount = searchPaths.getCount();
+ sessionDesc.searchPaths = searchPaths.getBuffer();
+ auto globalSession = slangSession->getGlobalSession();
+ globalSession->createSession(sessionDesc, slangSession.writeRef());
+
+ Slang::ComPtr<slang::IBlob> diagnosticsBlob;
+ slang::IModule* module = slangSession->loadModule(shaderModuleName, diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ if (!module)
+ return SLANG_FAIL;
+
+ // Write loaded modules to memory file system.
+ for (SlangInt i = 0; i < slangSession->getLoadedModuleCount(); i++)
+ {
+ auto module = slangSession->getLoadedModule(i);
+ auto path = module->getFilePath();
+ if (path)
+ {
+ auto name = module->getName();
+ ComPtr<ISlangBlob> outBlob;
+ module->serialize(outBlob.writeRef());
+ fileSys->saveFileBlob((Slang::String(name) + ".slang-module").getBuffer(), outBlob);
+ }
+ }
+ return SLANG_OK;
+ }
+
+ void precompiledModule2TestImpl(IDevice* device, UnitTestContext* context)
+ {
+ Slang::ComPtr<ITransientResourceHeap> transientHeap;
+ ITransientResourceHeap::Desc transientHeapDesc = {};
+ transientHeapDesc.constantBufferSize = 4096;
+ GFX_CHECK_CALL_ABORT(
+ device->createTransientResourceHeap(transientHeapDesc, transientHeap.writeRef()));
+
+ // First, load and compile the slang source.
+ ComPtr<ISlangMutableFileSystem> memoryFileSystem = ComPtr<ISlangMutableFileSystem>(new Slang::MemoryFileSystem());
+
+ ComPtr<IShaderProgram> shaderProgram;
+ slang::ProgramLayout* slangReflection;
+ GFX_CHECK_CALL_ABORT(precompileProgram(device, memoryFileSystem.get(), "precompiled-module-imported"));
+
+ // Next, load the precompiled slang program.
+ Slang::ComPtr<slang::ISession> slangSession;
+ device->getSlangSession(slangSession.writeRef());
+ slang::SessionDesc sessionDesc = {};
+ sessionDesc.targetCount = 1;
+ slang::TargetDesc targetDesc = {};
+ switch (device->getDeviceInfo().deviceType)
+ {
+ case gfx::DeviceType::DirectX12:
+ targetDesc.format = SLANG_DXIL;
+ targetDesc.profile = device->getSlangSession()->getGlobalSession()->findProfile("sm_6_1");
+ break;
+ case gfx::DeviceType::Vulkan:
+ targetDesc.format = SLANG_SPIRV;
+ targetDesc.profile = device->getSlangSession()->getGlobalSession()->findProfile("GLSL_460");
+ break;
+ }
+ sessionDesc.targets = &targetDesc;
+ sessionDesc.fileSystem = memoryFileSystem.get();
+ auto globalSession = slangSession->getGlobalSession();
+ globalSession->createSession(sessionDesc, slangSession.writeRef());
+
+ const char* moduleSrc = R"(
+ import "precompiled-module-imported";
+
+ // Main entry-point.
+
+ using namespace ns;
+
+ [shader("compute")]
+ [numthreads(4, 1, 1)]
+ void computeMain(
+ uint3 sv_dispatchThreadID : SV_DispatchThreadID,
+ uniform RWStructuredBuffer <float> buffer)
+ {
+ buffer[sv_dispatchThreadID.x] = helperFunc() + helperFunc1();
+ }
+ )";
+ memoryFileSystem->saveFile("precompiled-module.slang", moduleSrc, strlen(moduleSrc));
+ GFX_CHECK_CALL_ABORT(loadComputeProgram(device, slangSession, shaderProgram, "precompiled-module", "computeMain", slangReflection));
+
+ ComputePipelineStateDesc pipelineDesc = {};
+ pipelineDesc.program = shaderProgram.get();
+ ComPtr<gfx::IPipelineState> pipelineState;
+ GFX_CHECK_CALL_ABORT(
+ device->createComputePipelineState(pipelineDesc, pipelineState.writeRef()));
+
+ const int numberCount = 4;
+ float initialData[] = { 0.0f, 0.0f, 0.0f, 0.0f };
+ IBufferResource::Desc bufferDesc = {};
+ bufferDesc.sizeInBytes = numberCount * sizeof(float);
+ bufferDesc.format = gfx::Format::Unknown;
+ bufferDesc.elementSize = sizeof(float);
+ bufferDesc.allowedStates = ResourceStateSet(
+ ResourceState::ShaderResource,
+ ResourceState::UnorderedAccess,
+ ResourceState::CopyDestination,
+ ResourceState::CopySource);
+ bufferDesc.defaultState = ResourceState::UnorderedAccess;
+ bufferDesc.memoryType = MemoryType::DeviceLocal;
+
+ ComPtr<IBufferResource> numbersBuffer;
+ GFX_CHECK_CALL_ABORT(device->createBufferResource(
+ bufferDesc,
+ (void*)initialData,
+ numbersBuffer.writeRef()));
+
+ ComPtr<IResourceView> bufferView;
+ IResourceView::Desc viewDesc = {};
+ viewDesc.type = IResourceView::Type::UnorderedAccess;
+ viewDesc.format = Format::Unknown;
+ GFX_CHECK_CALL_ABORT(
+ device->createBufferView(numbersBuffer, nullptr, viewDesc, bufferView.writeRef()));
+
+ // We have done all the set up work, now it is time to start recording a command buffer for
+ // GPU execution.
+ {
+ ICommandQueue::Desc queueDesc = { ICommandQueue::QueueType::Graphics };
+ auto queue = device->createCommandQueue(queueDesc);
+
+ auto commandBuffer = transientHeap->createCommandBuffer();
+ auto encoder = commandBuffer->encodeComputeCommands();
+
+ auto rootObject = encoder->bindPipeline(pipelineState);
+
+ ShaderCursor entryPointCursor(
+ rootObject->getEntryPoint(0)); // get a cursor the the first entry-point.
+ // Bind buffer view to the entry point.
+ entryPointCursor.getPath("buffer").setResource(bufferView);
+
+ encoder->dispatchCompute(1, 1, 1);
+ encoder->endEncoding();
+ commandBuffer->close();
+ queue->executeCommandBuffer(commandBuffer);
+ queue->waitOnHost();
+ }
+
+ compareComputeResult(
+ device,
+ numbersBuffer,
+ Slang::makeArray<float>(3.0f, 3.0f, 3.0f, 3.0f));
+ }
+
+ SLANG_UNIT_TEST(precompiledModule2D3D12)
+ {
+ runTestImpl(precompiledModule2TestImpl, unitTestContext, Slang::RenderApiFlag::D3D12);
+ }
+
+ SLANG_UNIT_TEST(precompiledModule2Vulkan)
+ {
+ runTestImpl(precompiledModule2TestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ }
+
+}
diff --git a/tools/gfx-unit-test/precompiled-module-imported.slang b/tools/gfx-unit-test/precompiled-module-imported.slang
new file mode 100644
index 000000000..5c59e99b0
--- /dev/null
+++ b/tools/gfx-unit-test/precompiled-module-imported.slang
@@ -0,0 +1,11 @@
+module "precompiled-module-imported";
+
+__include "precompiled-module-included.slang";
+
+namespace ns
+{
+ public int helperFunc()
+ {
+ return 1;
+ }
+} \ No newline at end of file
diff --git a/tools/gfx-unit-test/precompiled-module-included.slang b/tools/gfx-unit-test/precompiled-module-included.slang
new file mode 100644
index 000000000..88d9e57d7
--- /dev/null
+++ b/tools/gfx-unit-test/precompiled-module-included.slang
@@ -0,0 +1,9 @@
+implementing "precompiled-module-imported";
+
+namespace ns
+{
+ public int helperFunc1()
+ {
+ return 2;
+ }
+} \ No newline at end of file
diff --git a/tools/gfx-unit-test/precompiled-module.cpp b/tools/gfx-unit-test/precompiled-module.cpp
new file mode 100644
index 000000000..026575120
--- /dev/null
+++ b/tools/gfx-unit-test/precompiled-module.cpp
@@ -0,0 +1,160 @@
+#include "tools/unit-test/slang-unit-test.h"
+
+#include "slang-gfx.h"
+#include "gfx-test-util.h"
+#include "tools/gfx-util/shader-cursor.h"
+#include "source/core/slang-basic.h"
+#include "source/core/slang-blob.h"
+#include "source/core/slang-memory-file-system.h"
+
+using namespace gfx;
+
+namespace gfx_test
+{
+ static Slang::Result precompileProgram(
+ gfx::IDevice* device,
+ ISlangMutableFileSystem* fileSys,
+ const char* shaderModuleName)
+ {
+ Slang::ComPtr<slang::ISession> slangSession;
+ SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
+ slang::SessionDesc sessionDesc = {};
+ auto searchPaths = getSlangSearchPaths();
+ sessionDesc.searchPathCount = searchPaths.getCount();
+ sessionDesc.searchPaths = searchPaths.getBuffer();
+ auto globalSession = slangSession->getGlobalSession();
+ globalSession->createSession(sessionDesc, slangSession.writeRef());
+
+ Slang::ComPtr<slang::IBlob> diagnosticsBlob;
+ slang::IModule* module = slangSession->loadModule(shaderModuleName, diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ if (!module)
+ return SLANG_FAIL;
+
+ // Write loaded modules to memory file system.
+ for (SlangInt i = 0; i < slangSession->getLoadedModuleCount(); i++)
+ {
+ auto module = slangSession->getLoadedModule(i);
+ auto path = module->getFilePath();
+ if (path)
+ {
+ auto name = module->getName();
+ ComPtr<ISlangBlob> outBlob;
+ module->serialize(outBlob.writeRef());
+ fileSys->saveFileBlob((Slang::String(name) + ".slang-module").getBuffer(), outBlob);
+ }
+ }
+ return SLANG_OK;
+ }
+
+ void precompiledModuleTestImpl(IDevice* device, UnitTestContext* context)
+ {
+ Slang::ComPtr<ITransientResourceHeap> transientHeap;
+ ITransientResourceHeap::Desc transientHeapDesc = {};
+ transientHeapDesc.constantBufferSize = 4096;
+ GFX_CHECK_CALL_ABORT(
+ device->createTransientResourceHeap(transientHeapDesc, transientHeap.writeRef()));
+
+ // First, load and compile the slang source.
+ ComPtr<ISlangMutableFileSystem> memoryFileSystem = ComPtr<ISlangMutableFileSystem>(new Slang::MemoryFileSystem());
+
+ ComPtr<IShaderProgram> shaderProgram;
+ slang::ProgramLayout* slangReflection;
+ GFX_CHECK_CALL_ABORT(precompileProgram(device, memoryFileSystem.get(), "precompiled-module"));
+
+ // Next, load the precompiled slang program.
+ Slang::ComPtr<slang::ISession> slangSession;
+ device->getSlangSession(slangSession.writeRef());
+ slang::SessionDesc sessionDesc = {};
+ sessionDesc.targetCount = 1;
+ slang::TargetDesc targetDesc = {};
+ switch (device->getDeviceInfo().deviceType)
+ {
+ case gfx::DeviceType::DirectX12:
+ targetDesc.format = SLANG_DXIL;
+ targetDesc.profile = device->getSlangSession()->getGlobalSession()->findProfile("sm_6_1");
+ break;
+ case gfx::DeviceType::Vulkan:
+ targetDesc.format = SLANG_SPIRV;
+ targetDesc.profile = device->getSlangSession()->getGlobalSession()->findProfile("GLSL_460");
+ break;
+ }
+ sessionDesc.targets = &targetDesc;
+ sessionDesc.fileSystem = memoryFileSystem.get();
+ auto globalSession = slangSession->getGlobalSession();
+ globalSession->createSession(sessionDesc, slangSession.writeRef());
+ GFX_CHECK_CALL_ABORT(loadComputeProgram(device, slangSession, shaderProgram, "precompiled-module", "computeMain", slangReflection));
+
+ ComputePipelineStateDesc pipelineDesc = {};
+ pipelineDesc.program = shaderProgram.get();
+ ComPtr<gfx::IPipelineState> pipelineState;
+ GFX_CHECK_CALL_ABORT(
+ device->createComputePipelineState(pipelineDesc, pipelineState.writeRef()));
+
+ const int numberCount = 4;
+ float initialData[] = { 0.0f, 0.0f, 0.0f, 0.0f };
+ IBufferResource::Desc bufferDesc = {};
+ bufferDesc.sizeInBytes = numberCount * sizeof(float);
+ bufferDesc.format = gfx::Format::Unknown;
+ bufferDesc.elementSize = sizeof(float);
+ bufferDesc.allowedStates = ResourceStateSet(
+ ResourceState::ShaderResource,
+ ResourceState::UnorderedAccess,
+ ResourceState::CopyDestination,
+ ResourceState::CopySource);
+ bufferDesc.defaultState = ResourceState::UnorderedAccess;
+ bufferDesc.memoryType = MemoryType::DeviceLocal;
+
+ ComPtr<IBufferResource> numbersBuffer;
+ GFX_CHECK_CALL_ABORT(device->createBufferResource(
+ bufferDesc,
+ (void*)initialData,
+ numbersBuffer.writeRef()));
+
+ ComPtr<IResourceView> bufferView;
+ IResourceView::Desc viewDesc = {};
+ viewDesc.type = IResourceView::Type::UnorderedAccess;
+ viewDesc.format = Format::Unknown;
+ GFX_CHECK_CALL_ABORT(
+ device->createBufferView(numbersBuffer, nullptr, viewDesc, bufferView.writeRef()));
+
+ // We have done all the set up work, now it is time to start recording a command buffer for
+ // GPU execution.
+ {
+ ICommandQueue::Desc queueDesc = { ICommandQueue::QueueType::Graphics };
+ auto queue = device->createCommandQueue(queueDesc);
+
+ auto commandBuffer = transientHeap->createCommandBuffer();
+ auto encoder = commandBuffer->encodeComputeCommands();
+
+ auto rootObject = encoder->bindPipeline(pipelineState);
+
+ ShaderCursor entryPointCursor(
+ rootObject->getEntryPoint(0)); // get a cursor the the first entry-point.
+ // Bind buffer view to the entry point.
+ entryPointCursor.getPath("buffer").setResource(bufferView);
+
+ encoder->dispatchCompute(1, 1, 1);
+ encoder->endEncoding();
+ commandBuffer->close();
+ queue->executeCommandBuffer(commandBuffer);
+ queue->waitOnHost();
+ }
+
+ compareComputeResult(
+ device,
+ numbersBuffer,
+ Slang::makeArray<float>(3.0f, 3.0f, 3.0f, 3.0f));
+ }
+
+ SLANG_UNIT_TEST(precompiledModuleD3D12)
+ {
+ runTestImpl(precompiledModuleTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12);
+ }
+
+ SLANG_UNIT_TEST(precompiledModuleVulkan)
+ {
+ runTestImpl(precompiledModuleTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ }
+
+}
diff --git a/tools/gfx-unit-test/precompiled-module.slang b/tools/gfx-unit-test/precompiled-module.slang
new file mode 100644
index 000000000..be7231432
--- /dev/null
+++ b/tools/gfx-unit-test/precompiled-module.slang
@@ -0,0 +1,14 @@
+import "precompiled-module-imported";
+
+using namespace ns;
+
+// Main entry-point.
+
+[shader("compute")]
+[numthreads(4, 1, 1)]
+void computeMain(
+ uint3 sv_dispatchThreadID : SV_DispatchThreadID,
+ uniform RWStructuredBuffer <float> buffer)
+{
+ buffer[sv_dispatchThreadID.x] = helperFunc() + helperFunc1();
+}