summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2025-10-11 13:04:57 -0700
committeryum <yum.food.vr@gmail.com>2025-10-11 13:04:57 -0700
commitd1ecda540ddd6e9ab53f7981de65c3e435c1132c (patch)
tree6907da6e0d4df248ce19ea825eccdf261f28ae9c
parent06780e36b2aeb1257607c89570fd99903508b82e (diff)
more stuff
-rw-r--r--README.md12
-rw-r--r--main.cc98
-rw-r--r--test.slang26
3 files changed, 51 insertions, 85 deletions
diff --git a/README.md b/README.md
index 00ff565..174ccd2 100644
--- a/README.md
+++ b/README.md
@@ -3,6 +3,17 @@
the idea here is just to make something like slangc.exe, but which doesn't
require an entrypoint.
+basics and gotchas
+
+- all functions marked `public` are exported as entrypoints.
+ - this causes slang to put their arguments in a cbuffer. we post-process them
+ to fix this.
+- any autodiff function calls within a module must *not* be to `public` APIs.
+ - this is a consequence of the prior point. any public API is seen by slang
+ as an "entrypoint" which accepts a cbuffer argument. the autodiff logic
+ does not handle this edge case, so it tries to pass arguments to a function
+ which takes no arguments.
+
build instructions
```bash
@@ -22,3 +33,4 @@ cd ../..
powershell.exe ./build.ps1 && ./build/bin/Release/modular_slange.exe ./test.slang
```
+
diff --git a/main.cc b/main.cc
index 7c67428..2329ce8 100644
--- a/main.cc
+++ b/main.cc
@@ -106,6 +106,14 @@ bool writeTextFile(const fs::path& path, std::string_view contents)
return true;
}
+void addCompilerOption(std::vector<CompilerOptionEntry>& options, CompilerOptionName name)
+{
+ CompilerOptionEntry entry = {};
+ entry.name = name;
+ entry.value.intValue0 = 1;
+ options.push_back(entry);
+}
+
struct ParameterInfo
{
std::string name;
@@ -312,7 +320,11 @@ void collectFunctionInfos(
{
if (const char* name = functionReflection->getName())
{
- if (*name && seenNames.insert(name).second && isTopLevelFunction(decl))
+ // Heuristic: functions that don't start with underscore are considered public
+ // (Slang convention: private/internal functions typically start with _)
+ bool isPublic = name[0] != '_';
+
+ if (*name && seenNames.insert(name).second && isTopLevelFunction(decl) && isPublic)
{
FunctionInfo info;
info.name = name;
@@ -352,6 +364,7 @@ void collectFunctionInfos(
info.parameters.push_back(std::move(paramInfo));
}
+ std::cerr << "Discovered entry point: " << info.name << std::endl;
functions.push_back(std::move(info));
}
}
@@ -681,24 +694,20 @@ IncludeGuardInfo detectIncludeGuard(const fs::path& sourcePath)
int main(int argc, char** argv)
{
- if (argc < 2)
- {
+ if (argc < 2) {
std::cerr << "Usage: " << (argc > 0 ? argv[0] : "modular_slang") << " <module.slang>" << std::endl;
return 1;
}
-
fs::path modulePath = fs::absolute(argv[1]);
- if (!fs::exists(modulePath))
- {
+ if (!fs::exists(modulePath)) {
std::cerr << "Module not found: " << modulePath << std::endl;
return 1;
}
-
- if (modulePath.extension() != ".slang")
- {
+ if (modulePath.extension() != ".slang") {
std::cerr << "Expected a .slang file: " << modulePath << std::endl;
return 1;
}
+
std::string moduleName = modulePath.stem().string();
std::string searchPath = modulePath.has_parent_path()
? modulePath.parent_path().string()
@@ -707,31 +716,23 @@ int main(int argc, char** argv)
fs::path outputPath = modulePath;
outputPath.replace_extension(".hlsl");
- IncludeGuardInfo includeGuard = detectIncludeGuard(modulePath);
-
- // 1. Session Creation
+ // Create session
ComPtr<IGlobalSession> globalSession;
checkResult("createGlobalSession", createGlobalSession(globalSession.writeRef()));
- // 2. Target Configuration
+ // Configure session and target
TargetDesc targetDesc = {};
targetDesc.format = SLANG_HLSL;
targetDesc.profile = globalSession->findProfile("lib_6_6");
targetDesc.flags = SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM;
- std::vector<CompilerOptionEntry> targetOptions;
- {
- CompilerOptionEntry entry = {};
- entry.name = CompilerOptionName::NoHLSLBinding;
- entry.value.intValue0 = 1;
- targetOptions.push_back(entry);
- }
- {
- CompilerOptionEntry entry = {};
- entry.name = CompilerOptionName::NoHLSLPackConstantBufferElements;
- entry.value.intValue0 = 1;
- targetOptions.push_back(entry);
- }
+ std::vector<CompilerOptionEntry> options;
+ addCompilerOption(options, CompilerOptionName::DisableNonEssentialValidations);
+ addCompilerOption(options, CompilerOptionName::NoHLSLBinding);
+ addCompilerOption(options, CompilerOptionName::NoMangle);
+ addCompilerOption(options, CompilerOptionName::NoHLSLPackConstantBufferElements);
+
+ std::vector<CompilerOptionEntry> targetOptions(options);
targetDesc.compilerOptionEntries = targetOptions.data();
targetDesc.compilerOptionEntryCount = static_cast<uint32_t>(targetOptions.size());
@@ -742,26 +743,15 @@ int main(int argc, char** argv)
sessionDesc.searchPaths = searchPaths;
sessionDesc.searchPathCount = 1;
- std::vector<CompilerOptionEntry> sessionOptions;
- {
- CompilerOptionEntry entry = {};
- entry.name = CompilerOptionName::NoMangle;
- entry.value.intValue0 = 1;
- sessionOptions.push_back(entry);
- }
- {
- CompilerOptionEntry entry = {};
- entry.name = CompilerOptionName::DisableNonEssentialValidations;
- entry.value.intValue0 = 1;
- sessionOptions.push_back(entry);
- }
+ std::vector<CompilerOptionEntry> sessionOptions(options);
sessionDesc.compilerOptionEntries = sessionOptions.data();
sessionDesc.compilerOptionEntryCount = static_cast<uint32_t>(sessionOptions.size());
ComPtr<ISession> session;
- checkResult("IGlobalSession::createSession", globalSession->createSession(sessionDesc, session.writeRef()));
+ checkResult("IGlobalSession::createSession",
+ globalSession->createSession(sessionDesc, session.writeRef()));
- // 3. Module Loading (from the supplied Slang source file)
+ // Load the "module" aka the library
ComPtr<IModule> libraryModule;
{
ComPtr<IBlob> diagnosticsBlob;
@@ -774,25 +764,23 @@ int main(int argc, char** argv)
return 1;
}
}
- // 4. Discover top-level functions to treat as entry points
+
+ // Discover top-level functions to treat as entry points
std::vector<FunctionInfo> functions;
std::unordered_set<std::string> seenNames;
DeclReflection* moduleReflection = libraryModule->getModuleReflection();
- if (!moduleReflection)
- {
+ if (!moduleReflection) {
std::cerr << "Failed to retrieve reflection data for module '" << moduleName << "'."
<< std::endl;
return 1;
}
collectFunctionInfos(moduleReflection, functions, seenNames);
-
- if (functions.empty())
- {
+ if (functions.empty()) {
std::cerr << "No functions found in module '" << moduleName << "'." << std::endl;
return 1;
}
- // 5. Compile the translation unit with whole-program emission
+ // Compile
ComPtr<ICompileRequest> compileRequest;
checkResult(
"ISession::createCompileRequest",
@@ -801,10 +789,8 @@ int main(int argc, char** argv)
compileRequest->setCodeGenTarget(SLANG_HLSL);
compileRequest->setTargetProfile(0, targetDesc.profile);
compileRequest->setTargetFlags(0, SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM);
-
- SlangCompileFlags compileFlags = compileRequest->getCompileFlags();
- compileFlags |= SLANG_COMPILE_FLAG_NO_MANGLING;
- compileRequest->setCompileFlags(compileFlags);
+ compileRequest->setMatrixLayoutMode(SLANG_MATRIX_LAYOUT_ROW_MAJOR);
+ compileRequest->setLineDirectiveMode(SLANG_LINE_DIRECTIVE_MODE_NONE);
compileRequest->addSearchPath(searchPath.c_str());
@@ -854,6 +840,7 @@ int main(int argc, char** argv)
std::string finalHlsl = rewriteHLSLWithWrappers(hlslSource, functions);
+ IncludeGuardInfo includeGuard = detectIncludeGuard(modulePath);
if (includeGuard.present)
{
const std::string guardIfndefToken = "#ifndef " + includeGuard.macro;
@@ -895,13 +882,6 @@ int main(int argc, char** argv)
return 1;
}
- // Also stream to standard output to retain previous behavior
- std::cout.write(finalHlsl.data(), static_cast<std::streamsize>(finalHlsl.size()));
- if (finalHlsl.empty() || finalHlsl.back() != '\n')
- {
- std::cout << std::endl;
- }
-
std::cerr << "Generated HLSL written to " << outputPath << std::endl;
return 0;
diff --git a/test.slang b/test.slang
deleted file mode 100644
index ee971ce..0000000
--- a/test.slang
+++ /dev/null
@@ -1,26 +0,0 @@
-#ifndef __CUSTOM31_INC
-#define __CUSTOM31_INC
-
-[Differentiable]
-public float3 c31_deform(uniform float3 xyz) {
- return float3(
- sin(xyz.x) * sin(xyz.z),
- xyz.y,
- sin(xyz.x) * sin(xyz.z)
- );
-}
-
-public float3x3 c31_deform_jacobian(uniform float3 xyz, uniform float3 n) {
- DifferentialPair<float3> dp_x = diffPair(xyz, float3(1, 0, 0));
- DifferentialPair<float3> dp_y = diffPair(xyz, float3(0, 1, 0));
- DifferentialPair<float3> dp_z = diffPair(xyz, float3(0, 0, 1));
-
- DifferentialPair<float3> dp_x_out = fwd_diff(c31_deform)(dp_x);
- DifferentialPair<float3> dp_y_out = fwd_diff(c31_deform)(dp_y);
- DifferentialPair<float3> dp_z_out = fwd_diff(c31_deform)(dp_z);
-
- return float3x3(dp_x_out.d, dp_y_out.d, dp_z_out.d);
-}
-
-#endif // __CUSTOM31_INC
-