summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--main.cc416
1 files changed, 26 insertions, 390 deletions
diff --git a/main.cc b/main.cc
index 2329ce8..aa09d31 100644
--- a/main.cc
+++ b/main.cc
@@ -7,7 +7,6 @@
#include <sstream>
#include <string>
#include <string_view>
-#include <unordered_map>
#include <unordered_set>
#include <vector>
@@ -28,8 +27,6 @@ using ::slang::IModule;
using ::slang::ISession;
using ::slang::SessionDesc;
using ::slang::TargetDesc;
-using ::slang::TypeReflection;
-using ::slang::VariableReflection;
template <typename T>
using ComPtr = ::Slang::ComPtr<T>;
@@ -114,17 +111,9 @@ void addCompilerOption(std::vector<CompilerOptionEntry>& options, CompilerOption
options.push_back(entry);
}
-struct ParameterInfo
-{
- std::string name;
- std::string typeName;
-};
-
struct FunctionInfo
{
std::string name;
- std::string returnType;
- std::vector<ParameterInfo> parameters;
};
struct IncludeGuardInfo
@@ -154,126 +143,6 @@ std::string trim(std::string_view text)
return std::string(text.substr(start, end - start));
}
-std::string getTypeName(TypeReflection* type)
-{
- if (!type)
- {
- return {};
- }
-
- using Kind = TypeReflection::Kind;
- using Scalar = TypeReflection::ScalarType;
-
- auto scalarToString = [](Scalar scalar) -> std::string
- {
- switch (scalar)
- {
- case Scalar::Float16:
- case Scalar::Float32:
- case Scalar::Float64:
- return "float";
- case Scalar::Int8:
- case Scalar::Int16:
- case Scalar::Int32:
- case Scalar::Int64:
- return "int";
- case Scalar::UInt8:
- case Scalar::UInt16:
- case Scalar::UInt32:
- case Scalar::UInt64:
- return "uint";
- case Scalar::Bool:
- return "bool";
- default:
- return {};
- }
- };
-
- switch (type->getKind())
- {
- case Kind::Scalar:
- if (const std::string base = scalarToString(type->getScalarType()); !base.empty())
- {
- return base;
- }
- break;
- case Kind::Vector:
- {
- const std::string elementTypeName = getTypeName(type->getElementType());
- const std::size_t elementCount = type->getElementCount();
- if (!elementTypeName.empty() && elementCount > 0)
- {
- return elementTypeName + std::to_string(elementCount);
- }
- break;
- }
- case Kind::Matrix:
- {
- const std::string elementTypeName = getTypeName(type->getElementType());
- const unsigned rows = type->getRowCount();
- const unsigned cols = type->getColumnCount();
- if (!elementTypeName.empty() && rows > 0 && cols > 0)
- {
- return elementTypeName + std::to_string(rows) + "x" + std::to_string(cols);
- }
- break;
- }
- default:
- break;
- }
-
- if (const char* simpleName = type->getName())
- {
- if (simpleName[0] != '\0')
- {
- return simpleName;
- }
- }
-
- ComPtr<ISlangBlob> fullNameBlob;
- if (SLANG_SUCCEEDED(type->getFullName(fullNameBlob.writeRef())) && fullNameBlob &&
- fullNameBlob->getBufferSize() > 0)
- {
- const char* buffer = static_cast<const char*>(fullNameBlob->getBufferPointer());
- std::string fullName(buffer, buffer + fullNameBlob->getBufferSize());
-
- auto parseTemplateType = [&](std::string_view name) -> std::string
- {
- if (name.substr(0, 7) == "vector<")
- {
- std::size_t commaPos = name.find(',');
- std::size_t endPos = name.rfind('>');
- if (commaPos != std::string_view::npos && endPos != std::string_view::npos &&
- commaPos + 1 < endPos)
- {
- std::string base = trim(name.substr(7, commaPos - 7));
- std::string count = trim(name.substr(commaPos + 1, endPos - commaPos - 1));
- return base + count;
- }
- }
- if (name.substr(0, 7) == "matrix<")
- {
- std::size_t firstComma = name.find(',');
- std::size_t secondComma = name.find(',', firstComma + 1);
- std::size_t endPos = name.rfind('>');
- if (firstComma != std::string_view::npos && secondComma != std::string_view::npos &&
- endPos != std::string_view::npos)
- {
- std::string base = trim(name.substr(7, firstComma - 7));
- std::string rows = trim(name.substr(firstComma + 1, secondComma - firstComma - 1));
- std::string cols = trim(name.substr(secondComma + 1, endPos - secondComma - 1));
- return base + rows + "x" + cols;
- }
- }
- return std::string(name);
- };
-
- return parseTemplateType(fullName);
- }
-
- return {};
-}
-
bool isTopLevelFunction(DeclReflection* functionDecl)
{
if (!functionDecl)
@@ -326,46 +195,8 @@ void collectFunctionInfos(
if (*name && seenNames.insert(name).second && isTopLevelFunction(decl) && isPublic)
{
- FunctionInfo info;
- info.name = name;
-
- if (TypeReflection* returnType = functionReflection->getReturnType())
- {
- info.returnType = getTypeName(returnType);
- }
- if (info.returnType.empty())
- {
- info.returnType = "void";
- }
-
- const unsigned paramCount = functionReflection->getParameterCount();
- info.parameters.reserve(paramCount);
- for (unsigned i = 0; i < paramCount; ++i)
- {
- VariableReflection* paramReflection =
- functionReflection->getParameterByIndex(i);
- ParameterInfo paramInfo;
- if (auto* typeReflection = paramReflection->getType())
- {
- paramInfo.typeName = getTypeName(typeReflection);
- }
- if (paramInfo.typeName.empty())
- {
- paramInfo.typeName = "auto";
- }
- if (const char* paramName = paramReflection->getName())
- {
- paramInfo.name = paramName;
- }
- if (paramInfo.name.empty())
- {
- paramInfo.name = "param" + std::to_string(i);
- }
- info.parameters.push_back(std::move(paramInfo));
- }
-
- std::cerr << "Discovered entry point: " << info.name << std::endl;
- functions.push_back(std::move(info));
+ std::cerr << "Discovered entry point: " << name << std::endl;
+ functions.push_back({name});
}
}
}
@@ -389,224 +220,6 @@ void collectFunctionInfos(
}
}
-struct EntryPointField
-{
- int bufferIndex = 0;
- std::string fieldName;
- std::string baseName;
-};
-
-std::string rewriteHLSLWithWrappers(
- const std::string& originalHlsl,
- const std::vector<FunctionInfo>& functions)
-{
- std::string result = originalHlsl;
-
- std::unordered_map<int, std::vector<EntryPointField>> bufferIndexToFields;
- std::unordered_map<std::string, std::vector<EntryPointField>> baseNameToFields;
-
- const std::string structPrefix = "struct EntryPointParams_";
- std::size_t searchPos = 0;
- while (true)
- {
- std::size_t structPos = result.find(structPrefix, searchPos);
- if (structPos == std::string::npos)
- {
- break;
- }
-
- std::size_t indexPos = structPos + structPrefix.size();
- std::size_t indexEnd = indexPos;
- while (indexEnd < result.size() && std::isdigit(static_cast<unsigned char>(result[indexEnd])))
- {
- ++indexEnd;
- }
- if (indexEnd == indexPos)
- {
- searchPos = indexEnd;
- continue;
- }
-
- int bufferIndex = std::stoi(result.substr(indexPos, indexEnd - indexPos));
-
- std::size_t braceOpen = result.find('{', indexEnd);
- if (braceOpen == std::string::npos)
- {
- break;
- }
- std::size_t braceClose = result.find("};", braceOpen);
- if (braceClose == std::string::npos)
- {
- break;
- }
-
- std::size_t fieldPos = braceOpen + 1;
- while (fieldPos < braceClose)
- {
- std::size_t semicolon = result.find(';', fieldPos);
- if (semicolon == std::string::npos || semicolon > braceClose)
- {
- break;
- }
-
- std::string line = trim(std::string_view(result).substr(fieldPos, semicolon - fieldPos));
- if (!line.empty())
- {
- std::size_t lastSpace = line.find_last_of(" \t");
- if (lastSpace != std::string::npos && lastSpace + 1 < line.size())
- {
- std::string fieldName = line.substr(lastSpace + 1);
- std::size_t bracketPos = fieldName.find('[');
- if (bracketPos != std::string::npos)
- {
- fieldName = fieldName.substr(0, bracketPos);
- }
-
- std::string baseName = fieldName;
- std::size_t underscorePos = baseName.rfind('_');
- if (underscorePos != std::string::npos)
- {
- baseName = baseName.substr(0, underscorePos);
- }
-
- EntryPointField field{bufferIndex, fieldName, baseName};
- bufferIndexToFields[bufferIndex].push_back(field);
- baseNameToFields[baseName].push_back(field);
- }
- }
-
- fieldPos = semicolon + 1;
- }
-
- searchPos = braceClose;
- }
-
- const std::string attributeToken = "[shader(\"dispatch\")]export";
- std::size_t searchFrom = 0;
- for (const FunctionInfo& func : functions)
- {
- std::size_t attrPos = result.find(attributeToken, searchFrom);
- if (attrPos != std::string::npos)
- {
- std::size_t attrEnd = attrPos + attributeToken.size();
- if (attrEnd < result.size() && result[attrEnd] == '\r')
- {
- ++attrEnd;
- }
- if (attrEnd < result.size() && result[attrEnd] == '\n')
- {
- ++attrEnd;
- }
- result.erase(attrPos, attrEnd - attrPos);
- searchFrom = attrPos;
- }
-
- std::size_t namePos = result.find(func.name + "(", searchFrom);
- if (namePos == std::string::npos)
- {
- namePos = result.find(func.name + "(");
- }
- if (namePos != std::string::npos)
- {
- const std::string entryName = "__slang_entry_" + func.name;
- result.replace(namePos, func.name.size(), entryName);
- searchFrom = namePos + entryName.size();
- }
- }
-
- std::ostringstream wrapperBuilder;
- wrapperBuilder << "\n";
-
- for (std::size_t functionIndex = 0; functionIndex < functions.size(); ++functionIndex)
- {
- const FunctionInfo& func = functions[functionIndex];
- const std::string entryName = "__slang_entry_" + func.name;
-
- std::string parameterList;
- for (std::size_t i = 0; i < func.parameters.size(); ++i)
- {
- if (i > 0)
- {
- parameterList += ", ";
- }
- parameterList += func.parameters[i].typeName + " " + func.parameters[i].name;
- }
-
- wrapperBuilder << func.returnType << " " << func.name << "(" << parameterList << ")\n";
- wrapperBuilder << "{\n";
-
- std::unordered_set<std::string> emittedAssignments;
-
- for (std::size_t paramIndex = 0; paramIndex < func.parameters.size(); ++paramIndex)
- {
- const ParameterInfo& param = func.parameters[paramIndex];
-
- auto baseIt = baseNameToFields.find(param.name);
- if (baseIt != baseNameToFields.end())
- {
- for (const EntryPointField& field : baseIt->second)
- {
- std::string assignment =
- " entryPointParams_" + std::to_string(field.bufferIndex) + "." +
- field.fieldName + " = " + param.name + ";\n";
- if (emittedAssignments.insert(assignment).second)
- {
- wrapperBuilder << assignment;
- }
- }
- }
-
- auto bufferIt = bufferIndexToFields.find(static_cast<int>(functionIndex));
- if (bufferIt != bufferIndexToFields.end() && paramIndex < bufferIt->second.size())
- {
- const EntryPointField& field = bufferIt->second[paramIndex];
- std::string assignment =
- " entryPointParams_" + std::to_string(field.bufferIndex) + "." +
- field.fieldName + " = " + param.name + ";\n";
- if (emittedAssignments.insert(assignment).second)
- {
- wrapperBuilder << assignment;
- }
- }
- }
-
- if (func.returnType == "void")
- {
- wrapperBuilder << " " << entryName << "();\n";
- wrapperBuilder << " return;\n";
- }
- else
- {
- wrapperBuilder << " return " << entryName << "();\n";
- }
-
- wrapperBuilder << "}\n\n";
- }
-
- const std::string wrappers = wrapperBuilder.str();
- if (!wrappers.empty())
- {
- const std::size_t endifPos = result.rfind("#endif");
- const std::size_t ifndefPos = result.find("#ifndef");
- const std::size_t definePos = result.find("#define", ifndefPos != std::string::npos ? ifndefPos : 0);
-
- const bool hasGuard =
- ifndefPos != std::string::npos && definePos != std::string::npos &&
- endifPos != std::string::npos && ifndefPos < definePos && definePos < endifPos;
-
- if (hasGuard)
- {
- result.insert(endifPos, wrappers);
- }
- else
- {
- result += wrappers;
- }
- }
-
- return result;
-}
-
IncludeGuardInfo detectIncludeGuard(const fs::path& sourcePath)
{
IncludeGuardInfo info;
@@ -731,6 +344,7 @@ int main(int argc, char** argv)
addCompilerOption(options, CompilerOptionName::NoHLSLBinding);
addCompilerOption(options, CompilerOptionName::NoMangle);
addCompilerOption(options, CompilerOptionName::NoHLSLPackConstantBufferElements);
+ addCompilerOption(options, CompilerOptionName::NoEntryPointUniformParamTransform);
std::vector<CompilerOptionEntry> targetOptions(options);
targetDesc.compilerOptionEntries = targetOptions.data();
@@ -838,7 +452,29 @@ int main(int argc, char** argv)
rawOutputPath.replace_extension(".raw.hlsl");
writeTextFile(rawOutputPath, hlslSource);
- std::string finalHlsl = rewriteHLSLWithWrappers(hlslSource, functions);
+ // Remove [shader("dispatch")]export lines
+ std::string finalHlsl = hlslSource;
+ const std::string shaderExportToken = "[shader(\"dispatch\")]export";
+ std::size_t pos = 0;
+ while ((pos = finalHlsl.find(shaderExportToken, pos)) != std::string::npos)
+ {
+ std::size_t lineEnd = pos + shaderExportToken.size();
+ // Skip optional whitespace and newline
+ while (lineEnd < finalHlsl.size() &&
+ (finalHlsl[lineEnd] == ' ' || finalHlsl[lineEnd] == '\t'))
+ {
+ ++lineEnd;
+ }
+ if (lineEnd < finalHlsl.size() && finalHlsl[lineEnd] == '\r')
+ {
+ ++lineEnd;
+ }
+ if (lineEnd < finalHlsl.size() && finalHlsl[lineEnd] == '\n')
+ {
+ ++lineEnd;
+ }
+ finalHlsl.erase(pos, lineEnd - pos);
+ }
IncludeGuardInfo includeGuard = detectIncludeGuard(modulePath);
if (includeGuard.present)