summaryrefslogtreecommitdiffstats
path: root/main.cc
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2025-10-11 11:27:19 -0700
committeryum <yum.food.vr@gmail.com>2025-10-11 11:27:19 -0700
commit06780e36b2aeb1257607c89570fd99903508b82e (patch)
treed8496afcb9a3e3808c41ea50f4bd3e490595faf0 /main.cc
parent1db7ecb9c0235b9317b5c318685bbbfa8a2309d1 (diff)
smol cleanup
Diffstat (limited to 'main.cc')
-rw-r--r--main.cc207
1 files changed, 120 insertions, 87 deletions
diff --git a/main.cc b/main.cc
index 3163fee..7c67428 100644
--- a/main.cc
+++ b/main.cc
@@ -14,15 +14,35 @@
#include <slang.h>
#include <slang-com-ptr.h>
+namespace fs = std::filesystem;
+
+using ::slang::CompilerOptionEntry;
+using ::slang::CompilerOptionName;
+using ::slang::createGlobalSession;
+using ::slang::DeclReflection;
+using ::slang::FunctionReflection;
+using ::slang::IBlob;
+using ::slang::ICompileRequest;
+using ::slang::IGlobalSession;
+using ::slang::IModule;
+using ::slang::ISession;
+using ::slang::SessionDesc;
+using ::slang::TargetDesc;
+using ::slang::TypeReflection;
+using ::slang::VariableReflection;
+
+template <typename T>
+using ComPtr = ::Slang::ComPtr<T>;
+
// Print any diagnostics carried by a Slang blob with optional context information.
-void printDiagnostics(const char* context, slang::IBlob* diagnostics)
+void printDiagnostics(const char* context, IBlob* diagnostics)
{
if (!diagnostics)
{
return;
}
- size_t size = diagnostics->getBufferSize();
+ std::size_t size = diagnostics->getBufferSize();
if (size == 0)
{
return;
@@ -52,7 +72,7 @@ void printDiagnostics(const char* context, slang::IBlob* diagnostics)
}
// Helper to check Slang API results and surface diagnostic details when available.
-void checkResult(const char* context, SlangResult res, slang::IBlob* diagnostics = nullptr)
+void checkResult(const char* context, SlangResult res, IBlob* diagnostics = nullptr)
{
printDiagnostics(context, diagnostics);
@@ -65,6 +85,27 @@ void checkResult(const char* context, SlangResult res, slang::IBlob* diagnostics
}
}
+bool writeTextFile(const fs::path& path, std::string_view contents)
+{
+ std::ofstream file(path, std::ios::binary);
+ if (!file)
+ {
+ std::cerr << "Warning: Failed to open " << path << " for writing." << std::endl;
+ return false;
+ }
+
+ file.write(contents.data(), static_cast<std::streamsize>(contents.size()));
+ file.close();
+
+ if (!file)
+ {
+ std::cerr << "Warning: Failed to write " << path << std::endl;
+ return false;
+ }
+
+ return true;
+}
+
struct ParameterInfo
{
std::string name;
@@ -89,8 +130,8 @@ struct IncludeGuardInfo
std::string trim(std::string_view text)
{
- size_t start = 0;
- size_t end = text.size();
+ std::size_t start = 0;
+ std::size_t end = text.size();
while (start < end && std::isspace(static_cast<unsigned char>(text[start])))
{
@@ -105,15 +146,15 @@ std::string trim(std::string_view text)
return std::string(text.substr(start, end - start));
}
-std::string getTypeName(slang::TypeReflection* type)
+std::string getTypeName(TypeReflection* type)
{
if (!type)
{
return {};
}
- using Kind = slang::TypeReflection::Kind;
- using Scalar = slang::TypeReflection::ScalarType;
+ using Kind = TypeReflection::Kind;
+ using Scalar = TypeReflection::ScalarType;
auto scalarToString = [](Scalar scalar) -> std::string
{
@@ -151,7 +192,7 @@ std::string getTypeName(slang::TypeReflection* type)
case Kind::Vector:
{
const std::string elementTypeName = getTypeName(type->getElementType());
- const size_t elementCount = type->getElementCount();
+ const std::size_t elementCount = type->getElementCount();
if (!elementTypeName.empty() && elementCount > 0)
{
return elementTypeName + std::to_string(elementCount);
@@ -181,7 +222,7 @@ std::string getTypeName(slang::TypeReflection* type)
}
}
- Slang::ComPtr<ISlangBlob> fullNameBlob;
+ ComPtr<ISlangBlob> fullNameBlob;
if (SLANG_SUCCEEDED(type->getFullName(fullNameBlob.writeRef())) && fullNameBlob &&
fullNameBlob->getBufferSize() > 0)
{
@@ -192,8 +233,8 @@ std::string getTypeName(slang::TypeReflection* type)
{
if (name.substr(0, 7) == "vector<")
{
- size_t commaPos = name.find(',');
- size_t endPos = name.rfind('>');
+ std::size_t commaPos = name.find(',');
+ std::size_t endPos = name.rfind('>');
if (commaPos != std::string_view::npos && endPos != std::string_view::npos &&
commaPos + 1 < endPos)
{
@@ -204,9 +245,9 @@ std::string getTypeName(slang::TypeReflection* type)
}
if (name.substr(0, 7) == "matrix<")
{
- size_t firstComma = name.find(',');
- size_t secondComma = name.find(',', firstComma + 1);
- size_t endPos = name.rfind('>');
+ std::size_t firstComma = name.find(',');
+ std::size_t secondComma = name.find(',', firstComma + 1);
+ std::size_t endPos = name.rfind('>');
if (firstComma != std::string_view::npos && secondComma != std::string_view::npos &&
endPos != std::string_view::npos)
{
@@ -225,15 +266,15 @@ std::string getTypeName(slang::TypeReflection* type)
return {};
}
-bool isTopLevelFunction(slang::DeclReflection* functionDecl)
+bool isTopLevelFunction(DeclReflection* functionDecl)
{
if (!functionDecl)
{
return false;
}
- using Kind = slang::DeclReflection::Kind;
- for (slang::DeclReflection* parent = functionDecl->getParent(); parent;
+ using Kind = DeclReflection::Kind;
+ for (DeclReflection* parent = functionDecl->getParent(); parent;
parent = parent->getParent())
{
switch (parent->getKind())
@@ -253,7 +294,7 @@ bool isTopLevelFunction(slang::DeclReflection* functionDecl)
// Recursively gather function declarations defined in the supplied Slang module.
void collectFunctionInfos(
- slang::DeclReflection* decl,
+ DeclReflection* decl,
std::vector<FunctionInfo>& functions,
std::unordered_set<std::string>& seenNames)
{
@@ -262,7 +303,7 @@ void collectFunctionInfos(
return;
}
- using Kind = slang::DeclReflection::Kind;
+ using Kind = DeclReflection::Kind;
switch (decl->getKind())
{
@@ -276,7 +317,7 @@ void collectFunctionInfos(
FunctionInfo info;
info.name = name;
- if (slang::TypeReflection* returnType = functionReflection->getReturnType())
+ if (TypeReflection* returnType = functionReflection->getReturnType())
{
info.returnType = getTypeName(returnType);
}
@@ -289,7 +330,7 @@ void collectFunctionInfos(
info.parameters.reserve(paramCount);
for (unsigned i = 0; i < paramCount; ++i)
{
- slang::VariableReflection* paramReflection =
+ VariableReflection* paramReflection =
functionReflection->getParameterByIndex(i);
ParameterInfo paramInfo;
if (auto* typeReflection = paramReflection->getType())
@@ -352,17 +393,17 @@ std::string rewriteHLSLWithWrappers(
std::unordered_map<std::string, std::vector<EntryPointField>> baseNameToFields;
const std::string structPrefix = "struct EntryPointParams_";
- size_t searchPos = 0;
+ std::size_t searchPos = 0;
while (true)
{
- size_t structPos = result.find(structPrefix, searchPos);
+ std::size_t structPos = result.find(structPrefix, searchPos);
if (structPos == std::string::npos)
{
break;
}
- size_t indexPos = structPos + structPrefix.size();
- size_t indexEnd = indexPos;
+ std::size_t indexPos = structPos + structPrefix.size();
+ std::size_t indexEnd = indexPos;
while (indexEnd < result.size() && std::isdigit(static_cast<unsigned char>(result[indexEnd])))
{
++indexEnd;
@@ -375,21 +416,21 @@ std::string rewriteHLSLWithWrappers(
int bufferIndex = std::stoi(result.substr(indexPos, indexEnd - indexPos));
- size_t braceOpen = result.find('{', indexEnd);
+ std::size_t braceOpen = result.find('{', indexEnd);
if (braceOpen == std::string::npos)
{
break;
}
- size_t braceClose = result.find("};", braceOpen);
+ std::size_t braceClose = result.find("};", braceOpen);
if (braceClose == std::string::npos)
{
break;
}
- size_t fieldPos = braceOpen + 1;
+ std::size_t fieldPos = braceOpen + 1;
while (fieldPos < braceClose)
{
- size_t semicolon = result.find(';', fieldPos);
+ std::size_t semicolon = result.find(';', fieldPos);
if (semicolon == std::string::npos || semicolon > braceClose)
{
break;
@@ -398,18 +439,18 @@ std::string rewriteHLSLWithWrappers(
std::string line = trim(std::string_view(result).substr(fieldPos, semicolon - fieldPos));
if (!line.empty())
{
- size_t lastSpace = line.find_last_of(" \t");
+ std::size_t lastSpace = line.find_last_of(" \t");
if (lastSpace != std::string::npos && lastSpace + 1 < line.size())
{
std::string fieldName = line.substr(lastSpace + 1);
- size_t bracketPos = fieldName.find('[');
+ std::size_t bracketPos = fieldName.find('[');
if (bracketPos != std::string::npos)
{
fieldName = fieldName.substr(0, bracketPos);
}
std::string baseName = fieldName;
- size_t underscorePos = baseName.rfind('_');
+ std::size_t underscorePos = baseName.rfind('_');
if (underscorePos != std::string::npos)
{
baseName = baseName.substr(0, underscorePos);
@@ -428,13 +469,13 @@ std::string rewriteHLSLWithWrappers(
}
const std::string attributeToken = "[shader(\"dispatch\")]export";
- size_t searchFrom = 0;
+ std::size_t searchFrom = 0;
for (const FunctionInfo& func : functions)
{
- size_t attrPos = result.find(attributeToken, searchFrom);
+ std::size_t attrPos = result.find(attributeToken, searchFrom);
if (attrPos != std::string::npos)
{
- size_t attrEnd = attrPos + attributeToken.size();
+ std::size_t attrEnd = attrPos + attributeToken.size();
if (attrEnd < result.size() && result[attrEnd] == '\r')
{
++attrEnd;
@@ -447,7 +488,7 @@ std::string rewriteHLSLWithWrappers(
searchFrom = attrPos;
}
- size_t namePos = result.find(func.name + "(", searchFrom);
+ std::size_t namePos = result.find(func.name + "(", searchFrom);
if (namePos == std::string::npos)
{
namePos = result.find(func.name + "(");
@@ -463,13 +504,13 @@ std::string rewriteHLSLWithWrappers(
std::ostringstream wrapperBuilder;
wrapperBuilder << "\n";
- for (size_t functionIndex = 0; functionIndex < functions.size(); ++functionIndex)
+ for (std::size_t functionIndex = 0; functionIndex < functions.size(); ++functionIndex)
{
const FunctionInfo& func = functions[functionIndex];
const std::string entryName = "__slang_entry_" + func.name;
std::string parameterList;
- for (size_t i = 0; i < func.parameters.size(); ++i)
+ for (std::size_t i = 0; i < func.parameters.size(); ++i)
{
if (i > 0)
{
@@ -483,7 +524,7 @@ std::string rewriteHLSLWithWrappers(
std::unordered_set<std::string> emittedAssignments;
- for (size_t paramIndex = 0; paramIndex < func.parameters.size(); ++paramIndex)
+ for (std::size_t paramIndex = 0; paramIndex < func.parameters.size(); ++paramIndex)
{
const ParameterInfo& param = func.parameters[paramIndex];
@@ -532,9 +573,9 @@ std::string rewriteHLSLWithWrappers(
const std::string wrappers = wrapperBuilder.str();
if (!wrappers.empty())
{
- const size_t endifPos = result.rfind("#endif");
- const size_t ifndefPos = result.find("#ifndef");
- const size_t definePos = result.find("#define", ifndefPos != std::string::npos ? ifndefPos : 0);
+ const std::size_t endifPos = result.rfind("#endif");
+ const std::size_t ifndefPos = result.find("#ifndef");
+ const std::size_t definePos = result.find("#define", ifndefPos != std::string::npos ? ifndefPos : 0);
const bool hasGuard =
ifndefPos != std::string::npos && definePos != std::string::npos &&
@@ -553,7 +594,7 @@ std::string rewriteHLSLWithWrappers(
return result;
}
-IncludeGuardInfo detectIncludeGuard(const std::filesystem::path& sourcePath)
+IncludeGuardInfo detectIncludeGuard(const fs::path& sourcePath)
{
IncludeGuardInfo info;
@@ -570,8 +611,8 @@ IncludeGuardInfo detectIncludeGuard(const std::filesystem::path& sourcePath)
lines.push_back(line);
}
- size_t ifndefIndex = std::numeric_limits<size_t>::max();
- for (size_t i = 0; i < lines.size(); ++i)
+ std::size_t ifndefIndex = std::numeric_limits<std::size_t>::max();
+ for (std::size_t i = 0; i < lines.size(); ++i)
{
std::string trimmed = trim(lines[i]);
if (trimmed.rfind("#ifndef", 0) == 0)
@@ -595,7 +636,7 @@ IncludeGuardInfo detectIncludeGuard(const std::filesystem::path& sourcePath)
return info;
}
- for (size_t i = ifndefIndex + 1; i < lines.size(); ++i)
+ for (std::size_t i = ifndefIndex + 1; i < lines.size(); ++i)
{
std::string trimmed = trim(lines[i]);
if (trimmed.rfind("#define", 0) == 0)
@@ -618,7 +659,7 @@ IncludeGuardInfo detectIncludeGuard(const std::filesystem::path& sourcePath)
return info;
}
- for (size_t i = lines.size(); i-- > 0;)
+ for (std::size_t i = lines.size(); i-- > 0;)
{
std::string trimmed = trim(lines[i]);
if (trimmed.rfind("#endif", 0) == 0)
@@ -646,8 +687,8 @@ int main(int argc, char** argv)
return 1;
}
- std::filesystem::path modulePath = std::filesystem::absolute(argv[1]);
- if (!std::filesystem::exists(modulePath))
+ fs::path modulePath = fs::absolute(argv[1]);
+ if (!fs::exists(modulePath))
{
std::cerr << "Module not found: " << modulePath << std::endl;
return 1;
@@ -658,73 +699,72 @@ int main(int argc, char** argv)
std::cerr << "Expected a .slang file: " << modulePath << std::endl;
return 1;
}
-
std::string moduleName = modulePath.stem().string();
std::string searchPath = modulePath.has_parent_path()
? modulePath.parent_path().string()
- : std::filesystem::current_path().string();
+ : fs::current_path().string();
- std::filesystem::path outputPath = modulePath;
+ fs::path outputPath = modulePath;
outputPath.replace_extension(".hlsl");
IncludeGuardInfo includeGuard = detectIncludeGuard(modulePath);
// 1. Session Creation
- Slang::ComPtr<slang::IGlobalSession> globalSession;
- checkResult("slang::createGlobalSession", slang::createGlobalSession(globalSession.writeRef()));
+ ComPtr<IGlobalSession> globalSession;
+ checkResult("createGlobalSession", createGlobalSession(globalSession.writeRef()));
// 2. Target Configuration
- slang::TargetDesc targetDesc = {};
+ TargetDesc targetDesc = {};
targetDesc.format = SLANG_HLSL;
targetDesc.profile = globalSession->findProfile("lib_6_6");
targetDesc.flags = SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM;
- std::vector<slang::CompilerOptionEntry> targetOptions;
+ std::vector<CompilerOptionEntry> targetOptions;
{
- slang::CompilerOptionEntry entry = {};
- entry.name = slang::CompilerOptionName::NoHLSLBinding;
+ CompilerOptionEntry entry = {};
+ entry.name = CompilerOptionName::NoHLSLBinding;
entry.value.intValue0 = 1;
targetOptions.push_back(entry);
}
{
- slang::CompilerOptionEntry entry = {};
- entry.name = slang::CompilerOptionName::NoHLSLPackConstantBufferElements;
+ CompilerOptionEntry entry = {};
+ entry.name = CompilerOptionName::NoHLSLPackConstantBufferElements;
entry.value.intValue0 = 1;
targetOptions.push_back(entry);
}
targetDesc.compilerOptionEntries = targetOptions.data();
targetDesc.compilerOptionEntryCount = static_cast<uint32_t>(targetOptions.size());
- slang::SessionDesc sessionDesc = {};
+ SessionDesc sessionDesc = {};
sessionDesc.targets = &targetDesc;
sessionDesc.targetCount = 1;
const char* searchPaths[] = { searchPath.c_str() };
sessionDesc.searchPaths = searchPaths;
sessionDesc.searchPathCount = 1;
- std::vector<slang::CompilerOptionEntry> sessionOptions;
+ std::vector<CompilerOptionEntry> sessionOptions;
{
- slang::CompilerOptionEntry entry = {};
- entry.name = slang::CompilerOptionName::NoMangle;
+ CompilerOptionEntry entry = {};
+ entry.name = CompilerOptionName::NoMangle;
entry.value.intValue0 = 1;
sessionOptions.push_back(entry);
}
{
- slang::CompilerOptionEntry entry = {};
- entry.name = slang::CompilerOptionName::DisableNonEssentialValidations;
+ CompilerOptionEntry entry = {};
+ entry.name = CompilerOptionName::DisableNonEssentialValidations;
entry.value.intValue0 = 1;
sessionOptions.push_back(entry);
}
sessionDesc.compilerOptionEntries = sessionOptions.data();
sessionDesc.compilerOptionEntryCount = static_cast<uint32_t>(sessionOptions.size());
- Slang::ComPtr<slang::ISession> session;
+ ComPtr<ISession> session;
checkResult("IGlobalSession::createSession", globalSession->createSession(sessionDesc, session.writeRef()));
// 3. Module Loading (from the supplied Slang source file)
- Slang::ComPtr<slang::IModule> libraryModule;
+ ComPtr<IModule> libraryModule;
{
- Slang::ComPtr<slang::IBlob> diagnosticsBlob;
+ ComPtr<IBlob> diagnosticsBlob;
libraryModule = session->loadModule(moduleName.c_str(), diagnosticsBlob.writeRef());
const std::string diagnosticsContext = "loadModule: " + moduleName;
printDiagnostics(diagnosticsContext.c_str(), diagnosticsBlob);
@@ -737,7 +777,7 @@ int main(int argc, char** argv)
// 4. Discover top-level functions to treat as entry points
std::vector<FunctionInfo> functions;
std::unordered_set<std::string> seenNames;
- slang::DeclReflection* moduleReflection = libraryModule->getModuleReflection();
+ DeclReflection* moduleReflection = libraryModule->getModuleReflection();
if (!moduleReflection)
{
std::cerr << "Failed to retrieve reflection data for module '" << moduleName << "'."
@@ -753,7 +793,7 @@ int main(int argc, char** argv)
}
// 5. Compile the translation unit with whole-program emission
- Slang::ComPtr<slang::ICompileRequest> compileRequest;
+ ComPtr<ICompileRequest> compileRequest;
checkResult(
"ISession::createCompileRequest",
session->createCompileRequest(compileRequest.writeRef()));
@@ -789,11 +829,11 @@ int main(int argc, char** argv)
}
SlangResult compileResult = compileRequest->compile();
- Slang::ComPtr<slang::IBlob> compileDiagnostics;
+ ComPtr<IBlob> compileDiagnostics;
compileRequest->getDiagnosticOutputBlob(compileDiagnostics.writeRef());
checkResult("ICompileRequest::compile", compileResult, compileDiagnostics);
- Slang::ComPtr<slang::IBlob> targetCodeBlob;
+ ComPtr<IBlob> targetCodeBlob;
checkResult(
"ICompileRequest::getTargetCodeBlob",
compileRequest->getTargetCodeBlob(0, targetCodeBlob.writeRef()));
@@ -806,7 +846,11 @@ int main(int argc, char** argv)
std::string hlslSource(
static_cast<const char*>(targetCodeBlob->getBufferPointer()),
- static_cast<size_t>(targetCodeBlob->getBufferSize()));
+ static_cast<std::size_t>(targetCodeBlob->getBufferSize()));
+
+ fs::path rawOutputPath = outputPath;
+ rawOutputPath.replace_extension(".raw.hlsl");
+ writeTextFile(rawOutputPath, hlslSource);
std::string finalHlsl = rewriteHLSLWithWrappers(hlslSource, functions);
@@ -846,19 +890,8 @@ int main(int argc, char** argv)
}
// 6. Write HLSL output to a sibling .hlsl file
- std::ofstream outputFile(outputPath, std::ios::binary);
- if (!outputFile)
- {
- std::cerr << "Failed to open output path for writing: " << outputPath << std::endl;
- return 1;
- }
-
- outputFile.write(finalHlsl.data(), static_cast<std::streamsize>(finalHlsl.size()));
- outputFile.close();
-
- if (!outputFile)
+ if (!writeTextFile(outputPath, finalHlsl))
{
- std::cerr << "Failed to write HLSL output to " << outputPath << std::endl;
return 1;
}