summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-vm-inst-impl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-04-28 11:42:22 -0700
committerGitHub <noreply@github.com>2025-04-28 11:42:22 -0700
commitc39c29bf4c52a85d7c83cc8b66ae45e265f9e078 (patch)
tree969339828d49d7db92ed9294a17bd34cc021db84 /source/slang/slang-vm-inst-impl.cpp
parent8f6c6e333c06ae1c3b9f00396563c14a2ae09b4d (diff)
Add Slang Byte Code generation and interpreter. (#6896)
* Add Slang Byte Code generation and interpreter. * Fix compile issues. * format code * More compile fix. * Fix clang issue. * Fix more clang issues. * Another clang fix. * Fix clang issues. * Fix another clang issue. * Fix wasm build. * Update building.md * Fix test-server. * Fix compile error. * Fix bug. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source/slang/slang-vm-inst-impl.cpp')
-rw-r--r--source/slang/slang-vm-inst-impl.cpp1066
1 files changed, 1066 insertions, 0 deletions
diff --git a/source/slang/slang-vm-inst-impl.cpp b/source/slang/slang-vm-inst-impl.cpp
new file mode 100644
index 000000000..304ea09a8
--- /dev/null
+++ b/source/slang/slang-vm-inst-impl.cpp
@@ -0,0 +1,1066 @@
+#include "slang-vm-inst-impl.h"
+
+#include "slang-vm.h"
+
+using namespace slang;
+
+namespace Slang
+{
+ByteCodeInterpreter* convert(IByteCodeRunner* runner)
+{
+ return static_cast<ByteCodeInterpreter*>(runner);
+}
+
+#define SIMPLE_BINARY_SCALAR_FUNC(name, op) \
+ struct name##ScalarFunc \
+ { \
+ template<typename TR, typename T1, typename T2> \
+ static void run(TR* dst, const T1* src1, const T2* src2) \
+ { \
+ *dst = (*src1)op(*src2); \
+ } \
+ }
+
+SIMPLE_BINARY_SCALAR_FUNC(Add, +);
+SIMPLE_BINARY_SCALAR_FUNC(Sub, -);
+SIMPLE_BINARY_SCALAR_FUNC(Mul, *);
+SIMPLE_BINARY_SCALAR_FUNC(Div, /);
+SIMPLE_BINARY_SCALAR_FUNC(And, &&);
+SIMPLE_BINARY_SCALAR_FUNC(Or, ||);
+SIMPLE_BINARY_SCALAR_FUNC(BitAnd, &);
+SIMPLE_BINARY_SCALAR_FUNC(BitOr, |);
+SIMPLE_BINARY_SCALAR_FUNC(BitXor, ^);
+SIMPLE_BINARY_SCALAR_FUNC(Shl, <<);
+SIMPLE_BINARY_SCALAR_FUNC(Shr, >>);
+SIMPLE_BINARY_SCALAR_FUNC(Less, <);
+SIMPLE_BINARY_SCALAR_FUNC(Leq, <=);
+SIMPLE_BINARY_SCALAR_FUNC(Greater, >);
+SIMPLE_BINARY_SCALAR_FUNC(Geq, >=);
+SIMPLE_BINARY_SCALAR_FUNC(Equal, ==);
+SIMPLE_BINARY_SCALAR_FUNC(Neq, !=);
+
+template<typename TR, typename T1, typename T2>
+void scalarMod(TR* dst, const T1* src1, const T2* src2)
+{
+ *dst = *src1 % *src2;
+}
+
+template<>
+void scalarMod<float, float, float>(float* dst, const float* src1, const float* src2)
+{
+ *dst = fmodf(*src1, *src2);
+}
+
+template<>
+void scalarMod<double, double, double>(double* dst, const double* src1, const double* src2)
+{
+ *dst = fmod(*src1, *src2);
+}
+
+struct ModScalarFunc
+{
+ template<typename TR, typename T1, typename T2>
+ static void run(TR* dst, const T1* src1, const T2* src2)
+ {
+ scalarMod<TR, T1, T2>(dst, src1, src2);
+ }
+};
+
+#define SIMPLE_UNARY_SCALAR_FUNC(name, op) \
+ struct name##ScalarFunc \
+ { \
+ template<typename TR, typename T1> \
+ static void run(TR* dst, const T1* src1) \
+ { \
+ *dst = op(*src1); \
+ } \
+ }
+SIMPLE_UNARY_SCALAR_FUNC(Neg, -);
+SIMPLE_UNARY_SCALAR_FUNC(Not, !);
+SIMPLE_UNARY_SCALAR_FUNC(BitNot, ~);
+
+template<typename ScalarFunc, typename TR, typename T1, typename T2, int elementCount>
+struct BinaryVectorFunc
+{
+ static void run(IByteCodeRunner* context, VMExecInstHeader* inst, void* userData)
+ {
+ SLANG_UNUSED(context);
+ SLANG_UNUSED(userData);
+ TR* dst = (TR*)inst->getOperand(0).getPtr();
+ T1* src1 = (T1*)inst->getOperand(1).getPtr();
+ T2* src2 = (T2*)inst->getOperand(2).getPtr();
+ for (int i = 0; i < elementCount; ++i)
+ {
+ ScalarFunc::template run<TR, T1, T2>(&dst[i], &src1[i], &src2[i]);
+ }
+ }
+};
+
+template<typename ScalarFunc, typename TR, typename T1, typename T2>
+struct GeneralBinaryVectorFunc
+{
+ static void run(IByteCodeRunner* context, VMExecInstHeader* inst, void* userData)
+ {
+ SLANG_UNUSED(context);
+ SLANG_UNUSED(userData);
+ TR* dst = (TR*)inst->getOperand(0).getPtr();
+ T1* src1 = (T1*)inst->getOperand(1).getPtr();
+ T2* src2 = (T2*)inst->getOperand(2).getPtr();
+ ArithmeticExtCode arithExtCode;
+ memcpy(&arithExtCode, &inst->opcodeExtension, sizeof(arithExtCode));
+ for (uint32_t i = 0; i < arithExtCode.vectorSize; ++i)
+ {
+ ScalarFunc::template run<TR, T1, T2>(&dst[i], &src1[i], &src2[i]);
+ }
+ }
+};
+
+template<typename Func, typename TR, typename T1 = TR, typename T2 = TR>
+VMExtFunction binaryArithmeticInstHandler(int elementCount)
+{
+ switch (elementCount)
+ {
+ case 0:
+ case 1:
+ return BinaryVectorFunc<Func, TR, T1, T2, 1>::run;
+ case 2:
+ return BinaryVectorFunc<Func, TR, T1, T2, 2>::run;
+ case 3:
+ return BinaryVectorFunc<Func, TR, T1, T2, 3>::run;
+ case 4:
+ return BinaryVectorFunc<Func, TR, T1, T2, 4>::run;
+ case 6:
+ return BinaryVectorFunc<Func, TR, T1, T2, 6>::run;
+ case 8:
+ return BinaryVectorFunc<Func, TR, T1, T2, 8>::run;
+ case 9:
+ return BinaryVectorFunc<Func, TR, T1, T2, 9>::run;
+ case 10:
+ return BinaryVectorFunc<Func, TR, T1, T2, 10>::run;
+ case 12:
+ return BinaryVectorFunc<Func, TR, T1, T2, 12>::run;
+ case 16:
+ return BinaryVectorFunc<Func, TR, T1, T2, 16>::run;
+ default:
+ return GeneralBinaryVectorFunc<Func, TR, T1, T2>::run;
+ }
+}
+
+template<typename Func>
+VMExtFunction binaryArithmeticInstHandler(uint32_t extCode)
+{
+ ArithmeticExtCode arithExtCode;
+ memcpy(&arithExtCode, &extCode, sizeof(arithExtCode));
+ switch (arithExtCode.scalarType)
+ {
+ case kSlangByteCodeScalarTypeSignedInt:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return binaryArithmeticInstHandler<Func, int8_t>(arithExtCode.vectorSize);
+ case 1:
+ return binaryArithmeticInstHandler<Func, int16_t>(arithExtCode.vectorSize);
+ case 2:
+ return binaryArithmeticInstHandler<Func, int32_t>(arithExtCode.vectorSize);
+ case 3:
+ return binaryArithmeticInstHandler<Func, int64_t>(arithExtCode.vectorSize);
+ }
+ case kSlangByteCodeScalarTypeUnsignedInt:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return binaryArithmeticInstHandler<Func, uint8_t>(arithExtCode.vectorSize);
+ case 1:
+ return binaryArithmeticInstHandler<Func, uint16_t>(arithExtCode.vectorSize);
+ case 2:
+ return binaryArithmeticInstHandler<Func, uint32_t>(arithExtCode.vectorSize);
+ case 3:
+ return binaryArithmeticInstHandler<Func, uint64_t>(arithExtCode.vectorSize);
+ }
+ case kSlangByteCodeScalarTypeFloat:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 2:
+ return binaryArithmeticInstHandler<Func, float>(arithExtCode.vectorSize);
+ case 3:
+ return binaryArithmeticInstHandler<Func, double>(arithExtCode.vectorSize);
+ default:
+ return nullptr; // Unsupported scalar bit width
+ }
+ }
+ return nullptr;
+}
+
+template<typename Func>
+VMExtFunction binaryArithmeticLogicalInstHandler(uint32_t extCode)
+{
+ ArithmeticExtCode arithExtCode;
+ memcpy(&arithExtCode, &extCode, sizeof(arithExtCode));
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return binaryArithmeticInstHandler<Func, uint8_t>(arithExtCode.vectorSize);
+ case 1:
+ return binaryArithmeticInstHandler<Func, uint16_t>(arithExtCode.vectorSize);
+ case 2:
+ return binaryArithmeticInstHandler<Func, uint32_t>(arithExtCode.vectorSize);
+ case 3:
+ return binaryArithmeticInstHandler<Func, uint64_t>(arithExtCode.vectorSize);
+ }
+ return nullptr;
+}
+
+template<typename Func>
+VMExtFunction binaryArithmeticIntInstHandler(uint32_t extCode)
+{
+ ArithmeticExtCode arithExtCode;
+ memcpy(&arithExtCode, &extCode, sizeof(arithExtCode));
+ switch (arithExtCode.scalarType)
+ {
+ case kSlangByteCodeScalarTypeSignedInt:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return binaryArithmeticInstHandler<Func, int8_t>(arithExtCode.vectorSize);
+ case 1:
+ return binaryArithmeticInstHandler<Func, int16_t>(arithExtCode.vectorSize);
+ case 2:
+ return binaryArithmeticInstHandler<Func, int32_t>(arithExtCode.vectorSize);
+ case 3:
+ return binaryArithmeticInstHandler<Func, int64_t>(arithExtCode.vectorSize);
+ }
+ case kSlangByteCodeScalarTypeUnsignedInt:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return binaryArithmeticInstHandler<Func, uint8_t>(arithExtCode.vectorSize);
+ case 1:
+ return binaryArithmeticInstHandler<Func, uint16_t>(arithExtCode.vectorSize);
+ case 2:
+ return binaryArithmeticInstHandler<Func, uint32_t>(arithExtCode.vectorSize);
+ case 3:
+ return binaryArithmeticInstHandler<Func, uint64_t>(arithExtCode.vectorSize);
+ }
+ }
+ return nullptr;
+}
+
+template<typename Func>
+VMExtFunction binaryArithmeticCompareInstHandler(uint32_t extCode)
+{
+ ArithmeticExtCode arithExtCode;
+ memcpy(&arithExtCode, &extCode, sizeof(arithExtCode));
+ switch (arithExtCode.scalarType)
+ {
+ case kSlangByteCodeScalarTypeSignedInt:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return binaryArithmeticInstHandler<Func, uint32_t, int8_t, int8_t>(
+ arithExtCode.vectorSize);
+ case 1:
+ return binaryArithmeticInstHandler<Func, uint32_t, int16_t, int16_t>(
+ arithExtCode.vectorSize);
+ case 2:
+ return binaryArithmeticInstHandler<Func, uint32_t, int32_t, int32_t>(
+ arithExtCode.vectorSize);
+ case 3:
+ return binaryArithmeticInstHandler<Func, uint32_t, int64_t, int64_t>(
+ arithExtCode.vectorSize);
+ }
+ case kSlangByteCodeScalarTypeUnsignedInt:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return binaryArithmeticInstHandler<Func, uint32_t, uint8_t, uint8_t>(
+ arithExtCode.vectorSize);
+ case 1:
+ return binaryArithmeticInstHandler<Func, uint32_t, uint16_t, uint16_t>(
+ arithExtCode.vectorSize);
+ case 2:
+ return binaryArithmeticInstHandler<Func, uint32_t, uint32_t, uint32_t>(
+ arithExtCode.vectorSize);
+ case 3:
+ return binaryArithmeticInstHandler<Func, uint32_t, uint64_t, uint64_t>(
+ arithExtCode.vectorSize);
+ }
+ case kSlangByteCodeScalarTypeFloat:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 2:
+ return binaryArithmeticInstHandler<Func, uint32_t, float, float>(
+ arithExtCode.vectorSize);
+ case 3:
+ return binaryArithmeticInstHandler<Func, uint32_t, double, double>(
+ arithExtCode.vectorSize);
+ default:
+ return nullptr; // Unsupported scalar bit width
+ }
+ }
+ return nullptr;
+}
+
+////////
+template<typename ScalarFunc, typename TR, typename T1, int elementCount>
+struct UnaryVectorFunc
+{
+ static void run(IByteCodeRunner* context, VMExecInstHeader* inst, void* userData)
+ {
+ SLANG_UNUSED(context);
+ SLANG_UNUSED(userData);
+ TR* dst = (TR*)inst->getOperand(0).getPtr();
+ T1* src1 = (T1*)inst->getOperand(1).getPtr();
+ for (int i = 0; i < elementCount; ++i)
+ {
+ ScalarFunc::template run<TR, T1>(&dst[i], &src1[i]);
+ }
+ }
+};
+
+template<typename ScalarFunc, typename TR, typename T1>
+struct GeneralUnaryVectorFunc
+{
+ static void run(IByteCodeRunner* context, VMExecInstHeader* inst, void* userData)
+ {
+ SLANG_UNUSED(context);
+ SLANG_UNUSED(userData);
+ TR* dst = (TR*)inst->getOperand(0).getPtr();
+ T1* src1 = (T1*)inst->getOperand(1).getPtr();
+ ArithmeticExtCode arithExtCode;
+ memcpy(&arithExtCode, &inst->opcodeExtension, sizeof(arithExtCode));
+ for (uint32_t i = 0; i < arithExtCode.vectorSize; ++i)
+ {
+ ScalarFunc::template run<TR, T1>(&dst[i], &src1[i]);
+ }
+ }
+};
+
+template<typename Func, typename TR, typename T1 = TR>
+VMExtFunction unaryArithmeticInstHandler(int elementCount)
+{
+ switch (elementCount)
+ {
+ case 0:
+ case 1:
+ return UnaryVectorFunc<Func, TR, T1, 1>::run;
+ case 2:
+ return UnaryVectorFunc<Func, TR, T1, 2>::run;
+ case 3:
+ return UnaryVectorFunc<Func, TR, T1, 3>::run;
+ case 4:
+ return UnaryVectorFunc<Func, TR, T1, 4>::run;
+ case 6:
+ return UnaryVectorFunc<Func, TR, T1, 6>::run;
+ case 8:
+ return UnaryVectorFunc<Func, TR, T1, 8>::run;
+ case 9:
+ return UnaryVectorFunc<Func, TR, T1, 9>::run;
+ case 10:
+ return UnaryVectorFunc<Func, TR, T1, 10>::run;
+ case 12:
+ return UnaryVectorFunc<Func, TR, T1, 12>::run;
+ case 16:
+ return UnaryVectorFunc<Func, TR, T1, 16>::run;
+ default:
+ return GeneralUnaryVectorFunc<Func, TR, T1>::run;
+ }
+}
+
+template<typename Func>
+VMExtFunction unaryArithmeticLogicalInstHandler(uint32_t extCode)
+{
+ ArithmeticExtCode arithExtCode;
+ memcpy(&arithExtCode, &extCode, sizeof(arithExtCode));
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return unaryArithmeticInstHandler<Func, uint8_t>(arithExtCode.vectorSize);
+ case 1:
+ return unaryArithmeticInstHandler<Func, uint16_t>(arithExtCode.vectorSize);
+ case 2:
+ return unaryArithmeticInstHandler<Func, uint32_t>(arithExtCode.vectorSize);
+ case 3:
+ return unaryArithmeticInstHandler<Func, uint64_t>(arithExtCode.vectorSize);
+ }
+ return nullptr;
+}
+
+template<typename Func>
+VMExtFunction unaryArithmeticIntInstHandler(uint32_t extCode)
+{
+ ArithmeticExtCode arithExtCode;
+ memcpy(&arithExtCode, &extCode, sizeof(arithExtCode));
+ switch (arithExtCode.scalarType)
+ {
+ case kSlangByteCodeScalarTypeSignedInt:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return unaryArithmeticInstHandler<Func, int8_t>(arithExtCode.vectorSize);
+ case 1:
+ return unaryArithmeticInstHandler<Func, int16_t>(arithExtCode.vectorSize);
+ case 2:
+ return unaryArithmeticInstHandler<Func, int32_t>(arithExtCode.vectorSize);
+ case 3:
+ return unaryArithmeticInstHandler<Func, int64_t>(arithExtCode.vectorSize);
+ }
+ case kSlangByteCodeScalarTypeUnsignedInt:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return unaryArithmeticInstHandler<Func, uint8_t>(arithExtCode.vectorSize);
+ case 1:
+ return unaryArithmeticInstHandler<Func, uint16_t>(arithExtCode.vectorSize);
+ case 2:
+ return unaryArithmeticInstHandler<Func, uint32_t>(arithExtCode.vectorSize);
+ case 3:
+ return unaryArithmeticInstHandler<Func, uint64_t>(arithExtCode.vectorSize);
+ }
+ }
+ return nullptr;
+}
+
+template<typename Func>
+VMExtFunction negInstHandler(uint32_t extCode)
+{
+ ArithmeticExtCode arithExtCode;
+ memcpy(&arithExtCode, &extCode, sizeof(arithExtCode));
+ switch (arithExtCode.scalarType)
+ {
+ case kSlangByteCodeScalarTypeSignedInt:
+ case kSlangByteCodeScalarTypeUnsignedInt:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return unaryArithmeticInstHandler<Func, int8_t>(arithExtCode.vectorSize);
+ case 1:
+ return unaryArithmeticInstHandler<Func, int16_t>(arithExtCode.vectorSize);
+ case 2:
+ return unaryArithmeticInstHandler<Func, int32_t>(arithExtCode.vectorSize);
+ case 3:
+ return unaryArithmeticInstHandler<Func, int64_t>(arithExtCode.vectorSize);
+ }
+ case kSlangByteCodeScalarTypeFloat:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 2:
+ return unaryArithmeticInstHandler<Func, float>(arithExtCode.vectorSize);
+ case 3:
+ return unaryArithmeticInstHandler<Func, double>(arithExtCode.vectorSize);
+ default:
+ return nullptr; // Unsupported scalar bit width
+ }
+ }
+ return nullptr;
+}
+
+static void nopHandler(IByteCodeRunner*, VMExecInstHeader*, void*) {}
+
+void callHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void*)
+{
+ auto ctx = convert(inCtx);
+ auto funcId = inst->getOperand(1).offset;
+ auto& func = ctx->m_functions[funcId];
+ auto funcHeader = func.m_header;
+
+ // Alloc working set.
+ ctx->pushFrame(funcHeader->workingSetSizeInBytes);
+
+ // Save current instruction pointer.
+ auto& stackFrame = ctx->m_stack.getLast();
+ stackFrame.m_currentInst = inst;
+ stackFrame.m_currentFuncCode = ctx->m_currentFuncCode;
+ auto newWorkingSetPtr = (uint8_t*)ctx->m_currentWorkingSet;
+ auto callerWorkingSetPtr =
+ (uint8_t*)(ctx->m_workingSetBuffer.getBuffer() + stackFrame.m_workingSetOffset);
+
+ // Set working set pointer to the caller's working set.
+ ctx->m_currentWorkingSet = callerWorkingSetPtr;
+
+ // Copy arguments to the callee's working set.
+ for (uint32_t i = 0; i < funcHeader->parameterCount; ++i)
+ {
+ auto dst = newWorkingSetPtr + func.m_parameterOffsets[i];
+ auto src = (uint8_t*)inst->getOperand(i + 2).getPtr();
+
+ // func.m_parameterOffsets should be initialized to contain parameterCount+1 elements,
+ // where the last element is the total size of the parameters.
+ auto nextParamOffset = func.m_parameterOffsets[i + 1];
+ memcpy(dst, src, nextParamOffset - func.m_parameterOffsets[i]);
+ }
+ ctx->m_currentWorkingSet = newWorkingSetPtr;
+ ctx->m_currentFuncCode = func.m_codeBuffer.getBuffer();
+ ctx->m_currentInst = (VMExecInstHeader*)func.m_codeBuffer.getBuffer();
+}
+
+static void retHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void*)
+{
+ auto ctx = convert(inCtx);
+ if (inst->opcodeExtension != 0)
+ {
+ void* resultPtr = nullptr;
+ if (ctx->m_stack.getCount())
+ {
+ auto callInst = ctx->m_stack.getLast().m_currentInst;
+ auto callerWorkingSetPtr = (uint8_t*)(ctx->m_workingSetBuffer.getBuffer() +
+ ctx->m_stack.getLast().m_workingSetOffset);
+ resultPtr = callerWorkingSetPtr + callInst->getOperand(0).offset;
+ }
+ else
+ {
+ // If there is no stack frame, we assume the result is stored in the return register.
+ ctx->m_returnRegister.setCount(inst->opcodeExtension);
+ resultPtr = ctx->m_returnRegister.getBuffer();
+ ctx->m_returnValSize = inst->opcodeExtension;
+ }
+ memcpy(resultPtr, inst->getOperand(0).getPtr(), inst->opcodeExtension);
+ }
+
+ // If we are returning from a main function, there is nothing to pop from the stack frame,
+ // and we should stop execution.
+ if (ctx->m_stack.getCount() == 0)
+ {
+ ctx->m_currentInst = nullptr;
+ return;
+ }
+
+ // Pop the working set.
+ ctx->popFrame();
+}
+
+static void jumpHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void*)
+{
+ auto ctx = convert(inCtx);
+ ctx->m_currentInst = (VMExecInstHeader*)inst->getOperand(0).getPtr();
+}
+
+static void jumpIfHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void*)
+{
+ auto ctx = convert(inCtx);
+
+ auto cond = *(uint32_t*)inst->getOperand(0).getPtr();
+ if (cond)
+ {
+ ctx->m_currentInst = (VMExecInstHeader*)inst->getOperand(1).getPtr();
+ }
+ else
+ {
+ ctx->m_currentInst = (VMExecInstHeader*)inst->getOperand(2).getPtr();
+ }
+}
+
+static void getWorkingSetPtrHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void*)
+{
+ auto ctx = convert(inCtx);
+ auto dst = (void**)inst->getOperand(0).getPtr();
+ auto ptr = (uint8_t*)ctx->m_currentWorkingSet + inst->opcodeExtension;
+ *dst = ptr;
+}
+
+static void getElementPtrHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = (void**)inst->getOperand(0).getPtr();
+ auto basePtr = *(uint8_t**)inst->getOperand(1).getPtr();
+ auto elementIndex = *(uint32_t*)inst->getOperand(2).getPtr();
+ *dst = (uint8_t*)basePtr + elementIndex * inst->opcodeExtension;
+}
+
+static void getElementHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = (void*)inst->getOperand(0).getPtr();
+ auto basePtr = (uint8_t*)inst->getOperand(1).getPtr();
+ auto elementIndex = *(uint32_t*)inst->getOperand(2).getPtr();
+ memcpy(dst, basePtr + elementIndex * inst->opcodeExtension, inst->opcodeExtension);
+}
+
+static void offsetPtrHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = (void**)inst->getOperand(0).getPtr();
+ auto basePtr = *(uint8_t**)inst->getOperand(1).getPtr();
+ auto offset = *(int32_t*)inst->getOperand(2).getPtr();
+ *dst = basePtr + offset * inst->opcodeExtension;
+}
+
+void loadHandler8(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = (uint8_t*)inst->getOperand(0).getPtr();
+ auto src = *(uint8_t**)inst->getOperand(1).getPtr();
+ *dst = *src;
+}
+void loadHandler16(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = (uint16_t*)inst->getOperand(0).getPtr();
+ auto src = *(uint16_t**)inst->getOperand(1).getPtr();
+ *dst = *src;
+}
+void loadHandler32(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = (uint32_t*)inst->getOperand(0).getPtr();
+ auto src = *(uint32_t**)inst->getOperand(1).getPtr();
+ *dst = *src;
+}
+void loadHandler64(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = (uint64_t*)inst->getOperand(0).getPtr();
+ auto src = *(uint64_t**)inst->getOperand(1).getPtr();
+ *dst = *src;
+}
+
+void generalLoadHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = (uint8_t*)inst->getOperand(0).getPtr();
+ auto src = *(uint8_t**)inst->getOperand(1).getPtr();
+ memcpy(dst, src, inst->opcodeExtension);
+}
+
+VMExtFunction getLoadHandler(uint32_t extCode)
+{
+ switch (extCode)
+ {
+ case 1:
+ return loadHandler8;
+ case 2:
+ return loadHandler16;
+ case 4:
+ return loadHandler32;
+ case 8:
+ return loadHandler64;
+ default:
+ return generalLoadHandler;
+ }
+}
+
+void storeHandler8(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = *(uint8_t**)inst->getOperand(0).getPtr();
+ auto src = (uint8_t*)inst->getOperand(1).getPtr();
+ *dst = *src;
+}
+
+void storeHandler16(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = *(uint16_t**)inst->getOperand(0).getPtr();
+ auto src = (uint16_t*)inst->getOperand(1).getPtr();
+ *dst = *src;
+}
+
+void storeHandler32(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = *(uint32_t**)inst->getOperand(0).getPtr();
+ auto src = (uint32_t*)inst->getOperand(1).getPtr();
+ *dst = *src;
+}
+
+void storeHandler64(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = *(uint64_t**)inst->getOperand(0).getPtr();
+ auto src = (uint64_t*)inst->getOperand(1).getPtr();
+ *dst = *src;
+}
+
+void generalStoreHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = *(uint8_t**)inst->getOperand(0).getPtr();
+ auto src = (uint8_t*)inst->getOperand(1).getPtr();
+ memcpy(dst, src, inst->opcodeExtension);
+}
+
+VMExtFunction getStoreHandler(uint32_t extCode)
+{
+ switch (extCode)
+ {
+ case 1:
+ return storeHandler8;
+ case 2:
+ return storeHandler16;
+ case 4:
+ return storeHandler32;
+ case 8:
+ return storeHandler64;
+ default:
+ return generalStoreHandler;
+ }
+}
+
+void copyHandler8(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = (uint8_t*)inst->getOperand(0).getPtr();
+ auto src = (uint8_t*)inst->getOperand(1).getPtr();
+ *dst = *src;
+}
+
+void copyHandler16(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = (uint16_t*)inst->getOperand(0).getPtr();
+ auto src = (uint16_t*)inst->getOperand(1).getPtr();
+ *dst = *src;
+}
+
+void copyHandler32(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = (uint32_t*)inst->getOperand(0).getPtr();
+ auto src = (uint32_t*)inst->getOperand(1).getPtr();
+ *dst = *src;
+}
+
+void copyHandler64(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = (uint64_t*)inst->getOperand(0).getPtr();
+ auto src = (uint64_t*)inst->getOperand(1).getPtr();
+ *dst = *src;
+}
+
+void generalCopyHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ auto dst = (uint8_t*)inst->getOperand(0).getPtr();
+ auto src = (uint8_t*)inst->getOperand(1).getPtr();
+ memcpy(dst, src, inst->opcodeExtension);
+}
+
+VMExtFunction getCopyHandler(uint32_t extCode)
+{
+ switch (extCode)
+ {
+ case 1:
+ return copyHandler8;
+ case 2:
+ return copyHandler16;
+ case 4:
+ return copyHandler32;
+ case 8:
+ return copyHandler64;
+ default:
+ return generalCopyHandler;
+ }
+}
+
+template<typename T>
+void swizzleHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void* userData)
+{
+ SLANG_UNUSED(ctx);
+ SLANG_UNUSED(userData);
+ auto dst = (T*)inst->getOperand(0).getPtr();
+ auto src = (T*)inst->getOperand(1).getPtr();
+ for (uint32_t i = 2; i < inst->operandCount; ++i)
+ {
+ dst[i - 2] = src[inst->getOperand(i).offset];
+ }
+}
+
+VMExtFunction getSwizzleHandler(uint32_t extCode)
+{
+ ArithmeticExtCode arithExtCode;
+ memcpy(&arithExtCode, &extCode, sizeof(arithExtCode));
+ switch (arithExtCode.scalarType)
+ {
+ case kSlangByteCodeScalarTypeSignedInt:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return swizzleHandler<int8_t>;
+ case 1:
+ return swizzleHandler<int16_t>;
+ case 2:
+ return swizzleHandler<int32_t>;
+ case 3:
+ return swizzleHandler<int64_t>;
+ }
+ case kSlangByteCodeScalarTypeUnsignedInt:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return swizzleHandler<uint8_t>;
+ case 1:
+ return swizzleHandler<uint16_t>;
+ case 2:
+ return swizzleHandler<uint32_t>;
+ case 3:
+ return swizzleHandler<uint64_t>;
+ }
+ case kSlangByteCodeScalarTypeFloat:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 2:
+ return swizzleHandler<float>;
+ case 3:
+ return swizzleHandler<double>;
+ default:
+ return nullptr; // Unsupported scalar bit width
+ }
+ }
+ return nullptr;
+}
+
+template<typename To, typename From, int vectorSize>
+void castHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*)
+{
+ SLANG_UNUSED(ctx);
+ To* dst = (To*)inst->getOperand(0).getPtr();
+ From* src = (From*)inst->getOperand(1).getPtr();
+ for (int i = 0; i < vectorSize; ++i)
+ {
+ dst[i] = static_cast<To>(src[i]);
+ }
+}
+
+template<typename From, int vectorSize>
+VMExtFunction getCastHandler(uint32_t extCode)
+{
+ ArithmeticExtCode arithExtCode;
+ memcpy(&arithExtCode, &extCode, sizeof(arithExtCode));
+ switch (arithExtCode.scalarType)
+ {
+ case kSlangByteCodeScalarTypeSignedInt:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return castHandler<uint8_t, From, vectorSize>;
+ case 1:
+ return castHandler<uint16_t, From, vectorSize>;
+ case 2:
+ return castHandler<uint32_t, From, vectorSize>;
+ case 3:
+ return castHandler<uint64_t, From, vectorSize>;
+ }
+ case kSlangByteCodeScalarTypeUnsignedInt:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return castHandler<uint8_t, From, vectorSize>;
+ case 1:
+ return castHandler<uint16_t, From, vectorSize>;
+ case 2:
+ return castHandler<uint32_t, From, vectorSize>;
+ case 3:
+ return castHandler<uint64_t, From, vectorSize>;
+ }
+ case kSlangByteCodeScalarTypeFloat:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 2:
+ return castHandler<float, From, vectorSize>;
+ case 3:
+ return castHandler<double, From, vectorSize>;
+ default:
+ return nullptr; // Unsupported scalar bit width
+ }
+ }
+ return nullptr;
+}
+
+template<int vectorSize>
+VMExtFunction getCastHandler(uint32_t extCode)
+{
+ uint32_t fromExtCode = extCode >> 16;
+ ArithmeticExtCode arithExtCode;
+ memcpy(&arithExtCode, &fromExtCode, sizeof(arithExtCode));
+ switch (arithExtCode.scalarType)
+ {
+ case kSlangByteCodeScalarTypeSignedInt:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return getCastHandler<uint8_t, vectorSize>(extCode);
+ case 1:
+ return getCastHandler<uint16_t, vectorSize>(extCode);
+ case 2:
+ return getCastHandler<uint32_t, vectorSize>(extCode);
+ case 3:
+ return getCastHandler<uint64_t, vectorSize>(extCode);
+ }
+ case kSlangByteCodeScalarTypeUnsignedInt:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 0:
+ return getCastHandler<uint8_t, vectorSize>(extCode);
+ case 1:
+ return getCastHandler<uint16_t, vectorSize>(extCode);
+ case 2:
+ return getCastHandler<uint32_t, vectorSize>(extCode);
+ case 3:
+ return getCastHandler<uint64_t, vectorSize>(extCode);
+ }
+ case kSlangByteCodeScalarTypeFloat:
+ switch (arithExtCode.scalarBitWidth)
+ {
+ case 2:
+ return getCastHandler<float, vectorSize>(extCode);
+ case 3:
+ return getCastHandler<double, vectorSize>(extCode);
+ default:
+ return nullptr; // Unsupported scalar bit width
+ }
+ }
+ return nullptr;
+}
+
+VMExtFunction getCastHandler(uint32_t extCode)
+{
+ uint32_t fromExtCode = extCode >> 16;
+ ArithmeticExtCode arithExtCode;
+ memcpy(&arithExtCode, &fromExtCode, sizeof(arithExtCode));
+ switch (arithExtCode.vectorSize)
+ {
+ case 0:
+ case 1:
+ return getCastHandler<1>(extCode);
+ case 2:
+ return getCastHandler<2>(extCode);
+ case 3:
+ return getCastHandler<3>(extCode);
+ case 4:
+ return getCastHandler<4>(extCode);
+ case 6:
+ return getCastHandler<6>(extCode);
+ case 8:
+ return getCastHandler<8>(extCode);
+ case 9:
+ return getCastHandler<9>(extCode);
+ case 12:
+ return getCastHandler<12>(extCode);
+ case 16:
+ return getCastHandler<16>(extCode);
+ }
+ return nullptr;
+}
+
+void printHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void* userData)
+{
+ auto ctx = convert(inCtx);
+ SLANG_UNUSED(userData);
+ const char* formatString = nullptr;
+ formatString = *(const char**)inst->getOperand(0).getPtr();
+
+ List<List<uint8_t>> args;
+ List<const void*> argPtrs;
+ for (uint32_t i = 1; i < inst->operandCount; ++i)
+ {
+ auto& arg = inst->getOperand(i);
+ List<uint8_t> data;
+ data.setCount(arg.size);
+ memcpy(data.getBuffer(), arg.getPtr(), arg.size);
+ args.add(data);
+ }
+ for (auto& arg : args)
+ {
+ argPtrs.add(arg.getBuffer());
+ }
+ auto result =
+ StringUtil::makeStringWithFormatFromArgArray(formatString, argPtrs.getArrayView());
+ ctx->m_printCallback(result.getBuffer(), ctx->m_printCallbackUserData);
+}
+
+
+VMExtFunction mapInstToFunction(
+ VMInstHeader* instHeader,
+ VMModuleView* module,
+ Dictionary<String, slang::VMExtFunction>& extInstHandlers)
+{
+ switch (instHeader->opcode)
+ {
+ case VMOp::Nop:
+ return nopHandler;
+ case VMOp::Add:
+ return binaryArithmeticInstHandler<AddScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::Sub:
+ return binaryArithmeticInstHandler<SubScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::Mul:
+ return binaryArithmeticInstHandler<MulScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::Div:
+ return binaryArithmeticInstHandler<DivScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::Rem:
+ return binaryArithmeticInstHandler<ModScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::And:
+ return binaryArithmeticLogicalInstHandler<AndScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::Or:
+ return binaryArithmeticLogicalInstHandler<OrScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::BitAnd:
+ return binaryArithmeticLogicalInstHandler<BitAndScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::BitOr:
+ return binaryArithmeticLogicalInstHandler<BitOrScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::BitXor:
+ return binaryArithmeticLogicalInstHandler<BitXorScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::Shl:
+ return binaryArithmeticIntInstHandler<ShlScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::Shr:
+ return binaryArithmeticIntInstHandler<ShrScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::Less:
+ return binaryArithmeticCompareInstHandler<LessScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::Leq:
+ return binaryArithmeticCompareInstHandler<LeqScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::Greater:
+ return binaryArithmeticCompareInstHandler<GreaterScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::Geq:
+ return binaryArithmeticCompareInstHandler<GeqScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::Equal:
+ return binaryArithmeticCompareInstHandler<EqualScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::Neq:
+ return binaryArithmeticCompareInstHandler<NeqScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::Neg:
+ return negInstHandler<NegScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::Not:
+ return unaryArithmeticLogicalInstHandler<NotScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::BitNot:
+ return unaryArithmeticIntInstHandler<BitNotScalarFunc>(instHeader->opcodeExtension);
+ case VMOp::Ret:
+ return retHandler;
+ case VMOp::Call:
+ return callHandler;
+ case VMOp::Jump:
+ return jumpHandler;
+ case VMOp::JumpIf:
+ return jumpIfHandler;
+ case VMOp::Load:
+ return getLoadHandler(instHeader->opcodeExtension);
+ case VMOp::Store:
+ return getStoreHandler(instHeader->opcodeExtension);
+ case VMOp::Copy:
+ return getCopyHandler(instHeader->opcodeExtension);
+ case VMOp::GetWorkingSetPtr:
+ return getWorkingSetPtrHandler;
+ case VMOp::GetElementPtr:
+ return getElementPtrHandler;
+ case VMOp::OffsetPtr:
+ return offsetPtrHandler;
+ case VMOp::GetElement:
+ return getElementHandler;
+ case VMOp::Swizzle:
+ return getSwizzleHandler(instHeader->opcodeExtension);
+ case VMOp::Cast:
+ return getCastHandler(instHeader->opcodeExtension);
+ case VMOp::CallExt:
+ {
+ if (instHeader->getOperand(0).offset >= module->stringCount)
+ return nullptr;
+ auto funcName = (const char*)module->constants +
+ module->stringOffsets[instHeader->getOperand(0).offset];
+ VMExtFunction handler = nullptr;
+ if (!extInstHandlers.tryGetValue(funcName, handler))
+ return nullptr;
+ return handler;
+ }
+ case VMOp::Print:
+ return printHandler;
+ }
+ return VMExtFunction();
+}
+
+} // namespace Slang