From 4cbf98dd43ebb4795fe36faa68b858602a4044f3 Mon Sep 17 00:00:00 2001 From: yum Date: Fri, 10 Oct 2025 22:53:26 -0700 Subject: initial commit --- .gitignore | 3 + CMakeLists.txt | 45 ++++ README.md | 24 ++ build.ps1 | 10 + main.cc | 743 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ test.slang | 24 ++ 6 files changed, 849 insertions(+) create mode 100644 .gitignore create mode 100644 CMakeLists.txt create mode 100644 README.md create mode 100644 build.ps1 create mode 100644 main.cc create mode 100644 test.slang diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..85b091c --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +*.sw[po] +build + diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..25032e6 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,45 @@ +cmake_minimum_required(VERSION 3.20) +project(modular_slang_demo LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +set(SLANG_SDK_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/slang/build/slang-2025.18.2-windows-x86_64") +if(NOT EXISTS "${SLANG_SDK_ROOT}") + message(FATAL_ERROR "Expected Slang SDK directory at ${SLANG_SDK_ROOT}") +endif() + +add_executable(modular_slang main.cc) +set_target_properties(modular_slang PROPERTIES + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin" +) + +target_include_directories(modular_slang PRIVATE + "${SLANG_SDK_ROOT}/include" +) + +target_link_libraries(modular_slang PRIVATE + "${SLANG_SDK_ROOT}/lib/slang.lib" +) + +if(MSVC) + target_compile_options(modular_slang PRIVATE /W4 /EHsc) +else() + target_compile_options(modular_slang PRIVATE -Wall -Wextra -pedantic) +endif() + +add_custom_command(TARGET modular_slang POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${SLANG_SDK_ROOT}/bin/slang.dll" + "$/slang.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${SLANG_SDK_ROOT}/bin/slang-rt.dll" + "$/slang-rt.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${SLANG_SDK_ROOT}/bin/gfx.dll" + "$/gfx.dll" + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${CMAKE_CURRENT_SOURCE_DIR}/test.slang" + "$/test.slang" +) diff --git a/README.md b/README.md new file mode 100644 index 0000000..00ff565 --- /dev/null +++ b/README.md @@ -0,0 +1,24 @@ +## modular slang + +the idea here is just to make something like slangc.exe, but which doesn't +require an entrypoint. + +build instructions + +```bash +git clone https://github.com/shader-slang/slang +git submodule update --init --recursive -j 32 +# wait, the previous command takes a while +mkdir slang/build +cd slang/build +# run this part in powershell, from ./slang/build +cmake.exe .. +cmake.exe --build . -j +# the previous command will take a long fucking time +# switch back to top level of repo +cd ../.. +# do this in wsl2 or powershell. Showing wsl2/bash syntax +# you'd do this for each incremental build. +powershell.exe ./build.ps1 && ./build/bin/Release/modular_slange.exe ./test.slang +``` + diff --git a/build.ps1 b/build.ps1 new file mode 100644 index 0000000..95d2683 --- /dev/null +++ b/build.ps1 @@ -0,0 +1,10 @@ +if (Test-Path -Path ./build) { + rm -Recurse ./build +} +mkdir ./build + +cmake -S . -B build -G "Visual Studio 17 2022" -A x64 +if ($LASTEXITCODE -ne 0) { exit $LASTEXITCODE } +cmake --build build --config Release +if ($LASTEXITCODE -ne 0) { exit $LASTEXITCODE } + diff --git a/main.cc b/main.cc new file mode 100644 index 0000000..2521e07 --- /dev/null +++ b/main.cc @@ -0,0 +1,743 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +// Print any diagnostics carried by a Slang blob with optional context information. +void printDiagnostics(const char* context, slang::IBlob* diagnostics) +{ + if (!diagnostics) + { + return; + } + + size_t size = diagnostics->getBufferSize(); + if (size == 0) + { + return; + } + + std::string_view text(static_cast(diagnostics->getBufferPointer()), size); + if (!text.empty() && text.back() == '\0') + { + text.remove_suffix(1); + } + + if (text.empty()) + { + return; + } + + if (context && *context) + { + std::cerr << context << " diagnostics:" << std::endl; + } + + std::cerr.write(text.data(), text.size()); + if (text.back() != '\n') + { + std::cerr << std::endl; + } +} + +// Helper to check Slang API results and surface diagnostic details when available. +void checkResult(const char* context, SlangResult res, slang::IBlob* diagnostics = nullptr) +{ + printDiagnostics(context, diagnostics); + + if (SLANG_FAILED(res)) + { + std::cerr << (context && *context ? context : "Slang call") + << " failed with SlangResult " << res + << " (0x" << std::hex << res << std::dec << ')' << std::endl; + exit(1); + } +} + +struct ParameterInfo +{ + std::string name; + std::string typeName; +}; + +struct FunctionInfo +{ + std::string name; + std::string returnType; + std::vector parameters; +}; + +std::string trim(std::string_view text) +{ + size_t start = 0; + size_t end = text.size(); + + while (start < end && std::isspace(static_cast(text[start]))) + { + ++start; + } + + while (end > start && std::isspace(static_cast(text[end - 1]))) + { + --end; + } + + return std::string(text.substr(start, end - start)); +} + +std::string getTypeName(slang::TypeReflection* type) +{ + if (!type) + { + return {}; + } + + using Kind = slang::TypeReflection::Kind; + using Scalar = slang::TypeReflection::ScalarType; + + auto scalarToString = [](Scalar scalar) -> std::string + { + switch (scalar) + { + case Scalar::Float16: + case Scalar::Float32: + case Scalar::Float64: + return "float"; + case Scalar::Int8: + case Scalar::Int16: + case Scalar::Int32: + case Scalar::Int64: + return "int"; + case Scalar::UInt8: + case Scalar::UInt16: + case Scalar::UInt32: + case Scalar::UInt64: + return "uint"; + case Scalar::Bool: + return "bool"; + default: + return {}; + } + }; + + switch (type->getKind()) + { + case Kind::Scalar: + if (const std::string base = scalarToString(type->getScalarType()); !base.empty()) + { + return base; + } + break; + case Kind::Vector: + { + const std::string elementTypeName = getTypeName(type->getElementType()); + const size_t elementCount = type->getElementCount(); + if (!elementTypeName.empty() && elementCount > 0) + { + return elementTypeName + std::to_string(elementCount); + } + break; + } + case Kind::Matrix: + { + const std::string elementTypeName = getTypeName(type->getElementType()); + const unsigned rows = type->getRowCount(); + const unsigned cols = type->getColumnCount(); + if (!elementTypeName.empty() && rows > 0 && cols > 0) + { + return elementTypeName + std::to_string(rows) + "x" + std::to_string(cols); + } + break; + } + default: + break; + } + + if (const char* simpleName = type->getName()) + { + if (simpleName[0] != '\0') + { + return simpleName; + } + } + + Slang::ComPtr fullNameBlob; + if (SLANG_SUCCEEDED(type->getFullName(fullNameBlob.writeRef())) && fullNameBlob && + fullNameBlob->getBufferSize() > 0) + { + const char* buffer = static_cast(fullNameBlob->getBufferPointer()); + std::string fullName(buffer, buffer + fullNameBlob->getBufferSize()); + + auto parseTemplateType = [&](std::string_view name) -> std::string + { + if (name.substr(0, 7) == "vector<") + { + size_t commaPos = name.find(','); + size_t endPos = name.rfind('>'); + if (commaPos != std::string_view::npos && endPos != std::string_view::npos && + commaPos + 1 < endPos) + { + std::string base = trim(name.substr(7, commaPos - 7)); + std::string count = trim(name.substr(commaPos + 1, endPos - commaPos - 1)); + return base + count; + } + } + if (name.substr(0, 7) == "matrix<") + { + size_t firstComma = name.find(','); + size_t secondComma = name.find(',', firstComma + 1); + size_t endPos = name.rfind('>'); + if (firstComma != std::string_view::npos && secondComma != std::string_view::npos && + endPos != std::string_view::npos) + { + std::string base = trim(name.substr(7, firstComma - 7)); + std::string rows = trim(name.substr(firstComma + 1, secondComma - firstComma - 1)); + std::string cols = trim(name.substr(secondComma + 1, endPos - secondComma - 1)); + return base + rows + "x" + cols; + } + } + return std::string(name); + }; + + return parseTemplateType(fullName); + } + + return {}; +} + +bool isTopLevelFunction(slang::DeclReflection* functionDecl) +{ + if (!functionDecl) + { + return false; + } + + using Kind = slang::DeclReflection::Kind; + for (slang::DeclReflection* parent = functionDecl->getParent(); parent; + parent = parent->getParent()) + { + switch (parent->getKind()) + { + case Kind::Module: + case Kind::Namespace: + return true; + case Kind::Generic: + continue; + default: + return false; + } + } + + return false; +} + +// Recursively gather function declarations defined in the supplied Slang module. +void collectFunctionInfos( + slang::DeclReflection* decl, + std::vector& functions, + std::unordered_set& seenNames) +{ + if (!decl) + { + return; + } + + using Kind = slang::DeclReflection::Kind; + + switch (decl->getKind()) + { + case Kind::Func: + if (auto* functionReflection = decl->asFunction()) + { + if (const char* name = functionReflection->getName()) + { + if (*name && seenNames.insert(name).second && isTopLevelFunction(decl)) + { + FunctionInfo info; + info.name = name; + + if (slang::TypeReflection* returnType = functionReflection->getReturnType()) + { + info.returnType = getTypeName(returnType); + } + if (info.returnType.empty()) + { + info.returnType = "void"; + } + + const unsigned paramCount = functionReflection->getParameterCount(); + info.parameters.reserve(paramCount); + for (unsigned i = 0; i < paramCount; ++i) + { + slang::VariableReflection* paramReflection = + functionReflection->getParameterByIndex(i); + ParameterInfo paramInfo; + if (auto* typeReflection = paramReflection->getType()) + { + paramInfo.typeName = getTypeName(typeReflection); + } + if (paramInfo.typeName.empty()) + { + paramInfo.typeName = "auto"; + } + if (const char* paramName = paramReflection->getName()) + { + paramInfo.name = paramName; + } + if (paramInfo.name.empty()) + { + paramInfo.name = "param" + std::to_string(i); + } + info.parameters.push_back(std::move(paramInfo)); + } + + functions.push_back(std::move(info)); + } + } + } + break; + case Kind::Generic: + if (auto* genericDecl = decl->asGeneric()) + { + collectFunctionInfos( + genericDecl->getInnerDecl(), + functions, + seenNames); + } + break; + default: + break; + } + + for (auto* child : decl->getChildren()) + { + collectFunctionInfos(child, functions, seenNames); + } +} + +struct EntryPointField +{ + int bufferIndex = 0; + std::string fieldName; + std::string baseName; +}; + +std::string rewriteHLSLWithWrappers( + const std::string& originalHlsl, + const std::vector& functions) +{ + std::string result = originalHlsl; + + std::unordered_map> bufferIndexToFields; + std::unordered_map> baseNameToFields; + + const std::string structPrefix = "struct EntryPointParams_"; + size_t searchPos = 0; + while (true) + { + size_t structPos = result.find(structPrefix, searchPos); + if (structPos == std::string::npos) + { + break; + } + + size_t indexPos = structPos + structPrefix.size(); + size_t indexEnd = indexPos; + while (indexEnd < result.size() && std::isdigit(static_cast(result[indexEnd]))) + { + ++indexEnd; + } + if (indexEnd == indexPos) + { + searchPos = indexEnd; + continue; + } + + int bufferIndex = std::stoi(result.substr(indexPos, indexEnd - indexPos)); + + size_t braceOpen = result.find('{', indexEnd); + if (braceOpen == std::string::npos) + { + break; + } + size_t braceClose = result.find("};", braceOpen); + if (braceClose == std::string::npos) + { + break; + } + + size_t fieldPos = braceOpen + 1; + while (fieldPos < braceClose) + { + size_t semicolon = result.find(';', fieldPos); + if (semicolon == std::string::npos || semicolon > braceClose) + { + break; + } + + std::string line = trim(std::string_view(result).substr(fieldPos, semicolon - fieldPos)); + if (!line.empty()) + { + size_t lastSpace = line.find_last_of(" \t"); + if (lastSpace != std::string::npos && lastSpace + 1 < line.size()) + { + std::string fieldName = line.substr(lastSpace + 1); + size_t bracketPos = fieldName.find('['); + if (bracketPos != std::string::npos) + { + fieldName = fieldName.substr(0, bracketPos); + } + + std::string baseName = fieldName; + size_t underscorePos = baseName.rfind('_'); + if (underscorePos != std::string::npos) + { + baseName = baseName.substr(0, underscorePos); + } + + EntryPointField field{bufferIndex, fieldName, baseName}; + bufferIndexToFields[bufferIndex].push_back(field); + baseNameToFields[baseName].push_back(field); + } + } + + fieldPos = semicolon + 1; + } + + searchPos = braceClose; + } + + const std::string attributeToken = "[shader(\"dispatch\")]export"; + size_t searchFrom = 0; + for (const FunctionInfo& func : functions) + { + size_t attrPos = result.find(attributeToken, searchFrom); + if (attrPos != std::string::npos) + { + size_t attrEnd = attrPos + attributeToken.size(); + if (attrEnd < result.size() && result[attrEnd] == '\r') + { + ++attrEnd; + } + if (attrEnd < result.size() && result[attrEnd] == '\n') + { + ++attrEnd; + } + result.erase(attrPos, attrEnd - attrPos); + searchFrom = attrPos; + } + + size_t namePos = result.find(func.name + "(", searchFrom); + if (namePos == std::string::npos) + { + namePos = result.find(func.name + "("); + } + if (namePos != std::string::npos) + { + const std::string entryName = "__slang_entry_" + func.name; + result.replace(namePos, func.name.size(), entryName); + searchFrom = namePos + entryName.size(); + } + } + + std::ostringstream wrapperBuilder; + wrapperBuilder << "\n"; + + for (size_t functionIndex = 0; functionIndex < functions.size(); ++functionIndex) + { + const FunctionInfo& func = functions[functionIndex]; + const std::string entryName = "__slang_entry_" + func.name; + + std::string parameterList; + for (size_t i = 0; i < func.parameters.size(); ++i) + { + if (i > 0) + { + parameterList += ", "; + } + parameterList += func.parameters[i].typeName + " " + func.parameters[i].name; + } + + wrapperBuilder << func.returnType << " " << func.name << "(" << parameterList << ")\n"; + wrapperBuilder << "{\n"; + + std::unordered_set emittedAssignments; + + for (size_t paramIndex = 0; paramIndex < func.parameters.size(); ++paramIndex) + { + const ParameterInfo& param = func.parameters[paramIndex]; + + auto baseIt = baseNameToFields.find(param.name); + if (baseIt != baseNameToFields.end()) + { + for (const EntryPointField& field : baseIt->second) + { + std::string assignment = + " entryPointParams_" + std::to_string(field.bufferIndex) + "." + + field.fieldName + " = " + param.name + ";\n"; + if (emittedAssignments.insert(assignment).second) + { + wrapperBuilder << assignment; + } + } + } + + auto bufferIt = bufferIndexToFields.find(static_cast(functionIndex)); + if (bufferIt != bufferIndexToFields.end() && paramIndex < bufferIt->second.size()) + { + const EntryPointField& field = bufferIt->second[paramIndex]; + std::string assignment = + " entryPointParams_" + std::to_string(field.bufferIndex) + "." + + field.fieldName + " = " + param.name + ";\n"; + if (emittedAssignments.insert(assignment).second) + { + wrapperBuilder << assignment; + } + } + } + + if (func.returnType == "void") + { + wrapperBuilder << " " << entryName << "();\n"; + wrapperBuilder << " return;\n"; + } + else + { + wrapperBuilder << " return " << entryName << "();\n"; + } + + wrapperBuilder << "}\n\n"; + } + + const std::string wrappers = wrapperBuilder.str(); + if (!wrappers.empty()) + { + const size_t endifPos = result.rfind("#endif"); + const size_t ifndefPos = result.find("#ifndef"); + const size_t definePos = result.find("#define", ifndefPos != std::string::npos ? ifndefPos : 0); + + const bool hasGuard = + ifndefPos != std::string::npos && definePos != std::string::npos && + endifPos != std::string::npos && ifndefPos < definePos && definePos < endifPos; + + if (hasGuard) + { + result.insert(endifPos, wrappers); + } + else + { + result += wrappers; + } + } + + return result; +} + +int main(int argc, char** argv) +{ + if (argc < 2) + { + std::cerr << "Usage: " << (argc > 0 ? argv[0] : "modular_slang") << " " << std::endl; + return 1; + } + + std::filesystem::path modulePath = std::filesystem::absolute(argv[1]); + if (!std::filesystem::exists(modulePath)) + { + std::cerr << "Module not found: " << modulePath << std::endl; + return 1; + } + + if (modulePath.extension() != ".slang") + { + std::cerr << "Expected a .slang file: " << modulePath << std::endl; + return 1; + } + + std::string moduleName = modulePath.stem().string(); + std::string searchPath = modulePath.has_parent_path() + ? modulePath.parent_path().string() + : std::filesystem::current_path().string(); + + std::filesystem::path outputPath = modulePath; + outputPath.replace_extension(".hlsl"); + + // 1. Session Creation + Slang::ComPtr globalSession; + checkResult("slang::createGlobalSession", slang::createGlobalSession(globalSession.writeRef())); + + // 2. Target Configuration + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HLSL; + targetDesc.profile = globalSession->findProfile("lib_6_6"); + targetDesc.flags = SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM; + + std::vector targetOptions; + { + slang::CompilerOptionEntry entry = {}; + entry.name = slang::CompilerOptionName::NoHLSLBinding; + entry.value.intValue0 = 1; + targetOptions.push_back(entry); + } + { + slang::CompilerOptionEntry entry = {}; + entry.name = slang::CompilerOptionName::NoHLSLPackConstantBufferElements; + entry.value.intValue0 = 1; + targetOptions.push_back(entry); + } + targetDesc.compilerOptionEntries = targetOptions.data(); + targetDesc.compilerOptionEntryCount = static_cast(targetOptions.size()); + + slang::SessionDesc sessionDesc = {}; + sessionDesc.targets = &targetDesc; + sessionDesc.targetCount = 1; + const char* searchPaths[] = { searchPath.c_str() }; + sessionDesc.searchPaths = searchPaths; + sessionDesc.searchPathCount = 1; + + std::vector sessionOptions; + { + slang::CompilerOptionEntry entry = {}; + entry.name = slang::CompilerOptionName::NoMangle; + entry.value.intValue0 = 1; + sessionOptions.push_back(entry); + } + { + slang::CompilerOptionEntry entry = {}; + entry.name = slang::CompilerOptionName::DisableNonEssentialValidations; + entry.value.intValue0 = 1; + sessionOptions.push_back(entry); + } + sessionDesc.compilerOptionEntries = sessionOptions.data(); + sessionDesc.compilerOptionEntryCount = static_cast(sessionOptions.size()); + + Slang::ComPtr session; + checkResult("IGlobalSession::createSession", globalSession->createSession(sessionDesc, session.writeRef())); + + // 3. Module Loading (from the supplied Slang source file) + Slang::ComPtr libraryModule; + { + Slang::ComPtr diagnosticsBlob; + libraryModule = session->loadModule(moduleName.c_str(), diagnosticsBlob.writeRef()); + const std::string diagnosticsContext = "loadModule: " + moduleName; + printDiagnostics(diagnosticsContext.c_str(), diagnosticsBlob); + if (!libraryModule) + { + std::cerr << "Failed to load module '" << moduleName << "'." << std::endl; + return 1; + } + } + // 4. Discover top-level functions to treat as entry points + std::vector functions; + std::unordered_set seenNames; + slang::DeclReflection* moduleReflection = libraryModule->getModuleReflection(); + if (!moduleReflection) + { + std::cerr << "Failed to retrieve reflection data for module '" << moduleName << "'." + << std::endl; + return 1; + } + collectFunctionInfos(moduleReflection, functions, seenNames); + + if (functions.empty()) + { + std::cerr << "No functions found in module '" << moduleName << "'." << std::endl; + return 1; + } + + // 5. Compile the translation unit with whole-program emission + Slang::ComPtr compileRequest; + checkResult( + "ISession::createCompileRequest", + session->createCompileRequest(compileRequest.writeRef())); + + compileRequest->setCodeGenTarget(SLANG_HLSL); + compileRequest->setTargetProfile(0, targetDesc.profile); + compileRequest->setTargetFlags(0, SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM); + + SlangCompileFlags compileFlags = compileRequest->getCompileFlags(); + compileFlags |= SLANG_COMPILE_FLAG_NO_MANGLING; + compileRequest->setCompileFlags(compileFlags); + + compileRequest->addSearchPath(searchPath.c_str()); + + const int translationUnitIndex = compileRequest->addTranslationUnit( + SLANG_SOURCE_LANGUAGE_SLANG, + moduleName.c_str()); + compileRequest->addTranslationUnitSourceFile( + translationUnitIndex, + modulePath.string().c_str()); + + for (const FunctionInfo& func : functions) + { + const int entryPointIndex = compileRequest->addEntryPoint( + translationUnitIndex, + func.name.c_str(), + SLANG_STAGE_DISPATCH); + if (entryPointIndex < 0) + { + std::cerr << "Failed to register entry point '" << func.name << "'." << std::endl; + return 1; + } + } + + SlangResult compileResult = compileRequest->compile(); + Slang::ComPtr compileDiagnostics; + compileRequest->getDiagnosticOutputBlob(compileDiagnostics.writeRef()); + checkResult("ICompileRequest::compile", compileResult, compileDiagnostics); + + Slang::ComPtr targetCodeBlob; + checkResult( + "ICompileRequest::getTargetCodeBlob", + compileRequest->getTargetCodeBlob(0, targetCodeBlob.writeRef())); + + if (!targetCodeBlob || targetCodeBlob->getBufferSize() == 0) + { + std::cerr << "No HLSL was generated for module '" << moduleName << "'." << std::endl; + return 1; + } + + std::string hlslSource( + static_cast(targetCodeBlob->getBufferPointer()), + static_cast(targetCodeBlob->getBufferSize())); + + const std::string finalHlsl = rewriteHLSLWithWrappers(hlslSource, functions); + + // 6. Write HLSL output to a sibling .hlsl file + std::ofstream outputFile(outputPath, std::ios::binary); + if (!outputFile) + { + std::cerr << "Failed to open output path for writing: " << outputPath << std::endl; + return 1; + } + + outputFile.write(finalHlsl.data(), static_cast(finalHlsl.size())); + outputFile.close(); + + if (!outputFile) + { + std::cerr << "Failed to write HLSL output to " << outputPath << std::endl; + return 1; + } + + // Also stream to standard output to retain previous behavior + std::cout.write(finalHlsl.data(), static_cast(finalHlsl.size())); + if (finalHlsl.empty() || finalHlsl.back() != '\n') + { + std::cout << std::endl; + } + + std::cerr << "Generated HLSL written to " << outputPath << std::endl; + + return 0; +} diff --git a/test.slang b/test.slang new file mode 100644 index 0000000..324e53f --- /dev/null +++ b/test.slang @@ -0,0 +1,24 @@ +#ifndef __CUSTOM31_INC +#define __CUSTOM31_INC + +[Differentiable] +public float3 cart_to_cyl(uniform float3 xyz) { + float radius = length(xyz.xy); + float theta = atan2(xyz.y, xyz.x); + return float3(radius, theta, xyz.z); +} + +public float3x3 cart_to_cyl_jacobian(uniform float3 xyz, uniform float3 n) { + DifferentialPair dp_x = diffPair(xyz, float3(1, 0, 0)); + DifferentialPair dp_y = diffPair(xyz, float3(0, 1, 0)); + DifferentialPair dp_z = diffPair(xyz, float3(0, 0, 1)); + + DifferentialPair dp_x_out = fwd_diff(cart_to_cyl)(dp_x); + DifferentialPair dp_y_out = fwd_diff(cart_to_cyl)(dp_y); + DifferentialPair dp_z_out = fwd_diff(cart_to_cyl)(dp_z); + + return float3x3(dp_x_out.d, dp_y_out.d, dp_z_out.d); +} + +#endif // __CUSTOM31_INC + -- cgit v1.2.3