summaryrefslogtreecommitdiffstats
path: root/main.cc
diff options
context:
space:
mode:
Diffstat (limited to 'main.cc')
-rw-r--r--main.cc134
1 files changed, 133 insertions, 1 deletions
diff --git a/main.cc b/main.cc
index 2521e07..3163fee 100644
--- a/main.cc
+++ b/main.cc
@@ -3,6 +3,7 @@
#include <fstream>
#include <iomanip>
#include <iostream>
+#include <limits>
#include <sstream>
#include <string>
#include <string_view>
@@ -77,6 +78,15 @@ struct FunctionInfo
std::vector<ParameterInfo> parameters;
};
+struct IncludeGuardInfo
+{
+ bool present = false;
+ std::string macro;
+ std::string ifndefLine;
+ std::string defineLine;
+ std::string endifLine;
+};
+
std::string trim(std::string_view text)
{
size_t start = 0;
@@ -543,6 +553,91 @@ std::string rewriteHLSLWithWrappers(
return result;
}
+IncludeGuardInfo detectIncludeGuard(const std::filesystem::path& sourcePath)
+{
+ IncludeGuardInfo info;
+
+ std::ifstream input(sourcePath);
+ if (!input)
+ {
+ return info;
+ }
+
+ std::vector<std::string> lines;
+ std::string line;
+ while (std::getline(input, line))
+ {
+ lines.push_back(line);
+ }
+
+ size_t ifndefIndex = std::numeric_limits<size_t>::max();
+ for (size_t i = 0; i < lines.size(); ++i)
+ {
+ std::string trimmed = trim(lines[i]);
+ if (trimmed.rfind("#ifndef", 0) == 0)
+ {
+ std::istringstream stream(trimmed);
+ std::string directive;
+ std::string macro;
+ stream >> directive >> macro;
+ if (!macro.empty())
+ {
+ info.macro = macro;
+ info.ifndefLine = lines[i];
+ ifndefIndex = i;
+ }
+ break;
+ }
+ }
+
+ if (info.macro.empty())
+ {
+ return info;
+ }
+
+ for (size_t i = ifndefIndex + 1; i < lines.size(); ++i)
+ {
+ std::string trimmed = trim(lines[i]);
+ if (trimmed.rfind("#define", 0) == 0)
+ {
+ std::istringstream stream(trimmed);
+ std::string directive;
+ std::string macro;
+ stream >> directive >> macro;
+ if (macro == info.macro)
+ {
+ info.defineLine = lines[i];
+ break;
+ }
+ }
+ }
+
+ if (info.defineLine.empty())
+ {
+ info = IncludeGuardInfo{};
+ return info;
+ }
+
+ for (size_t i = lines.size(); i-- > 0;)
+ {
+ std::string trimmed = trim(lines[i]);
+ if (trimmed.rfind("#endif", 0) == 0)
+ {
+ info.endifLine = lines[i];
+ break;
+ }
+ }
+
+ if (info.endifLine.empty())
+ {
+ info = IncludeGuardInfo{};
+ return info;
+ }
+
+ info.present = true;
+ return info;
+}
+
int main(int argc, char** argv)
{
if (argc < 2)
@@ -572,6 +667,8 @@ int main(int argc, char** argv)
std::filesystem::path outputPath = modulePath;
outputPath.replace_extension(".hlsl");
+ IncludeGuardInfo includeGuard = detectIncludeGuard(modulePath);
+
// 1. Session Creation
Slang::ComPtr<slang::IGlobalSession> globalSession;
checkResult("slang::createGlobalSession", slang::createGlobalSession(globalSession.writeRef()));
@@ -711,7 +808,42 @@ int main(int argc, char** argv)
static_cast<const char*>(targetCodeBlob->getBufferPointer()),
static_cast<size_t>(targetCodeBlob->getBufferSize()));
- const std::string finalHlsl = rewriteHLSLWithWrappers(hlslSource, functions);
+ std::string finalHlsl = rewriteHLSLWithWrappers(hlslSource, functions);
+
+ if (includeGuard.present)
+ {
+ const std::string guardIfndefToken = "#ifndef " + includeGuard.macro;
+ const std::string guardDefineToken = "#define " + includeGuard.macro;
+ const bool alreadyGuarded =
+ finalHlsl.find(guardIfndefToken) != std::string::npos &&
+ finalHlsl.find(guardDefineToken) != std::string::npos;
+
+ if (!alreadyGuarded)
+ {
+ std::string body = finalHlsl;
+ if (!body.empty() && body.back() != '\n')
+ {
+ body += '\n';
+ }
+
+ std::ostringstream wrapped;
+ wrapped << includeGuard.ifndefLine << '\n';
+ wrapped << includeGuard.defineLine << '\n';
+ wrapped << '\n';
+ wrapped << body;
+ if (!body.empty() && body.back() != '\n')
+ {
+ wrapped << '\n';
+ }
+ wrapped << includeGuard.endifLine;
+ if (!includeGuard.endifLine.empty() && includeGuard.endifLine.back() != '\n')
+ {
+ wrapped << '\n';
+ }
+
+ finalHlsl = wrapped.str();
+ }
+ }
// 6. Write HLSL output to a sibling .hlsl file
std::ofstream outputFile(outputPath, std::ios::binary);