diff options
| author | yum <yum.food.vr@gmail.com> | 2025-10-11 11:27:19 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2025-10-11 11:27:19 -0700 |
| commit | 06780e36b2aeb1257607c89570fd99903508b82e (patch) | |
| tree | d8496afcb9a3e3808c41ea50f4bd3e490595faf0 /main.cc | |
| parent | 1db7ecb9c0235b9317b5c318685bbbfa8a2309d1 (diff) | |
smol cleanup
Diffstat (limited to 'main.cc')
| -rw-r--r-- | main.cc | 207 |
1 files changed, 120 insertions, 87 deletions
@@ -14,15 +14,35 @@ #include <slang.h> #include <slang-com-ptr.h> +namespace fs = std::filesystem; + +using ::slang::CompilerOptionEntry; +using ::slang::CompilerOptionName; +using ::slang::createGlobalSession; +using ::slang::DeclReflection; +using ::slang::FunctionReflection; +using ::slang::IBlob; +using ::slang::ICompileRequest; +using ::slang::IGlobalSession; +using ::slang::IModule; +using ::slang::ISession; +using ::slang::SessionDesc; +using ::slang::TargetDesc; +using ::slang::TypeReflection; +using ::slang::VariableReflection; + +template <typename T> +using ComPtr = ::Slang::ComPtr<T>; + // Print any diagnostics carried by a Slang blob with optional context information. -void printDiagnostics(const char* context, slang::IBlob* diagnostics) +void printDiagnostics(const char* context, IBlob* diagnostics) { if (!diagnostics) { return; } - size_t size = diagnostics->getBufferSize(); + std::size_t size = diagnostics->getBufferSize(); if (size == 0) { return; @@ -52,7 +72,7 @@ void printDiagnostics(const char* context, slang::IBlob* diagnostics) } // Helper to check Slang API results and surface diagnostic details when available. -void checkResult(const char* context, SlangResult res, slang::IBlob* diagnostics = nullptr) +void checkResult(const char* context, SlangResult res, IBlob* diagnostics = nullptr) { printDiagnostics(context, diagnostics); @@ -65,6 +85,27 @@ void checkResult(const char* context, SlangResult res, slang::IBlob* diagnostics } } +bool writeTextFile(const fs::path& path, std::string_view contents) +{ + std::ofstream file(path, std::ios::binary); + if (!file) + { + std::cerr << "Warning: Failed to open " << path << " for writing." << std::endl; + return false; + } + + file.write(contents.data(), static_cast<std::streamsize>(contents.size())); + file.close(); + + if (!file) + { + std::cerr << "Warning: Failed to write " << path << std::endl; + return false; + } + + return true; +} + struct ParameterInfo { std::string name; @@ -89,8 +130,8 @@ struct IncludeGuardInfo std::string trim(std::string_view text) { - size_t start = 0; - size_t end = text.size(); + std::size_t start = 0; + std::size_t end = text.size(); while (start < end && std::isspace(static_cast<unsigned char>(text[start]))) { @@ -105,15 +146,15 @@ std::string trim(std::string_view text) return std::string(text.substr(start, end - start)); } -std::string getTypeName(slang::TypeReflection* type) +std::string getTypeName(TypeReflection* type) { if (!type) { return {}; } - using Kind = slang::TypeReflection::Kind; - using Scalar = slang::TypeReflection::ScalarType; + using Kind = TypeReflection::Kind; + using Scalar = TypeReflection::ScalarType; auto scalarToString = [](Scalar scalar) -> std::string { @@ -151,7 +192,7 @@ std::string getTypeName(slang::TypeReflection* type) case Kind::Vector: { const std::string elementTypeName = getTypeName(type->getElementType()); - const size_t elementCount = type->getElementCount(); + const std::size_t elementCount = type->getElementCount(); if (!elementTypeName.empty() && elementCount > 0) { return elementTypeName + std::to_string(elementCount); @@ -181,7 +222,7 @@ std::string getTypeName(slang::TypeReflection* type) } } - Slang::ComPtr<ISlangBlob> fullNameBlob; + ComPtr<ISlangBlob> fullNameBlob; if (SLANG_SUCCEEDED(type->getFullName(fullNameBlob.writeRef())) && fullNameBlob && fullNameBlob->getBufferSize() > 0) { @@ -192,8 +233,8 @@ std::string getTypeName(slang::TypeReflection* type) { if (name.substr(0, 7) == "vector<") { - size_t commaPos = name.find(','); - size_t endPos = name.rfind('>'); + std::size_t commaPos = name.find(','); + std::size_t endPos = name.rfind('>'); if (commaPos != std::string_view::npos && endPos != std::string_view::npos && commaPos + 1 < endPos) { @@ -204,9 +245,9 @@ std::string getTypeName(slang::TypeReflection* type) } if (name.substr(0, 7) == "matrix<") { - size_t firstComma = name.find(','); - size_t secondComma = name.find(',', firstComma + 1); - size_t endPos = name.rfind('>'); + std::size_t firstComma = name.find(','); + std::size_t secondComma = name.find(',', firstComma + 1); + std::size_t endPos = name.rfind('>'); if (firstComma != std::string_view::npos && secondComma != std::string_view::npos && endPos != std::string_view::npos) { @@ -225,15 +266,15 @@ std::string getTypeName(slang::TypeReflection* type) return {}; } -bool isTopLevelFunction(slang::DeclReflection* functionDecl) +bool isTopLevelFunction(DeclReflection* functionDecl) { if (!functionDecl) { return false; } - using Kind = slang::DeclReflection::Kind; - for (slang::DeclReflection* parent = functionDecl->getParent(); parent; + using Kind = DeclReflection::Kind; + for (DeclReflection* parent = functionDecl->getParent(); parent; parent = parent->getParent()) { switch (parent->getKind()) @@ -253,7 +294,7 @@ bool isTopLevelFunction(slang::DeclReflection* functionDecl) // Recursively gather function declarations defined in the supplied Slang module. void collectFunctionInfos( - slang::DeclReflection* decl, + DeclReflection* decl, std::vector<FunctionInfo>& functions, std::unordered_set<std::string>& seenNames) { @@ -262,7 +303,7 @@ void collectFunctionInfos( return; } - using Kind = slang::DeclReflection::Kind; + using Kind = DeclReflection::Kind; switch (decl->getKind()) { @@ -276,7 +317,7 @@ void collectFunctionInfos( FunctionInfo info; info.name = name; - if (slang::TypeReflection* returnType = functionReflection->getReturnType()) + if (TypeReflection* returnType = functionReflection->getReturnType()) { info.returnType = getTypeName(returnType); } @@ -289,7 +330,7 @@ void collectFunctionInfos( info.parameters.reserve(paramCount); for (unsigned i = 0; i < paramCount; ++i) { - slang::VariableReflection* paramReflection = + VariableReflection* paramReflection = functionReflection->getParameterByIndex(i); ParameterInfo paramInfo; if (auto* typeReflection = paramReflection->getType()) @@ -352,17 +393,17 @@ std::string rewriteHLSLWithWrappers( std::unordered_map<std::string, std::vector<EntryPointField>> baseNameToFields; const std::string structPrefix = "struct EntryPointParams_"; - size_t searchPos = 0; + std::size_t searchPos = 0; while (true) { - size_t structPos = result.find(structPrefix, searchPos); + std::size_t structPos = result.find(structPrefix, searchPos); if (structPos == std::string::npos) { break; } - size_t indexPos = structPos + structPrefix.size(); - size_t indexEnd = indexPos; + std::size_t indexPos = structPos + structPrefix.size(); + std::size_t indexEnd = indexPos; while (indexEnd < result.size() && std::isdigit(static_cast<unsigned char>(result[indexEnd]))) { ++indexEnd; @@ -375,21 +416,21 @@ std::string rewriteHLSLWithWrappers( int bufferIndex = std::stoi(result.substr(indexPos, indexEnd - indexPos)); - size_t braceOpen = result.find('{', indexEnd); + std::size_t braceOpen = result.find('{', indexEnd); if (braceOpen == std::string::npos) { break; } - size_t braceClose = result.find("};", braceOpen); + std::size_t braceClose = result.find("};", braceOpen); if (braceClose == std::string::npos) { break; } - size_t fieldPos = braceOpen + 1; + std::size_t fieldPos = braceOpen + 1; while (fieldPos < braceClose) { - size_t semicolon = result.find(';', fieldPos); + std::size_t semicolon = result.find(';', fieldPos); if (semicolon == std::string::npos || semicolon > braceClose) { break; @@ -398,18 +439,18 @@ std::string rewriteHLSLWithWrappers( std::string line = trim(std::string_view(result).substr(fieldPos, semicolon - fieldPos)); if (!line.empty()) { - size_t lastSpace = line.find_last_of(" \t"); + std::size_t lastSpace = line.find_last_of(" \t"); if (lastSpace != std::string::npos && lastSpace + 1 < line.size()) { std::string fieldName = line.substr(lastSpace + 1); - size_t bracketPos = fieldName.find('['); + std::size_t bracketPos = fieldName.find('['); if (bracketPos != std::string::npos) { fieldName = fieldName.substr(0, bracketPos); } std::string baseName = fieldName; - size_t underscorePos = baseName.rfind('_'); + std::size_t underscorePos = baseName.rfind('_'); if (underscorePos != std::string::npos) { baseName = baseName.substr(0, underscorePos); @@ -428,13 +469,13 @@ std::string rewriteHLSLWithWrappers( } const std::string attributeToken = "[shader(\"dispatch\")]export"; - size_t searchFrom = 0; + std::size_t searchFrom = 0; for (const FunctionInfo& func : functions) { - size_t attrPos = result.find(attributeToken, searchFrom); + std::size_t attrPos = result.find(attributeToken, searchFrom); if (attrPos != std::string::npos) { - size_t attrEnd = attrPos + attributeToken.size(); + std::size_t attrEnd = attrPos + attributeToken.size(); if (attrEnd < result.size() && result[attrEnd] == '\r') { ++attrEnd; @@ -447,7 +488,7 @@ std::string rewriteHLSLWithWrappers( searchFrom = attrPos; } - size_t namePos = result.find(func.name + "(", searchFrom); + std::size_t namePos = result.find(func.name + "(", searchFrom); if (namePos == std::string::npos) { namePos = result.find(func.name + "("); @@ -463,13 +504,13 @@ std::string rewriteHLSLWithWrappers( std::ostringstream wrapperBuilder; wrapperBuilder << "\n"; - for (size_t functionIndex = 0; functionIndex < functions.size(); ++functionIndex) + for (std::size_t functionIndex = 0; functionIndex < functions.size(); ++functionIndex) { const FunctionInfo& func = functions[functionIndex]; const std::string entryName = "__slang_entry_" + func.name; std::string parameterList; - for (size_t i = 0; i < func.parameters.size(); ++i) + for (std::size_t i = 0; i < func.parameters.size(); ++i) { if (i > 0) { @@ -483,7 +524,7 @@ std::string rewriteHLSLWithWrappers( std::unordered_set<std::string> emittedAssignments; - for (size_t paramIndex = 0; paramIndex < func.parameters.size(); ++paramIndex) + for (std::size_t paramIndex = 0; paramIndex < func.parameters.size(); ++paramIndex) { const ParameterInfo& param = func.parameters[paramIndex]; @@ -532,9 +573,9 @@ std::string rewriteHLSLWithWrappers( const std::string wrappers = wrapperBuilder.str(); if (!wrappers.empty()) { - const size_t endifPos = result.rfind("#endif"); - const size_t ifndefPos = result.find("#ifndef"); - const size_t definePos = result.find("#define", ifndefPos != std::string::npos ? ifndefPos : 0); + const std::size_t endifPos = result.rfind("#endif"); + const std::size_t ifndefPos = result.find("#ifndef"); + const std::size_t definePos = result.find("#define", ifndefPos != std::string::npos ? ifndefPos : 0); const bool hasGuard = ifndefPos != std::string::npos && definePos != std::string::npos && @@ -553,7 +594,7 @@ std::string rewriteHLSLWithWrappers( return result; } -IncludeGuardInfo detectIncludeGuard(const std::filesystem::path& sourcePath) +IncludeGuardInfo detectIncludeGuard(const fs::path& sourcePath) { IncludeGuardInfo info; @@ -570,8 +611,8 @@ IncludeGuardInfo detectIncludeGuard(const std::filesystem::path& sourcePath) lines.push_back(line); } - size_t ifndefIndex = std::numeric_limits<size_t>::max(); - for (size_t i = 0; i < lines.size(); ++i) + std::size_t ifndefIndex = std::numeric_limits<std::size_t>::max(); + for (std::size_t i = 0; i < lines.size(); ++i) { std::string trimmed = trim(lines[i]); if (trimmed.rfind("#ifndef", 0) == 0) @@ -595,7 +636,7 @@ IncludeGuardInfo detectIncludeGuard(const std::filesystem::path& sourcePath) return info; } - for (size_t i = ifndefIndex + 1; i < lines.size(); ++i) + for (std::size_t i = ifndefIndex + 1; i < lines.size(); ++i) { std::string trimmed = trim(lines[i]); if (trimmed.rfind("#define", 0) == 0) @@ -618,7 +659,7 @@ IncludeGuardInfo detectIncludeGuard(const std::filesystem::path& sourcePath) return info; } - for (size_t i = lines.size(); i-- > 0;) + for (std::size_t i = lines.size(); i-- > 0;) { std::string trimmed = trim(lines[i]); if (trimmed.rfind("#endif", 0) == 0) @@ -646,8 +687,8 @@ int main(int argc, char** argv) return 1; } - std::filesystem::path modulePath = std::filesystem::absolute(argv[1]); - if (!std::filesystem::exists(modulePath)) + fs::path modulePath = fs::absolute(argv[1]); + if (!fs::exists(modulePath)) { std::cerr << "Module not found: " << modulePath << std::endl; return 1; @@ -658,73 +699,72 @@ int main(int argc, char** argv) std::cerr << "Expected a .slang file: " << modulePath << std::endl; return 1; } - std::string moduleName = modulePath.stem().string(); std::string searchPath = modulePath.has_parent_path() ? modulePath.parent_path().string() - : std::filesystem::current_path().string(); + : fs::current_path().string(); - std::filesystem::path outputPath = modulePath; + fs::path outputPath = modulePath; outputPath.replace_extension(".hlsl"); IncludeGuardInfo includeGuard = detectIncludeGuard(modulePath); // 1. Session Creation - Slang::ComPtr<slang::IGlobalSession> globalSession; - checkResult("slang::createGlobalSession", slang::createGlobalSession(globalSession.writeRef())); + ComPtr<IGlobalSession> globalSession; + checkResult("createGlobalSession", createGlobalSession(globalSession.writeRef())); // 2. Target Configuration - slang::TargetDesc targetDesc = {}; + TargetDesc targetDesc = {}; targetDesc.format = SLANG_HLSL; targetDesc.profile = globalSession->findProfile("lib_6_6"); targetDesc.flags = SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM; - std::vector<slang::CompilerOptionEntry> targetOptions; + std::vector<CompilerOptionEntry> targetOptions; { - slang::CompilerOptionEntry entry = {}; - entry.name = slang::CompilerOptionName::NoHLSLBinding; + CompilerOptionEntry entry = {}; + entry.name = CompilerOptionName::NoHLSLBinding; entry.value.intValue0 = 1; targetOptions.push_back(entry); } { - slang::CompilerOptionEntry entry = {}; - entry.name = slang::CompilerOptionName::NoHLSLPackConstantBufferElements; + CompilerOptionEntry entry = {}; + entry.name = CompilerOptionName::NoHLSLPackConstantBufferElements; entry.value.intValue0 = 1; targetOptions.push_back(entry); } targetDesc.compilerOptionEntries = targetOptions.data(); targetDesc.compilerOptionEntryCount = static_cast<uint32_t>(targetOptions.size()); - slang::SessionDesc sessionDesc = {}; + SessionDesc sessionDesc = {}; sessionDesc.targets = &targetDesc; sessionDesc.targetCount = 1; const char* searchPaths[] = { searchPath.c_str() }; sessionDesc.searchPaths = searchPaths; sessionDesc.searchPathCount = 1; - std::vector<slang::CompilerOptionEntry> sessionOptions; + std::vector<CompilerOptionEntry> sessionOptions; { - slang::CompilerOptionEntry entry = {}; - entry.name = slang::CompilerOptionName::NoMangle; + CompilerOptionEntry entry = {}; + entry.name = CompilerOptionName::NoMangle; entry.value.intValue0 = 1; sessionOptions.push_back(entry); } { - slang::CompilerOptionEntry entry = {}; - entry.name = slang::CompilerOptionName::DisableNonEssentialValidations; + CompilerOptionEntry entry = {}; + entry.name = CompilerOptionName::DisableNonEssentialValidations; entry.value.intValue0 = 1; sessionOptions.push_back(entry); } sessionDesc.compilerOptionEntries = sessionOptions.data(); sessionDesc.compilerOptionEntryCount = static_cast<uint32_t>(sessionOptions.size()); - Slang::ComPtr<slang::ISession> session; + ComPtr<ISession> session; checkResult("IGlobalSession::createSession", globalSession->createSession(sessionDesc, session.writeRef())); // 3. Module Loading (from the supplied Slang source file) - Slang::ComPtr<slang::IModule> libraryModule; + ComPtr<IModule> libraryModule; { - Slang::ComPtr<slang::IBlob> diagnosticsBlob; + ComPtr<IBlob> diagnosticsBlob; libraryModule = session->loadModule(moduleName.c_str(), diagnosticsBlob.writeRef()); const std::string diagnosticsContext = "loadModule: " + moduleName; printDiagnostics(diagnosticsContext.c_str(), diagnosticsBlob); @@ -737,7 +777,7 @@ int main(int argc, char** argv) // 4. Discover top-level functions to treat as entry points std::vector<FunctionInfo> functions; std::unordered_set<std::string> seenNames; - slang::DeclReflection* moduleReflection = libraryModule->getModuleReflection(); + DeclReflection* moduleReflection = libraryModule->getModuleReflection(); if (!moduleReflection) { std::cerr << "Failed to retrieve reflection data for module '" << moduleName << "'." @@ -753,7 +793,7 @@ int main(int argc, char** argv) } // 5. Compile the translation unit with whole-program emission - Slang::ComPtr<slang::ICompileRequest> compileRequest; + ComPtr<ICompileRequest> compileRequest; checkResult( "ISession::createCompileRequest", session->createCompileRequest(compileRequest.writeRef())); @@ -789,11 +829,11 @@ int main(int argc, char** argv) } SlangResult compileResult = compileRequest->compile(); - Slang::ComPtr<slang::IBlob> compileDiagnostics; + ComPtr<IBlob> compileDiagnostics; compileRequest->getDiagnosticOutputBlob(compileDiagnostics.writeRef()); checkResult("ICompileRequest::compile", compileResult, compileDiagnostics); - Slang::ComPtr<slang::IBlob> targetCodeBlob; + ComPtr<IBlob> targetCodeBlob; checkResult( "ICompileRequest::getTargetCodeBlob", compileRequest->getTargetCodeBlob(0, targetCodeBlob.writeRef())); @@ -806,7 +846,11 @@ int main(int argc, char** argv) std::string hlslSource( static_cast<const char*>(targetCodeBlob->getBufferPointer()), - static_cast<size_t>(targetCodeBlob->getBufferSize())); + static_cast<std::size_t>(targetCodeBlob->getBufferSize())); + + fs::path rawOutputPath = outputPath; + rawOutputPath.replace_extension(".raw.hlsl"); + writeTextFile(rawOutputPath, hlslSource); std::string finalHlsl = rewriteHLSLWithWrappers(hlslSource, functions); @@ -846,19 +890,8 @@ int main(int argc, char** argv) } // 6. Write HLSL output to a sibling .hlsl file - std::ofstream outputFile(outputPath, std::ios::binary); - if (!outputFile) - { - std::cerr << "Failed to open output path for writing: " << outputPath << std::endl; - return 1; - } - - outputFile.write(finalHlsl.data(), static_cast<std::streamsize>(finalHlsl.size())); - outputFile.close(); - - if (!outputFile) + if (!writeTextFile(outputPath, finalHlsl)) { - std::cerr << "Failed to write HLSL output to " << outputPath << std::endl; return 1; } |
