summaryrefslogtreecommitdiff
path: root/source/slang/slang-vm.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-04-28 11:42:22 -0700
committerGitHub <noreply@github.com>2025-04-28 11:42:22 -0700
commitc39c29bf4c52a85d7c83cc8b66ae45e265f9e078 (patch)
tree969339828d49d7db92ed9294a17bd34cc021db84 /source/slang/slang-vm.cpp
parent8f6c6e333c06ae1c3b9f00396563c14a2ae09b4d (diff)
Add Slang Byte Code generation and interpreter. (#6896)
* Add Slang Byte Code generation and interpreter. * Fix compile issues. * format code * More compile fix. * Fix clang issue. * Fix more clang issues. * Another clang fix. * Fix clang issues. * Fix another clang issue. * Fix wasm build. * Update building.md * Fix test-server. * Fix compile error. * Fix bug. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source/slang/slang-vm.cpp')
-rw-r--r--source/slang/slang-vm.cpp254
1 files changed, 254 insertions, 0 deletions
diff --git a/source/slang/slang-vm.cpp b/source/slang/slang-vm.cpp
new file mode 100644
index 000000000..05d53acb8
--- /dev/null
+++ b/source/slang/slang-vm.cpp
@@ -0,0 +1,254 @@
+#include "slang-vm.h"
+
+#include "core/slang-blob.h"
+#include "slang-vm-inst-impl.h"
+
+namespace Slang
+{
+
+// Our VM insts need to be 8-byte aligned, so we can replace the opcode with function pointers and
+// sectionId with data pointers.
+static_assert(sizeof(VMOperand) % 8 == 0);
+static_assert(sizeof(VMInstHeader) % 8 == 0);
+static_assert(sizeof(VMOperand) == sizeof(VMExecOperand));
+static_assert(sizeof(VMInstHeader) == sizeof(VMExecInstHeader));
+
+ISlangUnknown* ByteCodeInterpreter::getInterface(const Guid& guid)
+{
+ if (guid == ISlangUnknown::getTypeGuid() || guid == IByteCodeRunner::getTypeGuid())
+ return static_cast<IByteCodeRunner*>(this);
+
+ return nullptr;
+}
+
+SlangResult ByteCodeInterpreter::prepareModuleForExecution()
+{
+ m_stringLits.clear();
+ m_stringLits.setCount(m_moduleView.stringCount);
+ for (uint32_t i = 0; i < m_moduleView.stringCount; i++)
+ {
+ auto strOffset = m_moduleView.stringOffsets[i];
+ const char* str = (const char*)m_moduleView.constants + strOffset;
+ m_stringLits[i] = str;
+ }
+ m_stringLitsPtr = m_stringLits.getBuffer();
+
+ m_functions.setCount(m_moduleView.functionCount);
+ for (uint32_t i = 0; i < m_moduleView.functionCount; i++)
+ {
+ auto func = m_moduleView.getFunction(i);
+ auto& exeFunc = m_functions[i];
+ exeFunc.m_codeBuffer.setCount(func.header->codeSize / sizeof(uint64_t));
+ exeFunc.m_header = func.header;
+ for (uint32_t j = 0; j < func.header->parameterCount; j++)
+ {
+ exeFunc.m_parameterOffsets.add(func.header->getParameterOffset(j));
+ }
+ exeFunc.m_parameterOffsets.add(func.header->parameterSizeInBytes);
+
+ // Copy the code into the executable function buffer
+ memcpy(exeFunc.m_codeBuffer.getBuffer(), func.functionCode, func.header->codeSize);
+
+ // Replace the instruction headers with function pointers
+ for (auto inst : exeFunc)
+ {
+ VMInstHeader* instHeader = reinterpret_cast<VMInstHeader*>(inst);
+ auto handler = mapInstToFunction(instHeader, &m_moduleView, m_extInstHandlers);
+ if (!handler)
+ {
+ StringBuilder instStr;
+ printVMInst(instStr, &m_moduleView, instHeader);
+ reportError(
+ "Cannot find execution handler for instruction %s\n",
+ instStr.toString().getBuffer());
+ return SLANG_FAIL;
+ }
+ inst->functionPtr = handler;
+ for (uint32_t operandIdx = 0; operandIdx < instHeader->operandCount; operandIdx++)
+ {
+ auto& operand = instHeader->getOperand(operandIdx);
+ auto& execOpernad = inst->getOperand(operandIdx);
+ switch (operand.sectionId)
+ {
+ case kSlangByteCodeSectionConstants:
+ execOpernad.section = &m_moduleView.constants;
+ break;
+ case kSlangByteCodeSectionInsts:
+ execOpernad.section = (uint8_t**)&m_currentFuncCode;
+ break;
+ case kSlangByteCodeSectionWorkingSet:
+ execOpernad.section = (uint8_t**)&m_currentWorkingSet;
+ break;
+ case kSlangByteCodeSectionStrings:
+ execOpernad.section = (uint8_t**)&m_stringLitsPtr;
+ execOpernad.offset *= sizeof(const char*);
+ break;
+ }
+ }
+ }
+ }
+
+ return SLANG_OK;
+}
+
+SLANG_NO_THROW SlangResult SLANG_MCALL ByteCodeInterpreter::loadModule(IBlob* moduleBlob)
+{
+ m_stack.reserve(128);
+ m_workingSetBuffer.reserve(1024 * 1024); // Reserve 1MB for working set
+ m_currentWorkingSet = m_workingSetBuffer.getBuffer();
+
+ m_errorBuilder.clear();
+ m_code.addRange((uint8_t*)(moduleBlob->getBufferPointer()), moduleBlob->getBufferSize());
+ SLANG_RETURN_ON_FAIL(
+ initVMModule(m_code.getBuffer(), (uint32_t)moduleBlob->getBufferSize(), &m_moduleView));
+ SLANG_RETURN_ON_FAIL(prepareModuleForExecution());
+ return SLANG_OK;
+}
+
+SLANG_NO_THROW void SLANG_MCALL ByteCodeInterpreter::getErrorString(slang::IBlob** outBlob)
+{
+ *outBlob = StringBlob::moveCreate(m_errorBuilder.produceString()).detach();
+ m_errorBuilder.clear();
+}
+
+SLANG_NO_THROW int SLANG_MCALL ByteCodeInterpreter::findFunctionByName(const char* name)
+{
+ for (uint32_t i = 0; i < m_moduleView.functionCount; i++)
+ {
+ auto func = m_moduleView.getFunction(i);
+ if (UnownedStringSlice(func.name) == name)
+ {
+ return (int)i;
+ }
+ }
+ return -1; // Function not found
+}
+
+SLANG_NO_THROW SlangResult SLANG_MCALL
+ByteCodeInterpreter::getFunctionInfo(uint32_t index, slang::ByteCodeFuncInfo* outInfo)
+{
+ if (index >= m_moduleView.functionCount)
+ {
+ return SLANG_FAIL;
+ }
+ auto func = m_moduleView.getFunction(index);
+ outInfo->parameterCount = func.header->parameterCount;
+ outInfo->returnValueSize = func.header->returnValueSizeInBytes;
+ return SLANG_OK;
+}
+
+SLANG_NO_THROW SlangResult SLANG_MCALL
+ByteCodeInterpreter::selectFunctionByIndex(uint32_t functionIndex)
+{
+ if (functionIndex >= m_moduleView.functionCount)
+ {
+ reportError(
+ "Function index %u out of range [0, %u)",
+ functionIndex,
+ m_moduleView.functionCount);
+ return SLANG_FAIL;
+ }
+ auto func = m_moduleView.getFunction(functionIndex);
+ m_currentFuncCode = m_functions[functionIndex].m_codeBuffer.getBuffer();
+ m_currentInst = reinterpret_cast<VMExecInstHeader*>(m_currentFuncCode);
+ m_workingSetBuffer.setCount(func.header->workingSetSizeInBytes / sizeof(uint64_t));
+ m_currentWorkingSet = m_workingSetBuffer.getBuffer();
+ return SLANG_OK;
+}
+
+SLANG_NO_THROW SlangResult SLANG_MCALL
+ByteCodeInterpreter::execute(void* argumentData, size_t argumentSize)
+{
+ if (!m_currentInst)
+ {
+ reportError("No function selected for execution");
+ return SLANG_FAIL;
+ }
+ if (!m_currentWorkingSet)
+ {
+ reportError("No working set allocated for execution");
+ return SLANG_FAIL;
+ }
+ if ((uint8_t*)m_currentWorkingSet + argumentSize >
+ (uint8_t*)(m_workingSetBuffer.getBuffer() + m_workingSetBuffer.getCount()))
+ {
+ reportError("Argument size exceeds working set.");
+ return SLANG_FAIL;
+ }
+ // Copy the arguments into the working set
+ if (argumentData && argumentSize > 0)
+ {
+ memcpy(m_currentWorkingSet, argumentData, argumentSize);
+ }
+ m_returnValSize = 0;
+ while (m_currentInst)
+ {
+ auto nextInst = m_currentInst->getNextInst();
+ auto currentInst = m_currentInst;
+ m_currentInst = nextInst;
+ currentInst->functionPtr(this, currentInst, m_extInstHandlerUserData);
+ }
+ return SLANG_OK;
+}
+
+ByteCodeInterpreter::ByteCodeInterpreter()
+{
+ m_printCallback = defaultPrintCallback;
+ m_printCallbackUserData = this;
+}
+
+SLANG_NO_THROW SlangResult SLANG_MCALL
+ByteCodeInterpreter::setPrintCallback(slang::VMPrintFunc callback, void* userData)
+{
+ m_printCallback = callback;
+ m_printCallbackUserData = userData;
+ return SLANG_OK;
+}
+
+void ByteCodeInterpreter::defaultPrintCallback(const char* str, void* userData)
+{
+ SLANG_UNUSED(userData);
+ printf("%s", str);
+}
+
+ExecutableFunction::InstIterator ExecutableFunction::begin()
+{
+ ExecutableFunction::InstIterator iter;
+ iter.codePtr = (uint8_t*)m_codeBuffer.getBuffer();
+ return iter;
+}
+
+ExecutableFunction::InstIterator ExecutableFunction::end()
+{
+ ExecutableFunction::InstIterator iter;
+ iter.codePtr = (uint8_t*)(m_codeBuffer.getBuffer() + m_codeBuffer.getCount());
+ return iter;
+}
+
+
+} // namespace Slang
+
+
+SLANG_EXTERN_C SLANG_API SlangResult slang_createByteCodeRunner(
+ const slang::ByteCodeRunnerDesc* desc,
+ slang::IByteCodeRunner** outByteCodeRunner)
+{
+ SLANG_UNUSED(desc);
+ Slang::RefPtr<Slang::ByteCodeInterpreter> runner = new Slang::ByteCodeInterpreter();
+ *outByteCodeRunner = static_cast<slang::IByteCodeRunner*>(runner.detach());
+ return SLANG_OK;
+}
+
+SLANG_EXTERN_C SLANG_API SlangResult
+slang_disassembleByteCode(slang::IBlob* moduleBlob, slang::IBlob** outDisassemblyBlob)
+{
+ Slang::VMModuleView moduleView;
+ SLANG_RETURN_ON_FAIL(Slang::initVMModule(
+ (uint8_t*)moduleBlob->getBufferPointer(),
+ (uint32_t)moduleBlob->getBufferSize(),
+ &moduleView));
+ Slang::StringBuilder sb;
+ sb << moduleView;
+ *outDisassemblyBlob = Slang::StringBlob::moveCreate(sb.produceString()).detach();
+ return SLANG_OK;
+}