summaryrefslogtreecommitdiffstats
path: root/main.cc
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2025-10-11 13:04:57 -0700
committeryum <yum.food.vr@gmail.com>2025-10-11 13:04:57 -0700
commitd1ecda540ddd6e9ab53f7981de65c3e435c1132c (patch)
tree6907da6e0d4df248ce19ea825eccdf261f28ae9c /main.cc
parent06780e36b2aeb1257607c89570fd99903508b82e (diff)
more stuff
Diffstat (limited to 'main.cc')
-rw-r--r--main.cc98
1 files changed, 39 insertions, 59 deletions
diff --git a/main.cc b/main.cc
index 7c67428..2329ce8 100644
--- a/main.cc
+++ b/main.cc
@@ -106,6 +106,14 @@ bool writeTextFile(const fs::path& path, std::string_view contents)
return true;
}
+void addCompilerOption(std::vector<CompilerOptionEntry>& options, CompilerOptionName name)
+{
+ CompilerOptionEntry entry = {};
+ entry.name = name;
+ entry.value.intValue0 = 1;
+ options.push_back(entry);
+}
+
struct ParameterInfo
{
std::string name;
@@ -312,7 +320,11 @@ void collectFunctionInfos(
{
if (const char* name = functionReflection->getName())
{
- if (*name && seenNames.insert(name).second && isTopLevelFunction(decl))
+ // Heuristic: functions that don't start with underscore are considered public
+ // (Slang convention: private/internal functions typically start with _)
+ bool isPublic = name[0] != '_';
+
+ if (*name && seenNames.insert(name).second && isTopLevelFunction(decl) && isPublic)
{
FunctionInfo info;
info.name = name;
@@ -352,6 +364,7 @@ void collectFunctionInfos(
info.parameters.push_back(std::move(paramInfo));
}
+ std::cerr << "Discovered entry point: " << info.name << std::endl;
functions.push_back(std::move(info));
}
}
@@ -681,24 +694,20 @@ IncludeGuardInfo detectIncludeGuard(const fs::path& sourcePath)
int main(int argc, char** argv)
{
- if (argc < 2)
- {
+ if (argc < 2) {
std::cerr << "Usage: " << (argc > 0 ? argv[0] : "modular_slang") << " <module.slang>" << std::endl;
return 1;
}
-
fs::path modulePath = fs::absolute(argv[1]);
- if (!fs::exists(modulePath))
- {
+ if (!fs::exists(modulePath)) {
std::cerr << "Module not found: " << modulePath << std::endl;
return 1;
}
-
- if (modulePath.extension() != ".slang")
- {
+ if (modulePath.extension() != ".slang") {
std::cerr << "Expected a .slang file: " << modulePath << std::endl;
return 1;
}
+
std::string moduleName = modulePath.stem().string();
std::string searchPath = modulePath.has_parent_path()
? modulePath.parent_path().string()
@@ -707,31 +716,23 @@ int main(int argc, char** argv)
fs::path outputPath = modulePath;
outputPath.replace_extension(".hlsl");
- IncludeGuardInfo includeGuard = detectIncludeGuard(modulePath);
-
- // 1. Session Creation
+ // Create session
ComPtr<IGlobalSession> globalSession;
checkResult("createGlobalSession", createGlobalSession(globalSession.writeRef()));
- // 2. Target Configuration
+ // Configure session and target
TargetDesc targetDesc = {};
targetDesc.format = SLANG_HLSL;
targetDesc.profile = globalSession->findProfile("lib_6_6");
targetDesc.flags = SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM;
- std::vector<CompilerOptionEntry> targetOptions;
- {
- CompilerOptionEntry entry = {};
- entry.name = CompilerOptionName::NoHLSLBinding;
- entry.value.intValue0 = 1;
- targetOptions.push_back(entry);
- }
- {
- CompilerOptionEntry entry = {};
- entry.name = CompilerOptionName::NoHLSLPackConstantBufferElements;
- entry.value.intValue0 = 1;
- targetOptions.push_back(entry);
- }
+ std::vector<CompilerOptionEntry> options;
+ addCompilerOption(options, CompilerOptionName::DisableNonEssentialValidations);
+ addCompilerOption(options, CompilerOptionName::NoHLSLBinding);
+ addCompilerOption(options, CompilerOptionName::NoMangle);
+ addCompilerOption(options, CompilerOptionName::NoHLSLPackConstantBufferElements);
+
+ std::vector<CompilerOptionEntry> targetOptions(options);
targetDesc.compilerOptionEntries = targetOptions.data();
targetDesc.compilerOptionEntryCount = static_cast<uint32_t>(targetOptions.size());
@@ -742,26 +743,15 @@ int main(int argc, char** argv)
sessionDesc.searchPaths = searchPaths;
sessionDesc.searchPathCount = 1;
- std::vector<CompilerOptionEntry> sessionOptions;
- {
- CompilerOptionEntry entry = {};
- entry.name = CompilerOptionName::NoMangle;
- entry.value.intValue0 = 1;
- sessionOptions.push_back(entry);
- }
- {
- CompilerOptionEntry entry = {};
- entry.name = CompilerOptionName::DisableNonEssentialValidations;
- entry.value.intValue0 = 1;
- sessionOptions.push_back(entry);
- }
+ std::vector<CompilerOptionEntry> sessionOptions(options);
sessionDesc.compilerOptionEntries = sessionOptions.data();
sessionDesc.compilerOptionEntryCount = static_cast<uint32_t>(sessionOptions.size());
ComPtr<ISession> session;
- checkResult("IGlobalSession::createSession", globalSession->createSession(sessionDesc, session.writeRef()));
+ checkResult("IGlobalSession::createSession",
+ globalSession->createSession(sessionDesc, session.writeRef()));
- // 3. Module Loading (from the supplied Slang source file)
+ // Load the "module" aka the library
ComPtr<IModule> libraryModule;
{
ComPtr<IBlob> diagnosticsBlob;
@@ -774,25 +764,23 @@ int main(int argc, char** argv)
return 1;
}
}
- // 4. Discover top-level functions to treat as entry points
+
+ // Discover top-level functions to treat as entry points
std::vector<FunctionInfo> functions;
std::unordered_set<std::string> seenNames;
DeclReflection* moduleReflection = libraryModule->getModuleReflection();
- if (!moduleReflection)
- {
+ if (!moduleReflection) {
std::cerr << "Failed to retrieve reflection data for module '" << moduleName << "'."
<< std::endl;
return 1;
}
collectFunctionInfos(moduleReflection, functions, seenNames);
-
- if (functions.empty())
- {
+ if (functions.empty()) {
std::cerr << "No functions found in module '" << moduleName << "'." << std::endl;
return 1;
}
- // 5. Compile the translation unit with whole-program emission
+ // Compile
ComPtr<ICompileRequest> compileRequest;
checkResult(
"ISession::createCompileRequest",
@@ -801,10 +789,8 @@ int main(int argc, char** argv)
compileRequest->setCodeGenTarget(SLANG_HLSL);
compileRequest->setTargetProfile(0, targetDesc.profile);
compileRequest->setTargetFlags(0, SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM);
-
- SlangCompileFlags compileFlags = compileRequest->getCompileFlags();
- compileFlags |= SLANG_COMPILE_FLAG_NO_MANGLING;
- compileRequest->setCompileFlags(compileFlags);
+ compileRequest->setMatrixLayoutMode(SLANG_MATRIX_LAYOUT_ROW_MAJOR);
+ compileRequest->setLineDirectiveMode(SLANG_LINE_DIRECTIVE_MODE_NONE);
compileRequest->addSearchPath(searchPath.c_str());
@@ -854,6 +840,7 @@ int main(int argc, char** argv)
std::string finalHlsl = rewriteHLSLWithWrappers(hlslSource, functions);
+ IncludeGuardInfo includeGuard = detectIncludeGuard(modulePath);
if (includeGuard.present)
{
const std::string guardIfndefToken = "#ifndef " + includeGuard.macro;
@@ -895,13 +882,6 @@ int main(int argc, char** argv)
return 1;
}
- // Also stream to standard output to retain previous behavior
- std::cout.write(finalHlsl.data(), static_cast<std::streamsize>(finalHlsl.size()));
- if (finalHlsl.empty() || finalHlsl.back() != '\n')
- {
- std::cout << std::endl;
- }
-
std::cerr << "Generated HLSL written to " << outputPath << std::endl;
return 0;