summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2025-10-11 17:24:53 -0700
committeryum <yum.food.vr@gmail.com>2025-10-11 17:24:53 -0700
commit56bae6342544974194468661e0827f425e2a79bb (patch)
tree4997e40f5432908487b23882879366ede5d8f265
parente0f317e187aa4ccf90a3a0e77b941b46f76101c5 (diff)
more cleanup
-rw-r--r--CMakeLists.txt10
-rw-r--r--main.cc859
2 files changed, 463 insertions, 406 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index f289121..71b0a88 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -5,11 +5,19 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
+include(FetchContent)
+
set(SLANG_SDK_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/slang/build/slang-2025.18.2-windows-x86_64")
if(NOT EXISTS "${SLANG_SDK_ROOT}")
message(FATAL_ERROR "Expected Slang SDK directory at ${SLANG_SDK_ROOT}")
endif()
+FetchContent_Declare(
+ abseil
+ URL https://github.com/abseil/abseil-cpp/archive/refs/tags/20240116.1.zip
+)
+FetchContent_MakeAvailable(abseil)
+
add_executable(modular_slang main.cc)
set_target_properties(modular_slang PROPERTIES
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
@@ -21,6 +29,8 @@ target_include_directories(modular_slang PRIVATE
target_link_libraries(modular_slang PRIVATE
"${SLANG_SDK_ROOT}/lib/slang.lib"
+ absl::status
+ absl::statusor
)
if(MSVC)
diff --git a/main.cc b/main.cc
index 32bd7f4..ba6ac3b 100644
--- a/main.cc
+++ b/main.cc
@@ -1,3 +1,4 @@
+#include <array>
#include <filesystem>
#include <cctype>
#include <fstream>
@@ -10,6 +11,9 @@
#include <unordered_set>
#include <vector>
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+
#include <slang.h>
#include <slang-com-ptr.h>
@@ -32,471 +36,514 @@ template <typename T>
using ComPtr = ::Slang::ComPtr<T>;
// Print any diagnostics carried by a Slang blob with optional context information.
-void printDiagnostics(const char* context, IBlob* diagnostics)
-{
- if (!diagnostics)
- {
- return;
- }
-
- std::size_t size = diagnostics->getBufferSize();
- if (size == 0)
- {
- return;
- }
-
- std::string_view text(static_cast<const char*>(diagnostics->getBufferPointer()), size);
- if (!text.empty() && text.back() == '\0')
- {
- text.remove_suffix(1);
- }
+void printDiagnostics(const char* context, IBlob* diagnostics) {
+ if (!diagnostics) {
+ return;
+ }
+
+ std::size_t size = diagnostics->getBufferSize();
+ if (size == 0) {
+ return;
+ }
+
+ std::string_view text(static_cast<const char*>(diagnostics->getBufferPointer()), size);
+ if (!text.empty() && text.back() == '\0') {
+ text.remove_suffix(1);
+ }
+
+ if (text.empty()) {
+ return;
+ }
+
+ if (context && *context) {
+ std::cerr << context << " diagnostics:" << std::endl;
+ }
+
+ std::cerr.write(text.data(), text.size());
+ if (text.back() != '\n') {
+ std::cerr << std::endl;
+ }
+}
- if (text.empty())
- {
- return;
- }
+// Helper to convert Slang API results into absl::Status values.
+absl::Status checkSlangResult(const char* context, SlangResult res, IBlob* diagnostics = nullptr) {
+ printDiagnostics(context, diagnostics);
- if (context && *context)
- {
- std::cerr << context << " diagnostics:" << std::endl;
- }
+ if (SLANG_FAILED(res)) {
+ std::ostringstream message;
+ message << (context && *context ? context : "Slang call")
+ << " failed with SlangResult " << res
+ << " (0x" << std::hex << res << std::dec << ')';
+ return absl::InternalError(message.str());
+ }
- std::cerr.write(text.data(), text.size());
- if (text.back() != '\n')
- {
- std::cerr << std::endl;
- }
+ return absl::OkStatus();
}
-// Helper to check Slang API results and surface diagnostic details when available.
-void checkResult(const char* context, SlangResult res, IBlob* diagnostics = nullptr)
-{
- printDiagnostics(context, diagnostics);
-
- if (SLANG_FAILED(res))
- {
- std::cerr << (context && *context ? context : "Slang call")
- << " failed with SlangResult " << res
- << " (0x" << std::hex << res << std::dec << ')' << std::endl;
- exit(1);
- }
-}
+absl::Status writeTextFile(const fs::path& path, std::string_view contents) {
+ std::ofstream file(path, std::ios::binary);
+ if (!file) {
+ std::ostringstream msg;
+ msg << "Failed to open " << path << " for writing.";
+ return absl::InternalError(msg.str());
+ }
-bool writeTextFile(const fs::path& path, std::string_view contents)
-{
- std::ofstream file(path, std::ios::binary);
- if (!file)
- {
- std::cerr << "Warning: Failed to open " << path << " for writing." << std::endl;
- return false;
- }
+ file.write(contents.data(), static_cast<std::streamsize>(contents.size()));
+ file.close();
- file.write(contents.data(), static_cast<std::streamsize>(contents.size()));
- file.close();
+ if (!file) {
+ std::ostringstream msg;
+ msg << "Failed to write " << path;
+ return absl::InternalError(msg.str());
+ }
- if (!file)
- {
- std::cerr << "Warning: Failed to write " << path << std::endl;
- return false;
- }
-
- return true;
+ return absl::OkStatus();
}
-void addCompilerOption(std::vector<CompilerOptionEntry>& options, CompilerOptionName name)
-{
- CompilerOptionEntry entry = {};
- entry.name = name;
- entry.value.intValue0 = 1;
- options.push_back(entry);
+void addCompilerOption(std::vector<CompilerOptionEntry>& options, CompilerOptionName name) {
+ CompilerOptionEntry entry = {};
+ entry.name = name;
+ entry.value.intValue0 = 1;
+ options.push_back(entry);
}
-struct FunctionInfo
-{
- std::string name;
+struct FunctionInfo {
+ std::string name;
+};
+
+struct IncludeGuardInfo {
+ bool present = false;
+ std::string macro;
+ std::string ifndefLine;
+ std::string defineLine;
+ std::string endifLine;
};
-struct IncludeGuardInfo
-{
- bool present = false;
- std::string macro;
- std::string ifndefLine;
- std::string defineLine;
- std::string endifLine;
+struct ModuleRequest {
+ fs::path modulePath;
+ std::string moduleName;
+ std::string searchPath;
+ fs::path outputPath;
};
-std::string trim(std::string_view text)
-{
- std::size_t start = 0;
- std::size_t end = text.size();
+std::string trim(std::string_view text) {
+ std::size_t start = 0;
+ std::size_t end = text.size();
- while (start < end && std::isspace(static_cast<unsigned char>(text[start])))
- {
- ++start;
- }
+ while (start < end && std::isspace(static_cast<unsigned char>(text[start]))) {
+ ++start;
+ }
- while (end > start && std::isspace(static_cast<unsigned char>(text[end - 1])))
- {
- --end;
- }
+ while (end > start && std::isspace(static_cast<unsigned char>(text[end - 1]))) {
+ --end;
+ }
- return std::string(text.substr(start, end - start));
+ return std::string(text.substr(start, end - start));
}
-bool isTopLevelFunction(DeclReflection* functionDecl)
-{
- if (!functionDecl)
- {
- return false;
- }
-
- using Kind = DeclReflection::Kind;
- for (DeclReflection* parent = functionDecl->getParent(); parent;
- parent = parent->getParent())
- {
- switch (parent->getKind())
- {
- case Kind::Module:
- case Kind::Namespace:
- return true;
- case Kind::Generic:
- continue;
- default:
- return false;
- }
+bool isTopLevelFunction(DeclReflection* functionDecl) {
+ if (!functionDecl) {
+ return false;
+ }
+
+ using Kind = DeclReflection::Kind;
+ for (DeclReflection* parent = functionDecl->getParent(); parent;
+ parent = parent->getParent()) {
+ switch (parent->getKind()) {
+ case Kind::Module:
+ case Kind::Namespace:
+ return true;
+ case Kind::Generic:
+ continue;
+ default:
+ return false;
}
+ }
- return false;
+ return false;
}
// Recursively gather function declarations defined in the supplied Slang module.
void collectFunctionInfos(
DeclReflection* decl,
std::vector<FunctionInfo>& functions,
- std::unordered_set<std::string>& seenNames)
-{
- if (!decl)
- {
- return;
- }
-
- using Kind = DeclReflection::Kind;
-
- switch (decl->getKind())
- {
- case Kind::Func:
- if (auto* functionReflection = decl->asFunction())
- {
- if (const char* name = functionReflection->getName())
- {
- // Heuristic: functions that don't start with underscore are considered public
- // (Slang convention: private/internal functions typically start with _)
- bool isPublic = name[0] != '_';
-
- if (*name && seenNames.insert(name).second && isTopLevelFunction(decl) && isPublic)
- {
- std::cerr << "Discovered entry point: " << name << std::endl;
- functions.push_back({name});
- }
- }
- }
- break;
- case Kind::Generic:
- if (auto* genericDecl = decl->asGeneric())
- {
- collectFunctionInfos(
- genericDecl->getInnerDecl(),
- functions,
- seenNames);
+ std::unordered_set<std::string>& seenNames) {
+ if (!decl) {
+ return;
+ }
+
+ using Kind = DeclReflection::Kind;
+
+ switch (decl->getKind()) {
+ case Kind::Func:
+ if (auto* functionReflection = decl->asFunction()) {
+ if (const char* name = functionReflection->getName()) {
+ // Heuristic: functions that don't start with underscore are considered public
+ // (Slang convention: private/internal functions typically start with _)
+ bool isPublic = name[0] != '_';
+
+ if (*name && seenNames.insert(name).second && isTopLevelFunction(decl) && isPublic) {
+ std::cerr << "Discovered entry point: " << name << std::endl;
+ functions.push_back({name});
}
+ }
+ }
+ break;
+ case Kind::Generic:
+ if (auto* genericDecl = decl->asGeneric()) {
+ collectFunctionInfos(
+ genericDecl->getInnerDecl(),
+ functions,
+ seenNames);
+ }
+ break;
+ default:
+ break;
+ }
+
+ for (auto* child : decl->getChildren()) {
+ collectFunctionInfos(child, functions, seenNames);
+ }
+}
+
+IncludeGuardInfo detectIncludeGuard(const fs::path& sourcePath) {
+ IncludeGuardInfo info;
+
+ std::ifstream input(sourcePath);
+ if (!input) {
+ return info;
+ }
+
+ std::vector<std::string> lines;
+ std::string line;
+ while (std::getline(input, line)) {
+ lines.push_back(line);
+ }
+
+ std::size_t ifndefIndex = std::numeric_limits<std::size_t>::max();
+ for (std::size_t i = 0; i < lines.size(); ++i) {
+ std::string trimmed = trim(lines[i]);
+ if (trimmed.rfind("#ifndef", 0) == 0) {
+ std::istringstream stream(trimmed);
+ std::string directive;
+ std::string macro;
+ stream >> directive >> macro;
+ if (!macro.empty()) {
+ info.macro = macro;
+ info.ifndefLine = lines[i];
+ ifndefIndex = i;
+ }
+ break;
+ }
+ }
+
+ if (info.macro.empty()) {
+ return info;
+ }
+
+ for (std::size_t i = ifndefIndex + 1; i < lines.size(); ++i) {
+ std::string trimmed = trim(lines[i]);
+ if (trimmed.rfind("#define", 0) == 0) {
+ std::istringstream stream(trimmed);
+ std::string directive;
+ std::string macro;
+ stream >> directive >> macro;
+ if (macro == info.macro) {
+ info.defineLine = lines[i];
break;
- default:
- break;
+ }
}
+ }
- for (auto* child : decl->getChildren())
- {
- collectFunctionInfos(child, functions, seenNames);
+ if (info.defineLine.empty()) {
+ info = IncludeGuardInfo{};
+ return info;
+ }
+
+ for (std::size_t i = lines.size(); i-- > 0;) {
+ std::string trimmed = trim(lines[i]);
+ if (trimmed.rfind("#endif", 0) == 0) {
+ info.endifLine = lines[i];
+ break;
}
+ }
+
+ if (info.endifLine.empty()) {
+ info = IncludeGuardInfo{};
+ return info;
+ }
+
+ info.present = true;
+ return info;
}
-IncludeGuardInfo detectIncludeGuard(const fs::path& sourcePath)
-{
- IncludeGuardInfo info;
+absl::StatusOr<ModuleRequest> parseModuleRequest(int argc, char** argv) {
+ const char* programName = (argc > 0 && argv) ? argv[0] : "modular_slang";
+
+ if (argc < 2 || !argv) {
+ std::ostringstream usage;
+ usage << "Usage: " << programName << " <module.slang>";
+ return absl::InvalidArgumentError(usage.str());
+ }
+
+ ModuleRequest request;
+ request.modulePath = fs::absolute(argv[1]);
+
+ if (!fs::exists(request.modulePath)) {
+ std::ostringstream msg;
+ msg << "Module not found: " << request.modulePath;
+ return absl::NotFoundError(msg.str());
+ }
+
+ if (request.modulePath.extension() != ".slang") {
+ std::ostringstream msg;
+ msg << "Expected a .slang file: " << request.modulePath;
+ return absl::InvalidArgumentError(msg.str());
+ }
+
+ request.moduleName = request.modulePath.stem().string();
+ request.searchPath = request.modulePath.has_parent_path()
+ ? request.modulePath.parent_path().string()
+ : fs::current_path().string();
+ request.outputPath = request.modulePath;
+ request.outputPath.replace_extension(".hlsl");
+
+ return request;
+}
- std::ifstream input(sourcePath);
- if (!input)
- {
- return info;
- }
+std::vector<CompilerOptionEntry> makeCommonOptions() {
+ std::vector<CompilerOptionEntry> options;
+ addCompilerOption(options, CompilerOptionName::DisableNonEssentialValidations);
+ addCompilerOption(options, CompilerOptionName::NoHLSLBinding);
+ addCompilerOption(options, CompilerOptionName::NoMangle);
+ addCompilerOption(options, CompilerOptionName::NoHLSLPackConstantBufferElements);
+ addCompilerOption(options, CompilerOptionName::NoEntryPointUniformParamTransform);
+ return options;
+}
- std::vector<std::string> lines;
- std::string line;
- while (std::getline(input, line))
- {
- lines.push_back(line);
- }
+void configureTargetDesc(
+ IGlobalSession* globalSession,
+ std::vector<CompilerOptionEntry>& targetOptions,
+ TargetDesc& outDesc) {
+ outDesc = {};
+ outDesc.format = SLANG_HLSL;
+ outDesc.profile = globalSession->findProfile("lib_6_6");
+ outDesc.flags = SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM;
+ outDesc.compilerOptionEntries = targetOptions.data();
+ outDesc.compilerOptionEntryCount = static_cast<uint32_t>(targetOptions.size());
+}
- std::size_t ifndefIndex = std::numeric_limits<std::size_t>::max();
- for (std::size_t i = 0; i < lines.size(); ++i)
- {
- std::string trimmed = trim(lines[i]);
- if (trimmed.rfind("#ifndef", 0) == 0)
- {
- std::istringstream stream(trimmed);
- std::string directive;
- std::string macro;
- stream >> directive >> macro;
- if (!macro.empty())
- {
- info.macro = macro;
- info.ifndefLine = lines[i];
- ifndefIndex = i;
- }
- break;
- }
- }
+void configureSessionDesc(
+ const TargetDesc& targetDesc,
+ const ModuleRequest& request,
+ std::vector<CompilerOptionEntry>& sessionOptions,
+ std::array<const char*, 1>& searchPathStorage,
+ SessionDesc& outDesc) {
+ searchPathStorage[0] = request.searchPath.c_str();
+
+ outDesc = {};
+ outDesc.targets = &targetDesc;
+ outDesc.targetCount = 1;
+ outDesc.searchPaths = searchPathStorage.data();
+ outDesc.searchPathCount = static_cast<uint32_t>(searchPathStorage.size());
+ outDesc.compilerOptionEntries = sessionOptions.data();
+ outDesc.compilerOptionEntryCount = static_cast<uint32_t>(sessionOptions.size());
+}
- if (info.macro.empty())
- {
- return info;
- }
+absl::StatusOr<ComPtr<IModule>> loadSlangModule(ISession* session, const std::string& moduleName) {
+ ComPtr<IModule> module;
+ ComPtr<IBlob> diagnostics;
+ module = session->loadModule(moduleName.c_str(), diagnostics.writeRef());
- for (std::size_t i = ifndefIndex + 1; i < lines.size(); ++i)
- {
- std::string trimmed = trim(lines[i]);
- if (trimmed.rfind("#define", 0) == 0)
- {
- std::istringstream stream(trimmed);
- std::string directive;
- std::string macro;
- stream >> directive >> macro;
- if (macro == info.macro)
- {
- info.defineLine = lines[i];
- break;
- }
- }
- }
+ const std::string context = "loadModule: " + moduleName;
+ printDiagnostics(context.c_str(), diagnostics);
- if (info.defineLine.empty())
- {
- info = IncludeGuardInfo{};
- return info;
- }
+ if (!module) {
+ std::ostringstream msg;
+ msg << "Failed to load module '" << moduleName << "'.";
+ return absl::InternalError(msg.str());
+ }
- for (std::size_t i = lines.size(); i-- > 0;)
- {
- std::string trimmed = trim(lines[i]);
- if (trimmed.rfind("#endif", 0) == 0)
- {
- info.endifLine = lines[i];
- break;
- }
- }
+ return module;
+}
+
+absl::StatusOr<std::vector<FunctionInfo>> collectEntryPoints(
+ IModule* module,
+ const std::string& moduleName) {
+ std::vector<FunctionInfo> functions;
+ std::unordered_set<std::string> seenNames;
+
+ DeclReflection* moduleReflection = module ? module->getModuleReflection() : nullptr;
+ if (!moduleReflection) {
+ std::ostringstream msg;
+ msg << "Failed to retrieve reflection data for module '"
+ << moduleName << "'.";
+ return absl::InternalError(msg.str());
+ }
+
+ collectFunctionInfos(moduleReflection, functions, seenNames);
+
+ if (functions.empty()) {
+ std::ostringstream msg;
+ msg << "No functions found in module '" << moduleName << "'.";
+ return absl::NotFoundError(msg.str());
+ }
+
+ return functions;
+}
- if (info.endifLine.empty())
- {
- info = IncludeGuardInfo{};
- return info;
+absl::StatusOr<ComPtr<ICompileRequest>> createCompileRequest(
+ ISession* session,
+ const ModuleRequest& request,
+ const TargetDesc& targetDesc,
+ const std::vector<FunctionInfo>& functions) {
+ ComPtr<ICompileRequest> compileRequest;
+ ABSL_RETURN_IF_ERROR(checkSlangResult(
+ "ISession::createCompileRequest",
+ session->createCompileRequest(compileRequest.writeRef())));
+
+ compileRequest->setCodeGenTarget(SLANG_HLSL);
+ compileRequest->setTargetProfile(0, targetDesc.profile);
+ compileRequest->setTargetFlags(0, SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM);
+ compileRequest->setMatrixLayoutMode(SLANG_MATRIX_LAYOUT_ROW_MAJOR);
+ compileRequest->setLineDirectiveMode(SLANG_LINE_DIRECTIVE_MODE_NONE);
+
+ compileRequest->addSearchPath(request.searchPath.c_str());
+
+ const int translationUnitIndex = compileRequest->addTranslationUnit(
+ SLANG_SOURCE_LANGUAGE_SLANG,
+ request.moduleName.c_str());
+ compileRequest->addTranslationUnitSourceFile(
+ translationUnitIndex,
+ request.modulePath.string().c_str());
+
+ for (const FunctionInfo& func : functions) {
+ const int entryPointIndex = compileRequest->addEntryPoint(
+ translationUnitIndex,
+ func.name.c_str(),
+ SLANG_STAGE_DISPATCH);
+ if (entryPointIndex < 0) {
+ std::ostringstream msg;
+ msg << "Failed to register entry point '" << func.name << "'.";
+ return absl::InternalError(msg.str());
}
+ }
- info.present = true;
- return info;
+ return compileRequest;
}
-int main(int argc, char** argv)
-{
- if (argc < 2) {
- std::cerr << "Usage: " << (argc > 0 ? argv[0] : "modular_slang") << " <module.slang>" << std::endl;
- return 1;
- }
- fs::path modulePath = fs::absolute(argv[1]);
- if (!fs::exists(modulePath)) {
- std::cerr << "Module not found: " << modulePath << std::endl;
- return 1;
- }
- if (modulePath.extension() != ".slang") {
- std::cerr << "Expected a .slang file: " << modulePath << std::endl;
- return 1;
- }
+absl::StatusOr<std::string> collectGeneratedHlsl(ICompileRequest* compileRequest, const std::string& moduleName) {
+ SlangResult compileResult = compileRequest->compile();
+ ComPtr<IBlob> diagnostics;
+ compileRequest->getDiagnosticOutputBlob(diagnostics.writeRef());
+ ABSL_RETURN_IF_ERROR(checkSlangResult("ICompileRequest::compile", compileResult, diagnostics.get()));
+
+ ComPtr<IBlob> targetCodeBlob;
+ ABSL_RETURN_IF_ERROR(checkSlangResult(
+ "ICompileRequest::getTargetCodeBlob",
+ compileRequest->getTargetCodeBlob(0, targetCodeBlob.writeRef())));
+
+ if (!targetCodeBlob || targetCodeBlob->getBufferSize() == 0) {
+ std::ostringstream msg;
+ msg << "No HLSL was generated for module '" << moduleName << "'.";
+ return absl::InternalError(msg.str());
+ }
+
+ return std::string(
+ static_cast<const char*>(targetCodeBlob->getBufferPointer()),
+ static_cast<std::size_t>(targetCodeBlob->getBufferSize()));
+}
- std::string moduleName = modulePath.stem().string();
- std::string searchPath = modulePath.has_parent_path()
- ? modulePath.parent_path().string()
- : fs::current_path().string();
-
- fs::path outputPath = modulePath;
- outputPath.replace_extension(".hlsl");
-
- // Create session
- ComPtr<IGlobalSession> globalSession;
- checkResult("createGlobalSession", createGlobalSession(globalSession.writeRef()));
-
- // Configure session and target
- TargetDesc targetDesc = {};
- targetDesc.format = SLANG_HLSL;
- targetDesc.profile = globalSession->findProfile("lib_6_6");
- targetDesc.flags = SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM;
-
- std::vector<CompilerOptionEntry> options;
- addCompilerOption(options, CompilerOptionName::DisableNonEssentialValidations);
- addCompilerOption(options, CompilerOptionName::NoHLSLBinding);
- addCompilerOption(options, CompilerOptionName::NoMangle);
- addCompilerOption(options, CompilerOptionName::NoHLSLPackConstantBufferElements);
- addCompilerOption(options, CompilerOptionName::NoEntryPointUniformParamTransform);
-
- std::vector<CompilerOptionEntry> targetOptions(options);
- targetDesc.compilerOptionEntries = targetOptions.data();
- targetDesc.compilerOptionEntryCount = static_cast<uint32_t>(targetOptions.size());
-
- SessionDesc sessionDesc = {};
- sessionDesc.targets = &targetDesc;
- sessionDesc.targetCount = 1;
- const char* searchPaths[] = { searchPath.c_str() };
- sessionDesc.searchPaths = searchPaths;
- sessionDesc.searchPathCount = 1;
-
- std::vector<CompilerOptionEntry> sessionOptions(options);
- sessionDesc.compilerOptionEntries = sessionOptions.data();
- sessionDesc.compilerOptionEntryCount = static_cast<uint32_t>(sessionOptions.size());
-
- ComPtr<ISession> session;
- checkResult("IGlobalSession::createSession",
- globalSession->createSession(sessionDesc, session.writeRef()));
-
- // Load the "module" aka the library
- ComPtr<IModule> libraryModule;
- {
- ComPtr<IBlob> diagnosticsBlob;
- libraryModule = session->loadModule(moduleName.c_str(), diagnosticsBlob.writeRef());
- const std::string diagnosticsContext = "loadModule: " + moduleName;
- printDiagnostics(diagnosticsContext.c_str(), diagnosticsBlob);
- if (!libraryModule)
- {
- std::cerr << "Failed to load module '" << moduleName << "'." << std::endl;
- return 1;
- }
- }
+std::string applyIncludeGuard(const std::string& hlslSource, const IncludeGuardInfo& includeGuard) {
+ if (!includeGuard.present) {
+ return hlslSource;
+ }
+
+ const std::string guardIfndefToken = "#ifndef " + includeGuard.macro;
+ const std::string guardDefineToken = "#define " + includeGuard.macro;
+ const bool alreadyGuarded =
+ hlslSource.find(guardIfndefToken) != std::string::npos &&
+ hlslSource.find(guardDefineToken) != std::string::npos;
+
+ if (alreadyGuarded) {
+ return hlslSource;
+ }
+
+ std::string body = hlslSource;
+ if (!body.empty() && body.back() != '\n') {
+ body += '\n';
+ }
+
+ std::ostringstream wrapped;
+ wrapped << includeGuard.ifndefLine << '\n';
+ wrapped << includeGuard.defineLine << '\n';
+ wrapped << '\n';
+ wrapped << body;
+ if (!body.empty() && body.back() != '\n') {
+ wrapped << '\n';
+ }
+ wrapped << includeGuard.endifLine;
+ if (!includeGuard.endifLine.empty() && includeGuard.endifLine.back() != '\n') {
+ wrapped << '\n';
+ }
+
+ return wrapped.str();
+}
- // Discover top-level functions to treat as entry points
- std::vector<FunctionInfo> functions;
- std::unordered_set<std::string> seenNames;
- DeclReflection* moduleReflection = libraryModule->getModuleReflection();
- if (!moduleReflection) {
- std::cerr << "Failed to retrieve reflection data for module '" << moduleName << "'."
- << std::endl;
- return 1;
- }
- collectFunctionInfos(moduleReflection, functions, seenNames);
- if (functions.empty()) {
- std::cerr << "No functions found in module '" << moduleName << "'." << std::endl;
- return 1;
- }
+absl::Status run(int argc, char** argv) {
+ ABSL_ASSIGN_OR_RETURN(ModuleRequest request, parseModuleRequest(argc, argv));
- // Compile
- ComPtr<ICompileRequest> compileRequest;
- checkResult(
- "ISession::createCompileRequest",
- session->createCompileRequest(compileRequest.writeRef()));
+ ComPtr<IGlobalSession> globalSession;
+ ABSL_RETURN_IF_ERROR(checkSlangResult(
+ "createGlobalSession",
+ createGlobalSession(globalSession.writeRef())));
- compileRequest->setCodeGenTarget(SLANG_HLSL);
- compileRequest->setTargetProfile(0, targetDesc.profile);
- compileRequest->setTargetFlags(0, SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM);
- compileRequest->setMatrixLayoutMode(SLANG_MATRIX_LAYOUT_ROW_MAJOR);
- compileRequest->setLineDirectiveMode(SLANG_LINE_DIRECTIVE_MODE_NONE);
+ auto commonOptions = makeCommonOptions();
- compileRequest->addSearchPath(searchPath.c_str());
+ std::vector<CompilerOptionEntry> targetOptions = commonOptions;
+ TargetDesc targetDesc;
+ configureTargetDesc(globalSession.get(), targetOptions, targetDesc);
- const int translationUnitIndex = compileRequest->addTranslationUnit(
- SLANG_SOURCE_LANGUAGE_SLANG,
- moduleName.c_str());
- compileRequest->addTranslationUnitSourceFile(
- translationUnitIndex,
- modulePath.string().c_str());
-
- for (const FunctionInfo& func : functions)
- {
- const int entryPointIndex = compileRequest->addEntryPoint(
- translationUnitIndex,
- func.name.c_str(),
- SLANG_STAGE_DISPATCH);
- if (entryPointIndex < 0)
- {
- std::cerr << "Failed to register entry point '" << func.name << "'." << std::endl;
- return 1;
- }
- }
+ std::vector<CompilerOptionEntry> sessionOptions = commonOptions;
+ SessionDesc sessionDesc;
+ std::array<const char*, 1> searchPaths{};
+ configureSessionDesc(targetDesc, request, sessionOptions, searchPaths, sessionDesc);
- SlangResult compileResult = compileRequest->compile();
- ComPtr<IBlob> compileDiagnostics;
- compileRequest->getDiagnosticOutputBlob(compileDiagnostics.writeRef());
- checkResult("ICompileRequest::compile", compileResult, compileDiagnostics);
+ ComPtr<ISession> session;
+ ABSL_RETURN_IF_ERROR(checkSlangResult(
+ "IGlobalSession::createSession",
+ globalSession->createSession(sessionDesc, session.writeRef())));
- ComPtr<IBlob> targetCodeBlob;
- checkResult(
- "ICompileRequest::getTargetCodeBlob",
- compileRequest->getTargetCodeBlob(0, targetCodeBlob.writeRef()));
+ ABSL_ASSIGN_OR_RETURN(ComPtr<IModule> libraryModule,
+ loadSlangModule(session.get(), request.moduleName));
- if (!targetCodeBlob || targetCodeBlob->getBufferSize() == 0)
- {
- std::cerr << "No HLSL was generated for module '" << moduleName << "'." << std::endl;
- return 1;
- }
+ ABSL_ASSIGN_OR_RETURN(std::vector<FunctionInfo> functions,
+ collectEntryPoints(libraryModule.get(), request.moduleName));
- std::string hlslSource(
- static_cast<const char*>(targetCodeBlob->getBufferPointer()),
- static_cast<std::size_t>(targetCodeBlob->getBufferSize()));
-
- fs::path rawOutputPath = outputPath;
- rawOutputPath.replace_extension(".raw.hlsl");
- writeTextFile(rawOutputPath, hlslSource);
-
- std::string finalHlsl = hlslSource;
-
- IncludeGuardInfo includeGuard = detectIncludeGuard(modulePath);
- if (includeGuard.present)
- {
- const std::string guardIfndefToken = "#ifndef " + includeGuard.macro;
- const std::string guardDefineToken = "#define " + includeGuard.macro;
- const bool alreadyGuarded =
- finalHlsl.find(guardIfndefToken) != std::string::npos &&
- finalHlsl.find(guardDefineToken) != std::string::npos;
-
- if (!alreadyGuarded)
- {
- std::string body = finalHlsl;
- if (!body.empty() && body.back() != '\n')
- {
- body += '\n';
- }
-
- std::ostringstream wrapped;
- wrapped << includeGuard.ifndefLine << '\n';
- wrapped << includeGuard.defineLine << '\n';
- wrapped << '\n';
- wrapped << body;
- if (!body.empty() && body.back() != '\n')
- {
- wrapped << '\n';
- }
- wrapped << includeGuard.endifLine;
- if (!includeGuard.endifLine.empty() && includeGuard.endifLine.back() != '\n')
- {
- wrapped << '\n';
- }
-
- finalHlsl = wrapped.str();
- }
- }
+ ABSL_ASSIGN_OR_RETURN(ComPtr<ICompileRequest> compileRequest,
+ createCompileRequest(session.get(), request, targetDesc, functions));
- // 6. Write HLSL output to a sibling .hlsl file
- if (!writeTextFile(outputPath, finalHlsl))
- {
- return 1;
- }
+ ABSL_ASSIGN_OR_RETURN(std::string hlslSource,
+ collectGeneratedHlsl(compileRequest.get(), request.moduleName));
+
+ fs::path rawOutputPath = request.outputPath;
+ rawOutputPath.replace_extension(".raw.hlsl");
+ ABSL_RETURN_IF_ERROR(writeTextFile(rawOutputPath, hlslSource));
+
+ IncludeGuardInfo includeGuard = detectIncludeGuard(request.modulePath);
+ std::string finalHlsl = applyIncludeGuard(hlslSource, includeGuard);
+ ABSL_RETURN_IF_ERROR(writeTextFile(request.outputPath, finalHlsl));
+
+ std::cerr << "Generated HLSL written to " << request.outputPath << std::endl;
+ return absl::OkStatus();
+}
- std::cerr << "Generated HLSL written to " << outputPath << std::endl;
+int main(int argc, char** argv) {
+ absl::Status status = run(argc, argv);
+ if (!status.ok()) {
+ std::cerr << status.message() << std::endl;
+ return 1;
+ }
- return 0;
+ return 0;
}