summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-vm-bytecode.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-04-28 11:42:22 -0700
committerGitHub <noreply@github.com>2025-04-28 11:42:22 -0700
commitc39c29bf4c52a85d7c83cc8b66ae45e265f9e078 (patch)
tree969339828d49d7db92ed9294a17bd34cc021db84 /source/slang/slang-vm-bytecode.cpp
parent8f6c6e333c06ae1c3b9f00396563c14a2ae09b4d (diff)
Add Slang Byte Code generation and interpreter. (#6896)
* Add Slang Byte Code generation and interpreter. * Fix compile issues. * format code * More compile fix. * Fix clang issue. * Fix more clang issues. * Another clang fix. * Fix clang issues. * Fix another clang issue. * Fix wasm build. * Update building.md * Fix test-server. * Fix compile error. * Fix bug. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source/slang/slang-vm-bytecode.cpp')
-rw-r--r--source/slang/slang-vm-bytecode.cpp424
1 files changed, 424 insertions, 0 deletions
diff --git a/source/slang/slang-vm-bytecode.cpp b/source/slang/slang-vm-bytecode.cpp
new file mode 100644
index 000000000..1eafa4e57
--- /dev/null
+++ b/source/slang/slang-vm-bytecode.cpp
@@ -0,0 +1,424 @@
+#include "slang-vm-bytecode.h"
+
+#include "core/slang-blob.h"
+#include "core/slang-stream.h"
+#include "core/slang-string-escape-util.h"
+
+using namespace slang;
+
+namespace Slang
+{
+static SlangResult consumeFourCC(MemoryStreamBase& stream, uint32_t expected)
+{
+ uint32_t fourCC = 0;
+ size_t bytesRead = 0;
+ SLANG_RETURN_ON_FAIL(stream.read(&fourCC, sizeof(fourCC), bytesRead));
+ if (fourCC != expected)
+ {
+ return SLANG_FAIL;
+ }
+ return SLANG_OK;
+}
+
+template<typename T>
+static SlangResult readValue(MemoryStreamBase& stream, T& value)
+{
+ size_t bytesRead = 0;
+ SLANG_RETURN_ON_FAIL(stream.read(&value, sizeof(T), bytesRead));
+ if (bytesRead != sizeof(T))
+ {
+ return SLANG_FAIL; // Not enough data
+ }
+ return SLANG_OK;
+}
+
+static SlangResult readUInt32(MemoryStreamBase& stream, uint32_t& value)
+{
+ return readValue(stream, value);
+}
+
+SlangResult initVMModule(uint8_t* code, uint32_t codeSize, VMModuleView* moduleView)
+{
+ MemoryStreamBase stream(FileAccess::Read, code, codeSize);
+ moduleView->code = code;
+
+ // Check the FourCC
+ SLANG_RETURN_ON_FAIL(consumeFourCC(stream, kSlangByteCodeFourCC));
+
+ // Check the version
+ uint32_t version;
+ size_t bytesRead = 0;
+ SLANG_RETURN_ON_FAIL(stream.read(&version, sizeof(version), bytesRead));
+ if (version > kSlangByteCodeVersion)
+ {
+ return SLANG_FAIL; // Unsupported version
+ }
+
+ // Read the function section
+ SLANG_RETURN_ON_FAIL(consumeFourCC(stream, kSlangByteCodeFunctionsFourCC));
+ uint32_t functionSectionSize = 0;
+ SLANG_RETURN_ON_FAIL(readUInt32(stream, functionSectionSize));
+ auto funcDataStart = stream.getPosition();
+ if (functionSectionSize < sizeof(uint32_t)) // At least the function count
+ {
+ return SLANG_FAIL; // Invalid section size
+ }
+
+ SLANG_RETURN_ON_FAIL(readUInt32(stream, moduleView->functionCount));
+ moduleView->functionOffsets = reinterpret_cast<uint32_t*>(code + stream.getPosition());
+
+ stream.seek(SeekOrigin::Start, funcDataStart + functionSectionSize);
+
+ // Read the kernel blob section
+ SLANG_RETURN_ON_FAIL(consumeFourCC(stream, kSlangByteCodeKernelBlobFourCC));
+ SLANG_RETURN_ON_FAIL(readUInt32(stream, moduleView->kernelBlobSize));
+ if (moduleView->kernelBlobSize > codeSize - stream.getPosition())
+ {
+ return SLANG_FAIL; // Invalid kernel blob size
+ }
+ moduleView->kernelBlob = code + stream.getPosition();
+ stream.seek(SeekOrigin::Current, moduleView->kernelBlobSize);
+
+ // Read the constants section
+ SLANG_RETURN_ON_FAIL(consumeFourCC(stream, kSlangByteCodeConstantsFourCC));
+ SLANG_RETURN_ON_FAIL(readUInt32(stream, moduleView->constantBlobSize));
+ if (moduleView->constantBlobSize < sizeof(uint32_t)) // At least the constant count
+ {
+ return SLANG_FAIL; // Invalid section size
+ }
+ SLANG_RETURN_ON_FAIL(readUInt32(stream, moduleView->stringCount));
+ moduleView->stringOffsets = reinterpret_cast<uint32_t*>(code + stream.getPosition());
+ stream.seek(SeekOrigin::Current, moduleView->stringCount * sizeof(uint32_t));
+ moduleView->constants = code + stream.getPosition();
+
+ for (uint32_t i = 0; i < moduleView->functionCount; i++)
+ {
+ auto functionStart = code + moduleView->functionOffsets[i];
+ auto header = (VMFuncHeader*)(functionStart);
+ VMFunctionView functionView;
+ functionView.moduleView = moduleView;
+ functionView.header = (VMFuncHeader*)(functionStart);
+ functionView.paramOffsets = (uint32_t*)(functionStart + sizeof(VMFuncHeader));
+ functionView.name = (const char*)moduleView->constants +
+ moduleView->stringOffsets[functionView.header->name.offset];
+ functionView.functionCode =
+ (uint8_t*)functionView.paramOffsets + sizeof(uint32_t) * header->parameterCount;
+ functionView.functionCodeEnd = functionView.functionCode + functionView.header->codeSize;
+ moduleView->functionViews.add(functionView);
+ }
+ return SLANG_OK;
+}
+
+StringBuilder& operator<<(StringBuilder& sb, VMOp op)
+{
+ switch (op)
+ {
+ case VMOp::Add:
+ sb << "add";
+ break;
+ case VMOp::Sub:
+ sb << "sub";
+ break;
+ case VMOp::Mul:
+ sb << "mul";
+ break;
+ case VMOp::Div:
+ sb << "div";
+ break;
+ case VMOp::Rem:
+ sb << "rem";
+ break;
+ case VMOp::And:
+ sb << "and";
+ break;
+ case VMOp::Or:
+ sb << "or";
+ break;
+ case VMOp::BitXor:
+ sb << "bitxor";
+ break;
+ case VMOp::BitNot:
+ sb << "bitnot";
+ break;
+ case VMOp::Shl:
+ sb << "shl";
+ break;
+ case VMOp::Shr:
+ sb << "shr";
+ break;
+ case VMOp::Equal:
+ sb << "equal";
+ break;
+ case VMOp::Neq:
+ sb << "neq";
+ break;
+ case VMOp::Less:
+ sb << "less";
+ break;
+ case VMOp::Leq:
+ sb << "leq";
+ break;
+ case VMOp::Greater:
+ sb << "greater";
+ break;
+ case VMOp::Geq:
+ sb << "geq";
+ break;
+ case VMOp::Nop:
+ sb << "nop";
+ break;
+ case VMOp::Neg:
+ sb << "neg";
+ break;
+ case VMOp::Not:
+ sb << "not";
+ break;
+ case VMOp::Jump:
+ sb << "jump";
+ break;
+ case VMOp::JumpIf:
+ sb << "jumpif";
+ break;
+ case VMOp::Dispatch:
+ sb << "dispatch";
+ break;
+ case VMOp::Load:
+ sb << "load";
+ break;
+ case VMOp::Store:
+ sb << "store";
+ break;
+ case VMOp::Copy:
+ sb << "copy";
+ break;
+ case VMOp::GetWorkingSetPtr:
+ sb << "get_working_set_ptr";
+ break;
+ case VMOp::GetElementPtr:
+ sb << "get_element_ptr";
+ break;
+ case VMOp::OffsetPtr:
+ sb << "offset_ptr";
+ break;
+ case VMOp::GetElement:
+ sb << "get_element";
+ break;
+ case VMOp::Cast:
+ sb << "cast";
+ break;
+ case VMOp::CallExt:
+ sb << "call_ext";
+ break;
+ case VMOp::Call:
+ sb << "call";
+ break;
+ case VMOp::Swizzle:
+ sb << "swizzle";
+ break;
+ case VMOp::Ret:
+ sb << "ret";
+ break;
+ case VMOp::Print:
+ sb << "print";
+ break;
+ default:
+ sb << "unknown_op(" << static_cast<uint32_t>(op) << ")";
+ break;
+ }
+ return sb;
+}
+
+StringBuilder& operator<<(StringBuilder& sb, ArithmeticExtCode extCode)
+{
+ switch (extCode.scalarType)
+ {
+ case kSlangByteCodeScalarTypeSignedInt:
+ sb << "i";
+ break;
+ case kSlangByteCodeScalarTypeUnsignedInt:
+ sb << "u";
+ break;
+ case kSlangByteCodeScalarTypeFloat:
+ sb << "f";
+ break;
+ default:
+ sb << "x";
+ break;
+ }
+ sb << (8 << extCode.scalarBitWidth);
+ if (extCode.vectorSize > 1)
+ {
+ sb << "v" << extCode.vectorSize;
+ }
+ return sb;
+}
+
+void printVMInst(StringBuilder& sb, VMModuleView* moduleView, VMInstHeader* inst)
+{
+ auto lenBeforeOpCode = sb.getLength();
+ sb << inst->opcode;
+ if (inst->opcodeExtension != 0)
+ {
+ switch (inst->opcode)
+ {
+ case VMOp::Add:
+ case VMOp::Sub:
+ case VMOp::Mul:
+ case VMOp::Div:
+ case VMOp::Rem:
+ case VMOp::And:
+ case VMOp::Or:
+ case VMOp::BitXor:
+ case VMOp::BitNot:
+ case VMOp::BitAnd:
+ case VMOp::BitOr:
+ case VMOp::Neg:
+ case VMOp::Not:
+ case VMOp::Shl:
+ case VMOp::Shr:
+ case VMOp::Equal:
+ case VMOp::Neq:
+ case VMOp::Less:
+ case VMOp::Leq:
+ case VMOp::Greater:
+ case VMOp::Geq:
+ {
+ ArithmeticExtCode extCode;
+ memcpy(&extCode, &inst->opcodeExtension, sizeof(extCode));
+ sb << "." << extCode;
+ }
+ break;
+ case VMOp::Cast:
+ {
+ ArithmeticExtCode extCode;
+ memcpy(&extCode, &inst->opcodeExtension, sizeof(extCode));
+ sb << "." << extCode;
+ uint32_t fromCode = inst->opcodeExtension >> 16;
+ memcpy(&extCode, &fromCode, sizeof(extCode));
+ sb << "." << extCode;
+ }
+ break;
+ default:
+ sb << "." << inst->opcodeExtension;
+ break;
+ }
+ }
+ auto opCodeLength = (int)(sb.getLength() - lenBeforeOpCode);
+ static const int kOpCodeColumnWidth = 20;
+ if (opCodeLength < kOpCodeColumnWidth)
+ {
+ for (int i = 0; i < kOpCodeColumnWidth - opCodeLength; i++)
+ {
+ sb << " ";
+ }
+ }
+ else
+ {
+ sb << " ";
+ }
+ for (uint32_t i = 0; i < inst->operandCount; i++)
+ {
+ if (i > 0)
+ sb << ", ";
+ auto operand = inst->getOperand(i);
+ switch (operand.sectionId)
+ {
+ case kSlangByteCodeSectionConstants:
+ switch (operand.getType())
+ {
+ case OperandDataType::Int32:
+ {
+ int32_t val;
+ moduleView->getConstant<int32_t>(operand, val);
+ sb << "i32(" << val << ")";
+ continue;
+ }
+ case OperandDataType::Int64:
+ {
+ int64_t val;
+ moduleView->getConstant<int64_t>(operand, val);
+ sb << "i64(" << val << ")";
+ continue;
+ }
+ case OperandDataType::Float32:
+ {
+ float val;
+ moduleView->getConstant<float>(operand, val);
+ sb << "f32(" << val << ")";
+ continue;
+ }
+ case OperandDataType::Float64:
+ {
+ double val;
+ moduleView->getConstant<double>(operand, val);
+ sb << "f32(" << val << ")";
+ continue;
+ }
+ }
+ sb << "const:";
+ break;
+ case kSlangByteCodeSectionInsts:
+ sb << "inst:";
+ break;
+ case kSlangByteCodeSectionWorkingSet:
+ sb << "ws:";
+ break;
+ case kSlangByteCodeSectionImmediate:
+ sb << "!";
+ break;
+ case kSlangByteCodeSectionFuncs:
+ sb << moduleView->getFunction(operand.offset).name;
+ continue;
+ case kSlangByteCodeSectionStrings:
+ sb << "str:";
+ if (operand.offset < moduleView->stringCount)
+ {
+ auto str = StringEscapeUtil::escapeString(UnownedStringSlice(
+ ((char*)moduleView->constants + moduleView->stringOffsets[operand.offset])));
+ sb << str;
+ }
+ else
+ {
+ sb << "<invalid string index>";
+ }
+ continue;
+ default:
+ sb << "section(" << operand.sectionId << ")@";
+ break;
+ }
+ sb << String(inst->getOperand(i).offset, 16);
+ }
+}
+
+StringBuilder& operator<<(StringBuilder& sb, VMModuleView& module)
+{
+ static const int addrColumnSize = 6;
+ for (uint32_t i = 0; i < module.functionCount; i++)
+ {
+ auto f = module.getFunction(i);
+ sb << "func " << f.name << ":\n";
+ for (auto inst : f)
+ {
+ sb << " ";
+ auto loc = ((uint8_t*)inst - f.functionCode);
+ auto pos = sb.getLength();
+ sb << String((uint32_t)loc, 16) << ": ";
+ auto addrLength = (int)(sb.getLength() - pos);
+ for (int j = 0; j < addrColumnSize - addrLength; j++)
+ {
+ sb << " ";
+ }
+ printVMInst(sb, &module, inst);
+ sb << "\n";
+ }
+ }
+ return sb;
+}
+
+VMFunctionView VMModuleView::getFunction(Index index) const
+{
+ if (index >= functionCount)
+ return {};
+ return functionViews[index];
+}
+} // namespace Slang