diff options
| author | yum <yum.food.vr@gmail.com> | 2025-10-11 13:36:52 -0700 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2025-10-11 13:36:52 -0700 |
| commit | d1dc6831813fcfd1a5dde4c8ab155f2c72b78695 (patch) | |
| tree | 668109672f09bee2aedfaee0314493fb0c8200e2 | |
| parent | d1ecda540ddd6e9ab53f7981de65c3e435c1132c (diff) | |
clean up dead code
| -rw-r--r-- | main.cc | 416 |
1 files changed, 26 insertions, 390 deletions
@@ -7,7 +7,6 @@ #include <sstream> #include <string> #include <string_view> -#include <unordered_map> #include <unordered_set> #include <vector> @@ -28,8 +27,6 @@ using ::slang::IModule; using ::slang::ISession; using ::slang::SessionDesc; using ::slang::TargetDesc; -using ::slang::TypeReflection; -using ::slang::VariableReflection; template <typename T> using ComPtr = ::Slang::ComPtr<T>; @@ -114,17 +111,9 @@ void addCompilerOption(std::vector<CompilerOptionEntry>& options, CompilerOption options.push_back(entry); } -struct ParameterInfo -{ - std::string name; - std::string typeName; -}; - struct FunctionInfo { std::string name; - std::string returnType; - std::vector<ParameterInfo> parameters; }; struct IncludeGuardInfo @@ -154,126 +143,6 @@ std::string trim(std::string_view text) return std::string(text.substr(start, end - start)); } -std::string getTypeName(TypeReflection* type) -{ - if (!type) - { - return {}; - } - - using Kind = TypeReflection::Kind; - using Scalar = TypeReflection::ScalarType; - - auto scalarToString = [](Scalar scalar) -> std::string - { - switch (scalar) - { - case Scalar::Float16: - case Scalar::Float32: - case Scalar::Float64: - return "float"; - case Scalar::Int8: - case Scalar::Int16: - case Scalar::Int32: - case Scalar::Int64: - return "int"; - case Scalar::UInt8: - case Scalar::UInt16: - case Scalar::UInt32: - case Scalar::UInt64: - return "uint"; - case Scalar::Bool: - return "bool"; - default: - return {}; - } - }; - - switch (type->getKind()) - { - case Kind::Scalar: - if (const std::string base = scalarToString(type->getScalarType()); !base.empty()) - { - return base; - } - break; - case Kind::Vector: - { - const std::string elementTypeName = getTypeName(type->getElementType()); - const std::size_t elementCount = type->getElementCount(); - if (!elementTypeName.empty() && elementCount > 0) - { - return elementTypeName + std::to_string(elementCount); - } - break; - } - case Kind::Matrix: - { - const std::string elementTypeName = getTypeName(type->getElementType()); - const unsigned rows = type->getRowCount(); - const unsigned cols = type->getColumnCount(); - if (!elementTypeName.empty() && rows > 0 && cols > 0) - { - return elementTypeName + std::to_string(rows) + "x" + std::to_string(cols); - } - break; - } - default: - break; - } - - if (const char* simpleName = type->getName()) - { - if (simpleName[0] != '\0') - { - return simpleName; - } - } - - ComPtr<ISlangBlob> fullNameBlob; - if (SLANG_SUCCEEDED(type->getFullName(fullNameBlob.writeRef())) && fullNameBlob && - fullNameBlob->getBufferSize() > 0) - { - const char* buffer = static_cast<const char*>(fullNameBlob->getBufferPointer()); - std::string fullName(buffer, buffer + fullNameBlob->getBufferSize()); - - auto parseTemplateType = [&](std::string_view name) -> std::string - { - if (name.substr(0, 7) == "vector<") - { - std::size_t commaPos = name.find(','); - std::size_t endPos = name.rfind('>'); - if (commaPos != std::string_view::npos && endPos != std::string_view::npos && - commaPos + 1 < endPos) - { - std::string base = trim(name.substr(7, commaPos - 7)); - std::string count = trim(name.substr(commaPos + 1, endPos - commaPos - 1)); - return base + count; - } - } - if (name.substr(0, 7) == "matrix<") - { - std::size_t firstComma = name.find(','); - std::size_t secondComma = name.find(',', firstComma + 1); - std::size_t endPos = name.rfind('>'); - if (firstComma != std::string_view::npos && secondComma != std::string_view::npos && - endPos != std::string_view::npos) - { - std::string base = trim(name.substr(7, firstComma - 7)); - std::string rows = trim(name.substr(firstComma + 1, secondComma - firstComma - 1)); - std::string cols = trim(name.substr(secondComma + 1, endPos - secondComma - 1)); - return base + rows + "x" + cols; - } - } - return std::string(name); - }; - - return parseTemplateType(fullName); - } - - return {}; -} - bool isTopLevelFunction(DeclReflection* functionDecl) { if (!functionDecl) @@ -326,46 +195,8 @@ void collectFunctionInfos( if (*name && seenNames.insert(name).second && isTopLevelFunction(decl) && isPublic) { - FunctionInfo info; - info.name = name; - - if (TypeReflection* returnType = functionReflection->getReturnType()) - { - info.returnType = getTypeName(returnType); - } - if (info.returnType.empty()) - { - info.returnType = "void"; - } - - const unsigned paramCount = functionReflection->getParameterCount(); - info.parameters.reserve(paramCount); - for (unsigned i = 0; i < paramCount; ++i) - { - VariableReflection* paramReflection = - functionReflection->getParameterByIndex(i); - ParameterInfo paramInfo; - if (auto* typeReflection = paramReflection->getType()) - { - paramInfo.typeName = getTypeName(typeReflection); - } - if (paramInfo.typeName.empty()) - { - paramInfo.typeName = "auto"; - } - if (const char* paramName = paramReflection->getName()) - { - paramInfo.name = paramName; - } - if (paramInfo.name.empty()) - { - paramInfo.name = "param" + std::to_string(i); - } - info.parameters.push_back(std::move(paramInfo)); - } - - std::cerr << "Discovered entry point: " << info.name << std::endl; - functions.push_back(std::move(info)); + std::cerr << "Discovered entry point: " << name << std::endl; + functions.push_back({name}); } } } @@ -389,224 +220,6 @@ void collectFunctionInfos( } } -struct EntryPointField -{ - int bufferIndex = 0; - std::string fieldName; - std::string baseName; -}; - -std::string rewriteHLSLWithWrappers( - const std::string& originalHlsl, - const std::vector<FunctionInfo>& functions) -{ - std::string result = originalHlsl; - - std::unordered_map<int, std::vector<EntryPointField>> bufferIndexToFields; - std::unordered_map<std::string, std::vector<EntryPointField>> baseNameToFields; - - const std::string structPrefix = "struct EntryPointParams_"; - std::size_t searchPos = 0; - while (true) - { - std::size_t structPos = result.find(structPrefix, searchPos); - if (structPos == std::string::npos) - { - break; - } - - std::size_t indexPos = structPos + structPrefix.size(); - std::size_t indexEnd = indexPos; - while (indexEnd < result.size() && std::isdigit(static_cast<unsigned char>(result[indexEnd]))) - { - ++indexEnd; - } - if (indexEnd == indexPos) - { - searchPos = indexEnd; - continue; - } - - int bufferIndex = std::stoi(result.substr(indexPos, indexEnd - indexPos)); - - std::size_t braceOpen = result.find('{', indexEnd); - if (braceOpen == std::string::npos) - { - break; - } - std::size_t braceClose = result.find("};", braceOpen); - if (braceClose == std::string::npos) - { - break; - } - - std::size_t fieldPos = braceOpen + 1; - while (fieldPos < braceClose) - { - std::size_t semicolon = result.find(';', fieldPos); - if (semicolon == std::string::npos || semicolon > braceClose) - { - break; - } - - std::string line = trim(std::string_view(result).substr(fieldPos, semicolon - fieldPos)); - if (!line.empty()) - { - std::size_t lastSpace = line.find_last_of(" \t"); - if (lastSpace != std::string::npos && lastSpace + 1 < line.size()) - { - std::string fieldName = line.substr(lastSpace + 1); - std::size_t bracketPos = fieldName.find('['); - if (bracketPos != std::string::npos) - { - fieldName = fieldName.substr(0, bracketPos); - } - - std::string baseName = fieldName; - std::size_t underscorePos = baseName.rfind('_'); - if (underscorePos != std::string::npos) - { - baseName = baseName.substr(0, underscorePos); - } - - EntryPointField field{bufferIndex, fieldName, baseName}; - bufferIndexToFields[bufferIndex].push_back(field); - baseNameToFields[baseName].push_back(field); - } - } - - fieldPos = semicolon + 1; - } - - searchPos = braceClose; - } - - const std::string attributeToken = "[shader(\"dispatch\")]export"; - std::size_t searchFrom = 0; - for (const FunctionInfo& func : functions) - { - std::size_t attrPos = result.find(attributeToken, searchFrom); - if (attrPos != std::string::npos) - { - std::size_t attrEnd = attrPos + attributeToken.size(); - if (attrEnd < result.size() && result[attrEnd] == '\r') - { - ++attrEnd; - } - if (attrEnd < result.size() && result[attrEnd] == '\n') - { - ++attrEnd; - } - result.erase(attrPos, attrEnd - attrPos); - searchFrom = attrPos; - } - - std::size_t namePos = result.find(func.name + "(", searchFrom); - if (namePos == std::string::npos) - { - namePos = result.find(func.name + "("); - } - if (namePos != std::string::npos) - { - const std::string entryName = "__slang_entry_" + func.name; - result.replace(namePos, func.name.size(), entryName); - searchFrom = namePos + entryName.size(); - } - } - - std::ostringstream wrapperBuilder; - wrapperBuilder << "\n"; - - for (std::size_t functionIndex = 0; functionIndex < functions.size(); ++functionIndex) - { - const FunctionInfo& func = functions[functionIndex]; - const std::string entryName = "__slang_entry_" + func.name; - - std::string parameterList; - for (std::size_t i = 0; i < func.parameters.size(); ++i) - { - if (i > 0) - { - parameterList += ", "; - } - parameterList += func.parameters[i].typeName + " " + func.parameters[i].name; - } - - wrapperBuilder << func.returnType << " " << func.name << "(" << parameterList << ")\n"; - wrapperBuilder << "{\n"; - - std::unordered_set<std::string> emittedAssignments; - - for (std::size_t paramIndex = 0; paramIndex < func.parameters.size(); ++paramIndex) - { - const ParameterInfo& param = func.parameters[paramIndex]; - - auto baseIt = baseNameToFields.find(param.name); - if (baseIt != baseNameToFields.end()) - { - for (const EntryPointField& field : baseIt->second) - { - std::string assignment = - " entryPointParams_" + std::to_string(field.bufferIndex) + "." + - field.fieldName + " = " + param.name + ";\n"; - if (emittedAssignments.insert(assignment).second) - { - wrapperBuilder << assignment; - } - } - } - - auto bufferIt = bufferIndexToFields.find(static_cast<int>(functionIndex)); - if (bufferIt != bufferIndexToFields.end() && paramIndex < bufferIt->second.size()) - { - const EntryPointField& field = bufferIt->second[paramIndex]; - std::string assignment = - " entryPointParams_" + std::to_string(field.bufferIndex) + "." + - field.fieldName + " = " + param.name + ";\n"; - if (emittedAssignments.insert(assignment).second) - { - wrapperBuilder << assignment; - } - } - } - - if (func.returnType == "void") - { - wrapperBuilder << " " << entryName << "();\n"; - wrapperBuilder << " return;\n"; - } - else - { - wrapperBuilder << " return " << entryName << "();\n"; - } - - wrapperBuilder << "}\n\n"; - } - - const std::string wrappers = wrapperBuilder.str(); - if (!wrappers.empty()) - { - const std::size_t endifPos = result.rfind("#endif"); - const std::size_t ifndefPos = result.find("#ifndef"); - const std::size_t definePos = result.find("#define", ifndefPos != std::string::npos ? ifndefPos : 0); - - const bool hasGuard = - ifndefPos != std::string::npos && definePos != std::string::npos && - endifPos != std::string::npos && ifndefPos < definePos && definePos < endifPos; - - if (hasGuard) - { - result.insert(endifPos, wrappers); - } - else - { - result += wrappers; - } - } - - return result; -} - IncludeGuardInfo detectIncludeGuard(const fs::path& sourcePath) { IncludeGuardInfo info; @@ -731,6 +344,7 @@ int main(int argc, char** argv) addCompilerOption(options, CompilerOptionName::NoHLSLBinding); addCompilerOption(options, CompilerOptionName::NoMangle); addCompilerOption(options, CompilerOptionName::NoHLSLPackConstantBufferElements); + addCompilerOption(options, CompilerOptionName::NoEntryPointUniformParamTransform); std::vector<CompilerOptionEntry> targetOptions(options); targetDesc.compilerOptionEntries = targetOptions.data(); @@ -838,7 +452,29 @@ int main(int argc, char** argv) rawOutputPath.replace_extension(".raw.hlsl"); writeTextFile(rawOutputPath, hlslSource); - std::string finalHlsl = rewriteHLSLWithWrappers(hlslSource, functions); + // Remove [shader("dispatch")]export lines + std::string finalHlsl = hlslSource; + const std::string shaderExportToken = "[shader(\"dispatch\")]export"; + std::size_t pos = 0; + while ((pos = finalHlsl.find(shaderExportToken, pos)) != std::string::npos) + { + std::size_t lineEnd = pos + shaderExportToken.size(); + // Skip optional whitespace and newline + while (lineEnd < finalHlsl.size() && + (finalHlsl[lineEnd] == ' ' || finalHlsl[lineEnd] == '\t')) + { + ++lineEnd; + } + if (lineEnd < finalHlsl.size() && finalHlsl[lineEnd] == '\r') + { + ++lineEnd; + } + if (lineEnd < finalHlsl.size() && finalHlsl[lineEnd] == '\n') + { + ++lineEnd; + } + finalHlsl.erase(pos, lineEnd - pos); + } IncludeGuardInfo includeGuard = detectIncludeGuard(modulePath); if (includeGuard.present) |
