#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include #include namespace fs = std::filesystem; using ::slang::CompilerOptionEntry; using ::slang::CompilerOptionName; using ::slang::createGlobalSession; using ::slang::DeclReflection; using ::slang::FunctionReflection; using ::slang::IBlob; using ::slang::ICompileRequest; using ::slang::IGlobalSession; using ::slang::IModule; using ::slang::ISession; using ::slang::SessionDesc; using ::slang::TargetDesc; template using ComPtr = ::Slang::ComPtr; // Print any diagnostics carried by a Slang blob with optional context information. void printDiagnostics(const char* context, IBlob* diagnostics) { if (!diagnostics) { return; } std::size_t size = diagnostics->getBufferSize(); if (size == 0) { return; } std::string_view text(static_cast(diagnostics->getBufferPointer()), size); if (!text.empty() && text.back() == '\0') { text.remove_suffix(1); } if (text.empty()) { return; } if (context && *context) { std::cerr << context << " diagnostics:" << std::endl; } std::cerr.write(text.data(), text.size()); if (text.back() != '\n') { std::cerr << std::endl; } } // Helper to convert Slang API results into absl::Status values. absl::Status checkSlangResult(const char* context, SlangResult res, IBlob* diagnostics = nullptr) { printDiagnostics(context, diagnostics); if (SLANG_FAILED(res)) { std::ostringstream message; message << (context && *context ? context : "Slang call") << " failed with SlangResult " << res << " (0x" << std::hex << res << std::dec << ')'; return absl::InternalError(message.str()); } return absl::OkStatus(); } absl::Status writeTextFile(const fs::path& path, std::string_view contents) { std::ofstream file(path, std::ios::binary); if (!file) { std::ostringstream msg; msg << "Failed to open " << path << " for writing."; return absl::InternalError(msg.str()); } file.write(contents.data(), static_cast(contents.size())); file.close(); if (!file) { std::ostringstream msg; msg << "Failed to write " << path; return absl::InternalError(msg.str()); } return absl::OkStatus(); } void addCompilerOption(std::vector& options, CompilerOptionName name) { CompilerOptionEntry entry = {}; entry.name = name; entry.value.intValue0 = 1; options.push_back(entry); } struct FunctionInfo { std::string name; }; struct IncludeGuardInfo { bool present = false; std::string macro; std::string ifndefLine; std::string defineLine; std::string endifLine; }; struct ModuleRequest { fs::path modulePath; std::string moduleName; std::string searchPath; fs::path outputPath; }; std::string trim(std::string_view text) { std::size_t start = 0; std::size_t end = text.size(); while (start < end && std::isspace(static_cast(text[start]))) { ++start; } while (end > start && std::isspace(static_cast(text[end - 1]))) { --end; } return std::string(text.substr(start, end - start)); } bool isTopLevelFunction(DeclReflection* functionDecl) { if (!functionDecl) { return false; } using Kind = DeclReflection::Kind; for (DeclReflection* parent = functionDecl->getParent(); parent; parent = parent->getParent()) { switch (parent->getKind()) { case Kind::Module: case Kind::Namespace: return true; case Kind::Generic: continue; default: return false; } } return false; } std::unordered_set findPublicFunctionNames(const fs::path& sourcePath) { std::unordered_set names; std::ifstream input(sourcePath, std::ios::binary); if (!input) { return names; } std::string source((std::istreambuf_iterator(input)), std::istreambuf_iterator()); const std::size_t length = source.size(); std::size_t index = 0; bool publicPending = false; std::string candidate; int templateDepth = 0; while (index < length) { char c = source[index]; if (c == '/' && index + 1 < length) { char next = source[index + 1]; if (next == '/') { index += 2; while (index < length && source[index] != '\n') { ++index; } continue; } if (next == '*') { index += 2; while (index + 1 < length && !(source[index] == '*' && source[index + 1] == '/')) { ++index; } if (index + 1 < length) { index += 2; } continue; } } if (c == '"' || c == '\'') { char quote = c; ++index; while (index < length) { char current = source[index]; if (current == '\\') { index += 2; continue; } if (current == quote) { ++index; break; } ++index; } continue; } if (std::isalpha(static_cast(c)) || c == '_') { std::size_t startToken = index; ++index; while (index < length) { char ch = source[index]; if (std::isalnum(static_cast(ch)) || ch == '_') { ++index; } else { break; } } std::string token = source.substr(startToken, index - startToken); if (token == "public") { publicPending = true; candidate.clear(); templateDepth = 0; } else if (publicPending && templateDepth == 0) { candidate = token; } continue; } if (publicPending) { if (c == '<') { ++templateDepth; ++index; continue; } if (c == '>') { if (templateDepth > 0) { --templateDepth; } ++index; continue; } if (c == '(') { if (!candidate.empty() && templateDepth == 0) { names.insert(candidate); } publicPending = false; candidate.clear(); templateDepth = 0; ++index; continue; } if (c == ';' || c == '{' || c == '}') { publicPending = false; candidate.clear(); templateDepth = 0; ++index; continue; } } ++index; } return names; } // Recursively gather function declarations defined in the supplied Slang module. void collectFunctionInfos( DeclReflection* decl, const std::unordered_set& publicFunctions, std::vector& functions, std::unordered_set& seenNames) { if (!decl) { return; } using Kind = DeclReflection::Kind; switch (decl->getKind()) { case Kind::Func: if (auto* functionReflection = decl->asFunction()) { if (const char* name = functionReflection->getName()) { bool isPublic = publicFunctions.find(name) != publicFunctions.end(); if (*name && isPublic && seenNames.insert(name).second && isTopLevelFunction(decl)) { std::cerr << "Discovered entry point: " << name << std::endl; functions.push_back({name}); } } } break; case Kind::Generic: if (auto* genericDecl = decl->asGeneric()) { collectFunctionInfos( genericDecl->getInnerDecl(), publicFunctions, functions, seenNames); } break; default: break; } for (auto* child : decl->getChildren()) { collectFunctionInfos(child, publicFunctions, functions, seenNames); } } IncludeGuardInfo detectIncludeGuard(const fs::path& sourcePath) { IncludeGuardInfo info; std::ifstream input(sourcePath); if (!input) { return info; } std::vector lines; std::string line; while (std::getline(input, line)) { lines.push_back(line); } std::size_t ifndefIndex = std::numeric_limits::max(); for (std::size_t i = 0; i < lines.size(); ++i) { std::string trimmed = trim(lines[i]); if (trimmed.rfind("#ifndef", 0) == 0) { std::istringstream stream(trimmed); std::string directive; std::string macro; stream >> directive >> macro; if (!macro.empty()) { info.macro = macro; info.ifndefLine = lines[i]; ifndefIndex = i; } break; } } if (info.macro.empty()) { return info; } for (std::size_t i = ifndefIndex + 1; i < lines.size(); ++i) { std::string trimmed = trim(lines[i]); if (trimmed.rfind("#define", 0) == 0) { std::istringstream stream(trimmed); std::string directive; std::string macro; stream >> directive >> macro; if (macro == info.macro) { info.defineLine = lines[i]; break; } } } if (info.defineLine.empty()) { info = IncludeGuardInfo{}; return info; } for (std::size_t i = lines.size(); i-- > 0;) { std::string trimmed = trim(lines[i]); if (trimmed.rfind("#endif", 0) == 0) { info.endifLine = lines[i]; break; } } if (info.endifLine.empty()) { info = IncludeGuardInfo{}; return info; } info.present = true; return info; } absl::StatusOr parseModuleRequest(int argc, char** argv) { const char* programName = (argc > 0 && argv) ? argv[0] : "modular_slang"; if (argc < 2 || !argv) { std::ostringstream usage; usage << "Usage: " << programName << " "; return absl::InvalidArgumentError(usage.str()); } ModuleRequest request; request.modulePath = fs::absolute(argv[1]); if (!fs::exists(request.modulePath)) { std::ostringstream msg; msg << "Module not found: " << request.modulePath; return absl::NotFoundError(msg.str()); } if (request.modulePath.extension() != ".slang") { std::ostringstream msg; msg << "Expected a .slang file: " << request.modulePath; return absl::InvalidArgumentError(msg.str()); } request.moduleName = request.modulePath.stem().string(); request.searchPath = request.modulePath.has_parent_path() ? request.modulePath.parent_path().string() : fs::current_path().string(); request.outputPath = request.modulePath; request.outputPath.replace_extension(".hlsl"); return request; } std::vector makeCommonOptions() { std::vector options; addCompilerOption(options, CompilerOptionName::DisableNonEssentialValidations); addCompilerOption(options, CompilerOptionName::NoHLSLBinding); addCompilerOption(options, CompilerOptionName::NoMangle); addCompilerOption(options, CompilerOptionName::NoHLSLPackConstantBufferElements); addCompilerOption(options, CompilerOptionName::PlainFunctionEntryPoints); return options; } void configureTargetDesc( IGlobalSession* globalSession, std::vector& targetOptions, TargetDesc& outDesc) { outDesc = {}; outDesc.format = SLANG_HLSL; outDesc.profile = globalSession->findProfile("lib_6_6"); outDesc.flags = SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM; outDesc.compilerOptionEntries = targetOptions.data(); outDesc.compilerOptionEntryCount = static_cast(targetOptions.size()); } void configureSessionDesc( const TargetDesc& targetDesc, const ModuleRequest& request, std::vector& sessionOptions, std::array& searchPathStorage, SessionDesc& outDesc) { searchPathStorage[0] = request.searchPath.c_str(); outDesc = {}; outDesc.targets = &targetDesc; outDesc.targetCount = 1; outDesc.searchPaths = searchPathStorage.data(); outDesc.searchPathCount = static_cast(searchPathStorage.size()); outDesc.compilerOptionEntries = sessionOptions.data(); outDesc.compilerOptionEntryCount = static_cast(sessionOptions.size()); } absl::StatusOr> loadSlangModule(ISession* session, const std::string& moduleName) { ComPtr module; ComPtr diagnostics; module = session->loadModule(moduleName.c_str(), diagnostics.writeRef()); const std::string context = "loadModule: " + moduleName; printDiagnostics(context.c_str(), diagnostics); if (!module) { std::ostringstream msg; msg << "Failed to load module '" << moduleName << "'."; return absl::InternalError(msg.str()); } return module; } struct EntryPointsResult { std::vector functions; std::unordered_set publicFunctions; }; absl::StatusOr collectEntryPoints( IModule* module, const std::string& moduleName, const fs::path& sourcePath) { std::vector functions; std::unordered_set seenNames; std::unordered_set publicFunctions = findPublicFunctionNames(sourcePath); if (publicFunctions.empty()) { std::ostringstream msg; msg << "No public functions found in " << sourcePath.string() << '.'; return absl::NotFoundError(msg.str()); } DeclReflection* moduleReflection = module ? module->getModuleReflection() : nullptr; if (!moduleReflection) { std::ostringstream msg; msg << "Failed to retrieve reflection data for module '" << moduleName << "'."; return absl::InternalError(msg.str()); } collectFunctionInfos(moduleReflection, publicFunctions, functions, seenNames); if (functions.empty()) { std::ostringstream msg; msg << "No public functions found in module '" << moduleName << "'."; return absl::NotFoundError(msg.str()); } return EntryPointsResult{functions, publicFunctions}; } absl::StatusOr> createCompileRequest( ISession* session, const ModuleRequest& request, const TargetDesc& targetDesc, const std::vector& functions) { ComPtr compileRequest; if (absl::Status status = checkSlangResult( "ISession::createCompileRequest", session->createCompileRequest(compileRequest.writeRef())); !status.ok()) { return status; } compileRequest->setCodeGenTarget(SLANG_HLSL); compileRequest->setTargetProfile(0, targetDesc.profile); compileRequest->setTargetFlags(0, SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM); compileRequest->setMatrixLayoutMode(SLANG_MATRIX_LAYOUT_ROW_MAJOR); compileRequest->setLineDirectiveMode(SLANG_LINE_DIRECTIVE_MODE_NONE); compileRequest->addSearchPath(request.searchPath.c_str()); const int translationUnitIndex = compileRequest->addTranslationUnit( SLANG_SOURCE_LANGUAGE_SLANG, request.moduleName.c_str()); compileRequest->addTranslationUnitSourceFile( translationUnitIndex, request.modulePath.string().c_str()); for (const FunctionInfo& func : functions) { const int entryPointIndex = compileRequest->addEntryPoint( translationUnitIndex, func.name.c_str(), SLANG_STAGE_DISPATCH); if (entryPointIndex < 0) { std::ostringstream msg; msg << "Failed to register entry point '" << func.name << "'."; return absl::InternalError(msg.str()); } } return compileRequest; } absl::StatusOr collectGeneratedHlsl(ICompileRequest* compileRequest, const std::string& moduleName) { SlangResult compileResult = compileRequest->compile(); ComPtr diagnostics; compileRequest->getDiagnosticOutputBlob(diagnostics.writeRef()); if (absl::Status status = checkSlangResult( "ICompileRequest::compile", compileResult, diagnostics.get()); !status.ok()) { return status; } ComPtr targetCodeBlob; if (absl::Status status = checkSlangResult( "ICompileRequest::getTargetCodeBlob", compileRequest->getTargetCodeBlob(0, targetCodeBlob.writeRef())); !status.ok()) { return status; } if (!targetCodeBlob || targetCodeBlob->getBufferSize() == 0) { std::ostringstream msg; msg << "No HLSL was generated for module '" << moduleName << "'."; return absl::InternalError(msg.str()); } return std::string( static_cast(targetCodeBlob->getBufferPointer()), static_cast(targetCodeBlob->getBufferSize())); } std::string removeNvapiInclude(std::string hlslSource) { const std::string guardToken = "#ifdef SLANG_HLSL_ENABLE_NVAPI"; const std::string endifToken = "#endif"; std::size_t searchPos = 0; while (true) { const std::size_t blockStart = hlslSource.find(guardToken, searchPos); if (blockStart == std::string::npos) { break; } std::size_t blockEnd = hlslSource.find(endifToken, blockStart); if (blockEnd == std::string::npos) { break; } blockEnd += endifToken.size(); while (blockEnd < hlslSource.size() && (hlslSource[blockEnd] == '\r' || hlslSource[blockEnd] == '\n')) { ++blockEnd; } hlslSource.erase(blockStart, blockEnd - blockStart); searchPos = blockStart; } return hlslSource; } std::string applyIncludeGuard(const std::string& hlslSource, const IncludeGuardInfo& includeGuard) { if (!includeGuard.present) { return hlslSource; } const std::string guardIfndefToken = "#ifndef " + includeGuard.macro; const std::string guardDefineToken = "#define " + includeGuard.macro; const bool alreadyGuarded = hlslSource.find(guardIfndefToken) != std::string::npos && hlslSource.find(guardDefineToken) != std::string::npos; if (alreadyGuarded) { return hlslSource; } std::string body = hlslSource; if (!body.empty() && body.back() != '\n') { body += '\n'; } std::ostringstream wrapped; wrapped << includeGuard.ifndefLine << '\n'; wrapped << includeGuard.defineLine << '\n'; wrapped << '\n'; wrapped << body; if (!body.empty() && body.back() != '\n') { wrapped << '\n'; } wrapped << includeGuard.endifLine; if (!includeGuard.endifLine.empty() && includeGuard.endifLine.back() != '\n') { wrapped << '\n'; } return wrapped.str(); } absl::Status run(int argc, char** argv) { absl::StatusOr requestOr = parseModuleRequest(argc, argv); if (!requestOr.ok()) { return requestOr.status(); } ModuleRequest request = std::move(requestOr).value(); ComPtr globalSession; if (absl::Status status = checkSlangResult( "createGlobalSession", createGlobalSession(globalSession.writeRef())); !status.ok()) { return status; } auto commonOptions = makeCommonOptions(); std::vector targetOptions = commonOptions; TargetDesc targetDesc; configureTargetDesc(globalSession.get(), targetOptions, targetDesc); std::vector sessionOptions = commonOptions; SessionDesc sessionDesc; std::array searchPaths{}; configureSessionDesc(targetDesc, request, sessionOptions, searchPaths, sessionDesc); ComPtr session; if (absl::Status status = checkSlangResult( "IGlobalSession::createSession", globalSession->createSession(sessionDesc, session.writeRef())); !status.ok()) { return status; } absl::StatusOr> libraryModuleOr = loadSlangModule(session.get(), request.moduleName); if (!libraryModuleOr.ok()) { return libraryModuleOr.status(); } ComPtr libraryModule = std::move(libraryModuleOr).value(); absl::StatusOr entryPointsOr = collectEntryPoints(libraryModule.get(), request.moduleName, request.modulePath); if (!entryPointsOr.ok()) { return entryPointsOr.status(); } EntryPointsResult entryPoints = std::move(entryPointsOr).value(); absl::StatusOr> compileRequestOr = createCompileRequest(session.get(), request, targetDesc, entryPoints.functions); if (!compileRequestOr.ok()) { return compileRequestOr.status(); } ComPtr compileRequest = std::move(compileRequestOr).value(); absl::StatusOr hlslSourceOr = collectGeneratedHlsl(compileRequest.get(), request.moduleName); if (!hlslSourceOr.ok()) { return hlslSourceOr.status(); } std::string hlslSource = std::move(hlslSourceOr).value(); std::string filteredHlsl = removeNvapiInclude(hlslSource); fs::path rawOutputPath = request.outputPath; rawOutputPath.replace_extension(".raw.hlsl"); if (absl::Status status = writeTextFile(rawOutputPath, hlslSource); !status.ok()) { return status; } IncludeGuardInfo includeGuard = detectIncludeGuard(request.modulePath); std::string finalHlsl = applyIncludeGuard(filteredHlsl, includeGuard); if (absl::Status status = writeTextFile(request.outputPath, finalHlsl); !status.ok()) { return status; } std::cerr << "Generated HLSL written to " << request.outputPath << std::endl; return absl::OkStatus(); } int main(int argc, char** argv) { absl::Status status = run(argc, argv); if (!status.ok()) { std::cerr << status.message() << std::endl; return 1; } return 0; }