diff options
| author | Yong He <yonghe@outlook.com> | 2025-04-28 11:42:22 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-28 11:42:22 -0700 |
| commit | c39c29bf4c52a85d7c83cc8b66ae45e265f9e078 (patch) | |
| tree | 969339828d49d7db92ed9294a17bd34cc021db84 /source | |
| parent | 8f6c6e333c06ae1c3b9f00396563c14a2ae09b4d (diff) | |
Add Slang Byte Code generation and interpreter. (#6896)
* Add Slang Byte Code generation and interpreter.
* Fix compile issues.
* format code
* More compile fix.
* Fix clang issue.
* Fix more clang issues.
* Another clang fix.
* Fix clang issues.
* Fix another clang issue.
* Fix wasm build.
* Update building.md
* Fix test-server.
* Fix compile error.
* Fix bug.
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source')
29 files changed, 3779 insertions, 18 deletions
diff --git a/source/compiler-core/slang-artifact-desc-util.cpp b/source/compiler-core/slang-artifact-desc-util.cpp index c1575ba16..5f06853ec 100644 --- a/source/compiler-core/slang-artifact-desc-util.cpp +++ b/source/compiler-core/slang-artifact-desc-util.cpp @@ -325,6 +325,9 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL return Desc::make(Kind::Assembly, Payload::WGSL_SPIRV, Style::Kernel, 0); case SLANG_WGSL_SPIRV: return Desc::make(Kind::ObjectCode, Payload::WGSL_SPIRV, Style::Kernel, 0); + + case SLANG_HOST_VM: + return Desc::make(Kind::ObjectCode, Payload::UniversalCPU, Style::Host, 0); default: break; } diff --git a/source/core/slang-string-escape-util.cpp b/source/core/slang-string-escape-util.cpp index 42e757b23..0645d94ba 100644 --- a/source/core/slang-string-escape-util.cpp +++ b/source/core/slang-string-escape-util.cpp @@ -1184,4 +1184,19 @@ StringEscapeUtil::Handler* StringEscapeUtil::getHandler(Style style) return SLANG_OK; } +String StringEscapeUtil::escapeString(UnownedStringSlice input, StringEscapeUtil::Style style) +{ + StringBuilder sb; + auto handler = StringEscapeUtil::getHandler(style); + StringEscapeUtil::appendQuoted(handler, input, sb); + return sb.produceString(); +} + +String StringEscapeUtil::unescapeString(UnownedStringSlice input, StringEscapeUtil::Style style) +{ + StringBuilder sb; + auto handler = StringEscapeUtil::getHandler(style); + StringEscapeUtil::appendUnquoted(handler, input, sb); + return sb.produceString(); +} } // namespace Slang diff --git a/source/core/slang-string-escape-util.h b/source/core/slang-string-escape-util.h index 6e02d772d..ece8de79f 100644 --- a/source/core/slang-string-escape-util.h +++ b/source/core/slang-string-escape-util.h @@ -113,6 +113,9 @@ struct StringEscapeUtil Handler* handler, const UnownedStringSlice& slice, StringBuilder& out); + + static String escapeString(UnownedStringSlice input, Style style = Style::Slang); + static String unescapeString(UnownedStringSlice input, Style style = Style::Slang); }; diff --git a/source/core/slang-string-util.cpp b/source/core/slang-string-util.cpp index ba234e30b..65758fd14 100644 --- a/source/core/slang-string-util.cpp +++ b/source/core/slang-string-util.cpp @@ -390,6 +390,159 @@ UnownedStringSlice StringUtil::getAtInSplit( return builder; } +template<typename T> +static T readValue(ArrayView<const void*> ptrToArgs, Count& argIndex) +{ + if (argIndex < ptrToArgs.getCount()) + { + T value; + memcpy(&value, ptrToArgs[argIndex], sizeof(T)); + argIndex++; + return value; + } + return T(); +} + +String StringUtil::makeStringWithFormatFromArgArray( + const char* format, + ArrayView<const void*> ptrToArgs) +{ + if (!format) + { + return String(); + } + StringBuilder builder; + const char* ptr = format; + Count argIndex = 0; + auto consumeString = [&]() + { + if (argIndex < ptrToArgs.getCount()) + { + const char* strPtr = *(const char**)ptrToArgs[argIndex]; + argIndex++; + if (strPtr) + { + // Append the string to the builder + builder.append(strPtr); + } + } + }; +#define ADVANCE_PTR \ + ptr++; \ + if (!*ptr) \ + { \ + return builder.produceString(); \ + } + + while (*ptr) + { + if (*ptr == '%') + { + const char* formatStart = ptr; + ADVANCE_PTR; + if (*ptr == 's') + { + // If we have a %s, then we want to append the data + consumeString(); + // Move past the 's' + ADVANCE_PTR; + continue; + } + if (*ptr == '-') + { + // If we have a %- then we want to continue parsing format string. + ADVANCE_PTR; + } + while (CharUtil::isDigit(*ptr)) + { + // Skip the digits after the '.' + ADVANCE_PTR; + } + if (*ptr == '.') + { + ADVANCE_PTR; + while (CharUtil::isDigit(*ptr)) + { + // Skip the digits after the '.' + ADVANCE_PTR; + } + } + int isLong = 0; + if (*ptr == 'l' || *ptr == 'L') + { + // If we have a 'l' or 'L', then we want to skip it. + ADVANCE_PTR; + isLong = 1; + if (*ptr == 'l' || *ptr == 'L') + { + // If we have another 'l' or 'L', then we want to skip it too. + ADVANCE_PTR; + isLong = 2; + } + } + const char typeChar = *ptr; + ADVANCE_PTR; + String formatStr = UnownedStringSlice(formatStart, ptr); + switch (CharUtil::toLower(typeChar)) + { + case 'd': + case 'x': + case 'i': + case 'u': + case 'o': + case 'c': + if (isLong == 2) + { + StringUtil::appendFormat( + builder, + formatStr.getBuffer(), + readValue<int64_t>(ptrToArgs, argIndex)); + } + else + { + StringUtil::appendFormat( + builder, + formatStr.getBuffer(), + readValue<int>(ptrToArgs, argIndex)); + } + break; + case 'e': + case 'f': + case 'g': + if (isLong != 0) + { + StringUtil::appendFormat( + builder, + formatStr.getBuffer(), + readValue<double>(ptrToArgs, argIndex)); + } + else + { + StringUtil::appendFormat( + builder, + formatStr.getBuffer(), + readValue<float>(ptrToArgs, argIndex)); + } + break; + case 'n': + break; + case '%': + // If we have a '%%' then we want to append a single '%' + builder.appendChar('%'); + continue; + } + } + else + { + // Just append the character + builder.appendChar(*ptr); + ptr++; + } + } + return builder.produceString(); +} + + /* static */ UnownedStringSlice StringUtil::getSlice(ISlangBlob* blob) { if (blob) diff --git a/source/core/slang-string-util.h b/source/core/slang-string-util.h index 9e7b2d65a..4f4368ba3 100644 --- a/source/core/slang-string-util.h +++ b/source/core/slang-string-util.h @@ -135,6 +135,11 @@ struct StringUtil /// Create a string from the format string applying args (like sprintf) static String makeStringWithFormat(const char* format, ...); + /// Create a string from the format string and arguments in a buffer. + static String makeStringWithFormatFromArgArray( + const char* format, + ArrayView<const void*> ptrToArgs); + /// Given a string held in a blob, returns as a String /// Returns an empty string if blob is nullptr static String getString(ISlangBlob* blob); diff --git a/source/core/slang-type-text-util.cpp b/source/core/slang-type-text-util.cpp index 01ce0d1ac..2261ed614 100644 --- a/source/core/slang-type-text-util.cpp +++ b/source/core/slang-type-text-util.cpp @@ -87,6 +87,7 @@ static const TypeTextUtil::CompileTargetInfo s_compileTargetInfos[] = { "wgsl-spirv-asm,wgsl-spirv-assembly", "SPIR-V assembly via WebGPU shading language"}, {SLANG_WGSL_SPIRV, "wgsl-spirv", "wgsl-spirv", "SPIR-V via WebGPU shading language"}, + {SLANG_HOST_VM, "slang-vm", "slangvm,slang-vm", "Slang VM byte code"}, }; static const NamesDescriptionValue s_languageInfos[] = { diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 03321bfaf..44b9a8860 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -12155,6 +12155,7 @@ vector<T, N> powr(vector<T, N> x, vector<T, N> y) /// } /// ``` [require(cpp_cuda_glsl_hlsl_spirv, printf)] +[require(slangvm)] __intrinsic_op($(kIROp_Printf)) void printf<each T>(NativeString format, expand each T args); diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index f4ae94978..8f6aa254c 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -115,6 +115,10 @@ def spirv : target; /// [Target] def wgsl : target + textualTarget; +/// Represents the Slang VM bytecode target. +/// [Target] +def slangvm : target; + // Capabilities that stand for target SPIR-V versions for the GLSL backend. // These are not compilation targets. We will convert `_spirv_*` to `glsl_spirv_*` during compilation. @@ -428,6 +432,10 @@ def domain : stage; /// [Stage] def geometry : stage; +/// Dispatch shader stage +/// [Stage] +def dispatch : stage; + def _raygen : stage; alias _raygeneration = _raygen; def _intersection : stage; diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index a8c3aee15..4547281e1 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -448,6 +448,7 @@ void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) } bool canHaveVaryingInput = false; + bool shouldWarnOnNonUniformParam = true; switch (stage) { case Stage::Vertex: @@ -462,6 +463,9 @@ void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) case Stage::Domain: canHaveVaryingInput = true; break; + case Stage::Dispatch: + shouldWarnOnNonUniformParam = false; + break; default: break; } @@ -499,10 +503,13 @@ void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) // support varying input/output. We will automatically convert it to a 'uniform' parameter, // and diagnose a warning. addModifier(param, getCurrentASTBuilder()->create<HLSLUniformModifier>()); - sink->diagnose( - param, - Diagnostics::nonUniformEntryPointParameterTreatedAsUniform, - param->getName()); + if (shouldWarnOnNonUniformParam) + { + sink->diagnose( + param, + Diagnostics::nonUniformEntryPointParameterTreatedAsUniform, + param->getName()); + } } for (auto target : linkage->targets) diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 8e9b8f430..15f22630c 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -468,6 +468,8 @@ Stage getStageFromAtom(CapabilityAtom atom) return Stage::Miss; case CapabilityAtom::_callable: return Stage::Callable; + case CapabilityAtom::dispatch: + return Stage::Dispatch; default: SLANG_UNEXPECTED("unknown stage atom"); UNREACHABLE_RETURN(Stage::Unknown); @@ -1766,6 +1768,8 @@ SlangResult emitSPIRVForEntryPointsDirectly( CodeGenContext* codeGenContext, ComPtr<IArtifact>& outArtifact); +SlangResult emitHostVMCode(CodeGenContext* codeGenContext, ComPtr<IArtifact>& outArtifact); + static CodeGenTarget _getIntermediateTarget(CodeGenTarget target) { switch (target) @@ -1835,7 +1839,9 @@ SlangResult CodeGenContext::_emitEntryPoints(ComPtr<IArtifact>& outArtifact) case CodeGenTarget::WGSLSPIRV: SLANG_RETURN_ON_FAIL(emitWithDownstreamForEntryPoints(outArtifact)); return SLANG_OK; - + case CodeGenTarget::HostVM: + SLANG_RETURN_ON_FAIL(emitHostVMCode(this, outArtifact)); + return SLANG_OK; default: break; } @@ -1887,6 +1893,7 @@ SlangResult CodeGenContext::emitEntryPoints(ComPtr<IArtifact>& outArtifact) case CodeGenTarget::HostExecutable: case CodeGenTarget::HostSharedLibrary: case CodeGenTarget::WGSLSPIRVAssembly: + case CodeGenTarget::HostVM: { SLANG_RETURN_ON_FAIL(_emitEntryPoints(outArtifact)); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index bfae6e400..27c368738 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -92,6 +92,7 @@ enum class CodeGenTarget : SlangCompileTargetIntegral WGSL = SLANG_WGSL, WGSLSPIRVAssembly = SLANG_WGSL_SPIRV_ASM, WGSLSPIRV = SLANG_WGSL_SPIRV, + HostVM = SLANG_HOST_VM, CountOf = SLANG_TARGET_COUNT_OF, }; @@ -1492,7 +1493,7 @@ public: { return SLANG_E_INVALID_ARG; } - + SLANG_AST_BUILDER_RAII(m_astBuilder); ComPtr<slang::IEntryPoint> entryPoint(findEntryPointByName(UnownedStringSlice(name))); if ((!entryPoint)) return SLANG_FAIL; @@ -1511,7 +1512,6 @@ public: { return SLANG_E_INVALID_ARG; } - ComPtr<slang::IEntryPoint> entryPoint( findAndCheckEntryPoint(UnownedStringSlice(name), stage, outDiagnostics)); if ((!entryPoint)) diff --git a/source/slang/slang-emit-slang.cpp b/source/slang/slang-emit-slang.cpp new file mode 100644 index 000000000..175f7ffc5 --- /dev/null +++ b/source/slang/slang-emit-slang.cpp @@ -0,0 +1,17 @@ +#include "slang-emit-slang.h" + +namespace Slang +{ + +SlangResult emitSlangDeclarationsForEntryPoints( + CodeGenContext* codeGenContext, + LinkedIR& linkedIR, + String& outSlangCode) +{ + SLANG_UNUSED(codeGenContext); + SLANG_UNUSED(linkedIR); + SLANG_UNUSED(outSlangCode); + return SLANG_OK; +} + +} // namespace Slang diff --git a/source/slang/slang-emit-slang.h b/source/slang/slang-emit-slang.h new file mode 100644 index 000000000..964e1e52e --- /dev/null +++ b/source/slang/slang-emit-slang.h @@ -0,0 +1,16 @@ +#ifndef SLANG_EMIT_SLANG_H +#define SLANG_EMIT_SLANG_H + +#include "slang-emit-base.h" +#include "slang-ir-link.h" +#include "slang-vm-bytecode.h" + +namespace Slang +{ +SlangResult emitSlangDeclarationsForEntryPoints( + CodeGenContext* codeGenContext, + LinkedIR& linkedIR, + String& outSlangDeclaration); +} + +#endif diff --git a/source/slang/slang-emit-vm.cpp b/source/slang/slang-emit-vm.cpp new file mode 100644 index 000000000..fc0b4432e --- /dev/null +++ b/source/slang/slang-emit-vm.cpp @@ -0,0 +1,1266 @@ +#include "slang-emit-vm.h" + +#include "slang-ir-call-graph.h" +#include "slang-ir-layout.h" +#include "slang-ir-util.h" + +using namespace slang; + +namespace Slang +{ +class ByteCodeEmitter +{ +public: + Dictionary<IRInst*, String> mapInstToName; + Dictionary<String, int> mapNameToUniqueId; + Dictionary<IRInst*, VMOperand> mapInstToOperand; + Dictionary<UnownedStringSlice, VMOperand> mapStringToOperand; + struct ConstKey + { + uint64_t value; + uint32_t size; + bool operator==(const ConstKey& other) const + { + return value == other.value && size == other.size; + } + bool operator!=(const ConstKey& other) const { return !(*this == other); } + HashCode getHashCode() const { return combineHash(value, size); } + }; + Dictionary<ConstKey, VMOperand> mapConstantIntToOperand; + Dictionary<IRFunc*, int> mapFuncToId; + + VMByteCodeBuilder& byteCodeBuilder; + CodeGenContext* codeGenContext; + + ByteCodeEmitter(VMByteCodeBuilder& builder, CodeGenContext* codeGenContext) + : byteCodeBuilder(builder), codeGenContext(codeGenContext) + { + } + + String getName(IRInst* inst) + { + String name; + if (mapInstToName.tryGetValue(inst, name)) + return name; + + if (auto nameDecor = inst->findDecoration<IRNameHintDecoration>()) + { + name = nameDecor->getName(); + } + else if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>()) + { + name = linkageDecor->getMangledName(); + } + else + { + name = getIROpInfo(inst->getOp()).name; + } + if (int* id = mapNameToUniqueId.tryGetValue(name)) + { + (*id)++; + name = name + "_" + String(*id); + } + else + { + mapNameToUniqueId[name] = 0; + } + mapInstToName[inst] = name; + return name; + } + + struct InstRelocationEntry + { + Index offsetToOperand; + IRBlock* block; + }; + + template<typename T> + static T alignUp(T value, T alignment) + { + return (value + alignment - 1) / alignment * alignment; + } + + VMOperand allocReg(VMByteCodeFunctionBuilder& funcBuilder, size_t size, size_t alignment) + { + VMOperand operand; + operand.sectionId = kSlangByteCodeSectionWorkingSet; + operand.offset = funcBuilder.workingSetSizeInBytes; + funcBuilder.workingSetSizeInBytes = + alignUp(funcBuilder.workingSetSizeInBytes, (uint32_t)alignment); + operand.offset = funcBuilder.workingSetSizeInBytes; + operand.size = size; + funcBuilder.workingSetSizeInBytes += (uint32_t)size; + return operand; + } + + VMOperand ensureWorkingsetMemory(VMByteCodeFunctionBuilder& funcBuilder, IRInst* inst) + { + VMOperand operand; + + if (mapInstToOperand.tryGetValue(inst, operand)) + return operand; + + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getDataType(), + &sizeAlignment); + operand = allocReg(funcBuilder, sizeAlignment.size, sizeAlignment.alignment); + mapInstToOperand[inst] = operand; + return operand; + } + + VMOperand addStringLiteral(UnownedStringSlice str) + { + if (auto operand = mapStringToOperand.tryGetValue(str)) + return *operand; + VMOperand operand; + operand.sectionId = kSlangByteCodeSectionStrings; + operand.offset = (uint32_t)byteCodeBuilder.stringOffsets.getCount(); + + byteCodeBuilder.stringOffsets.add((uint32_t)byteCodeBuilder.constantSection.getCount()); + byteCodeBuilder.constantSection.addRange((uint8_t*)str.begin(), str.getLength()); + byteCodeBuilder.constantSection.add(0); + operand.setType(OperandDataType::String); + operand.size = 0; + mapStringToOperand[str] = operand; + return operand; + } + + void alignConstSection(int alignment) + { + int rem = (int)byteCodeBuilder.constantSection.getCount() % alignment; + if (rem != 0) + { + int paddingSize = alignment - rem; + for (int i = 0; i < paddingSize; i++) + { + byteCodeBuilder.constantSection.add(0); + } + } + } + + template<typename IntType> + VMOperand addConstantValue(IntType value) + { + ConstKey key; + key.value = value; + key.size = (uint32_t)sizeof(IntType); + if (auto operand = mapConstantIntToOperand.tryGetValue(key)) + return *operand; + VMOperand operand; + operand.sectionId = kSlangByteCodeSectionConstants; + // align constantSection + alignConstSection((int)sizeof(IntType)); + operand.offset = (uint32_t)byteCodeBuilder.constantSection.getCount(); + byteCodeBuilder.constantSection.addRange((uint8_t*)&value, sizeof(value)); + mapConstantIntToOperand[key] = operand; + + operand.size = sizeof(IntType); + if (operand.size == 4) + operand.setType(OperandDataType::Int32); + else if (operand.size == 8) + operand.setType(OperandDataType::Int64); + else + operand.setType(OperandDataType::General); + return operand; + } + + VMOperand addConstantValue(IRConstant* inst) + { + VMOperand operand; + operand.sectionId = kSlangByteCodeSectionConstants; + + // Align constantSection. + IRSizeAndAlignment sizeAlignment; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getDataType(), + &sizeAlignment); + alignConstSection(sizeAlignment.alignment); + + operand.offset = (uint32_t)byteCodeBuilder.constantSection.getCount(); + operand.size = sizeAlignment.size; + + switch (inst->getOp()) + { + case kIROp_StringLit: + { + return addStringLiteral(static_cast<IRStringLit*>(inst)->getStringSlice()); + } + case kIROp_IntLit: + { + int64_t value = static_cast<IRIntLit*>(inst)->getValue(); + byteCodeBuilder.constantSection.addRange((uint8_t*)&value, sizeAlignment.size); + operand.setType(OperandDataType::General); + if (sizeAlignment.size != 64) + { + operand.setType(OperandDataType::Int32); + } + break; + } + case kIROp_FloatLit: + { + auto value = static_cast<IRFloatLit*>(inst)->getValue(); + if (inst->getDataType()->getOp() == kIROp_HalfType) + { + auto halfValue = FloatToHalf((float)value); + byteCodeBuilder.constantSection.addRange( + (uint8_t*)&halfValue, + sizeof(halfValue)); + } + else if (inst->getDataType()->getOp() == kIROp_FloatType) + { + float floatValue = (float)value; + byteCodeBuilder.constantSection.addRange( + (uint8_t*)&floatValue, + sizeof(floatValue)); + operand.setType(OperandDataType::Float32); + } + else + { + byteCodeBuilder.constantSection.addRange((uint8_t*)&value, sizeof(value)); + operand.setType(OperandDataType::Float64); + } + break; + } + case kIROp_PtrLit: + { + int64_t value = static_cast<IRIntLit*>(inst)->getValue(); + byteCodeBuilder.constantSection.addRange((uint8_t*)&value, sizeof(value)); + break; + } + case kIROp_VoidLit: + break; + } + return operand; + } + + VMOperand ensureInst(IRInst* inst) + { + VMOperand operand; + if (mapInstToOperand.tryGetValue(inst, operand)) + return operand; + + if (auto constantInst = as<IRConstant>(inst)) + { + operand = addConstantValue(constantInst); + mapInstToOperand[inst] = operand; + } + else + { + SLANG_UNEXPECTED("unsupported global inst for vm bytecode emit"); + } + return operand; + } + + void writeInst( + VMByteCodeFunctionBuilder& funcBuilder, + VMOp op, + uint32_t extOp, + ArrayView<VMOperand> operands) + { + VMInstHeader instHeader; + instHeader.opcode = op; + instHeader.opcodeExtension = extOp; + instHeader.operandCount = (uint16_t)operands.getCount(); + funcBuilder.instOffsets.add(funcBuilder.code.getCount()); + funcBuilder.code.addRange(reinterpret_cast<uint8_t*>(&instHeader), sizeof(instHeader)); + for (auto operand : operands) + { + funcBuilder.code.addRange(reinterpret_cast<uint8_t*>(&operand), sizeof(operand)); + } + } + + void writeInst(VMByteCodeFunctionBuilder& funcBuilder, VMOp op, uint32_t extOp) + { + writeInst(funcBuilder, op, extOp, ArrayView<VMOperand>()); + } + + void writeInst( + VMByteCodeFunctionBuilder& funcBuilder, + VMOp op, + uint32_t extOp, + VMOperand operand) + { + writeInst(funcBuilder, op, extOp, makeArrayViewSingle(operand)); + } + + void writeInst( + VMByteCodeFunctionBuilder& funcBuilder, + VMOp op, + uint32_t extOp, + VMOperand operand1, + VMOperand operand2) + { + writeInst(funcBuilder, op, extOp, makeArray(operand1, operand2).getView()); + } + + void writeInst( + VMByteCodeFunctionBuilder& funcBuilder, + VMOp op, + uint32_t extOp, + VMOperand operand1, + VMOperand operand2, + VMOperand operand3) + { + writeInst(funcBuilder, op, extOp, makeArray(operand1, operand2, operand3).getView()); + } + + uint32_t getExtCode(IRInst* type) + { + ArithmeticExtCode extCode = {}; + if (auto vecType = as<IRVectorType>(type)) + { + extCode.vectorSize = getIntVal(vecType->getElementCount()); + type = vecType->getElementType(); + } + else if (auto matType = as<IRMatrixType>(type)) + { + extCode.vectorSize = + getIntVal(matType->getRowCount()) * getIntVal(matType->getColumnCount()); + type = matType->getElementType(); + } + switch (type->getOp()) + { + case kIROp_IntType: + case kIROp_BoolType: + extCode.scalarType = kSlangByteCodeScalarTypeSignedInt; + extCode.scalarBitWidth = 2; + break; + case kIROp_Int8Type: + extCode.scalarType = kSlangByteCodeScalarTypeSignedInt; + extCode.scalarBitWidth = 0; + break; + case kIROp_Int16Type: + extCode.scalarType = kSlangByteCodeScalarTypeSignedInt; + extCode.scalarBitWidth = 1; + break; + case kIROp_Int64Type: + case kIROp_IntPtrType: + extCode.scalarType = kSlangByteCodeScalarTypeSignedInt; + extCode.scalarBitWidth = 3; + break; + case kIROp_UIntType: + extCode.scalarType = kSlangByteCodeScalarTypeUnsignedInt; + extCode.scalarBitWidth = 2; + break; + case kIROp_UInt8Type: + extCode.scalarType = kSlangByteCodeScalarTypeUnsignedInt; + extCode.scalarBitWidth = 0; + break; + case kIROp_UInt16Type: + extCode.scalarType = kSlangByteCodeScalarTypeUnsignedInt; + extCode.scalarBitWidth = 1; + break; + case kIROp_UInt64Type: + case kIROp_UIntPtrType: + case kIROp_PtrType: + case kIROp_OutType: + case kIROp_InOutType: + case kIROp_RefType: + case kIROp_NativePtrType: + extCode.scalarType = kSlangByteCodeScalarTypeUnsignedInt; + extCode.scalarBitWidth = 3; + break; + case kIROp_FloatType: + extCode.scalarType = kSlangByteCodeScalarTypeFloat; + extCode.scalarBitWidth = 2; + break; + case kIROp_HalfType: + extCode.scalarType = kSlangByteCodeScalarTypeFloat; + extCode.scalarBitWidth = 1; + break; + case kIROp_DoubleType: + extCode.scalarType = kSlangByteCodeScalarTypeFloat; + extCode.scalarBitWidth = 3; + break; + default: + SLANG_UNEXPECTED("Unsupported type for arithmetic operation"); + } + uint32_t result; + memcpy(&result, &extCode, sizeof(extCode)); + return result; + } + + VMInstHeader translateArithmeticOp(IRInst* inst) + { + VMInstHeader opInfo = {}; + + switch (inst->getOp()) + { + case kIROp_Add: + opInfo.opcode = VMOp::Add; + break; + case kIROp_Sub: + opInfo.opcode = VMOp::Sub; + break; + case kIROp_Mul: + opInfo.opcode = VMOp::Mul; + break; + case kIROp_Div: + opInfo.opcode = VMOp::Div; + break; + case kIROp_IRem: + case kIROp_FRem: + opInfo.opcode = VMOp::Rem; + break; + case kIROp_Neg: + opInfo.opcode = VMOp::Neg; + break; + case kIROp_And: + opInfo.opcode = VMOp::And; + break; + case kIROp_Or: + opInfo.opcode = VMOp::Or; + break; + case kIROp_Not: + opInfo.opcode = VMOp::Not; + break; + case kIROp_BitAnd: + opInfo.opcode = VMOp::BitAnd; + break; + case kIROp_BitOr: + opInfo.opcode = VMOp::BitOr; + break; + case kIROp_BitXor: + opInfo.opcode = VMOp::BitXor; + break; + case kIROp_BitNot: + opInfo.opcode = VMOp::BitNot; + break; + case kIROp_Lsh: + opInfo.opcode = VMOp::Shl; + break; + case kIROp_Rsh: + opInfo.opcode = VMOp::Shr; + break; + case kIROp_Less: + opInfo.opcode = VMOp::Less; + break; + case kIROp_Leq: + opInfo.opcode = VMOp::Leq; + break; + case kIROp_Greater: + opInfo.opcode = VMOp::Greater; + break; + case kIROp_Geq: + opInfo.opcode = VMOp::Geq; + break; + case kIROp_Eql: + opInfo.opcode = VMOp::Equal; + break; + case kIROp_Neq: + opInfo.opcode = VMOp::Neq; + break; + default: + SLANG_UNEXPECTED("Unsupported operation"); + break; + } + opInfo.opcodeExtension = getExtCode(inst->getOperand(0)->getDataType()); + return opInfo; + } + + void emitCast(VMByteCodeFunctionBuilder& funcBuilder, VMOp op, IRInst* inst) + { + auto extCode1 = getExtCode(inst->getDataType()); + auto extCode2 = getExtCode(inst->getOperand(0)->getDataType()); + auto extCode = extCode1 | (extCode2 << 16); + writeInst( + funcBuilder, + op, + extCode, + ensureWorkingsetMemory(funcBuilder, inst), + ensureInst(inst->getOperand(0))); + } + + void emitInst( + VMByteCodeFunctionBuilder& funcBuilder, + IRInst* inst, + List<InstRelocationEntry>& relocations) + { + switch (inst->getOp()) + { + case kIROp_undefined: + { + ensureWorkingsetMemory(funcBuilder, inst); + } + break; + case kIROp_Param: + { + auto operand = ensureWorkingsetMemory(funcBuilder, inst); + if (isFirstBlock(inst->getParent())) + { + funcBuilder.parameterOffsets.add(operand.offset); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getDataType(), + &sizeAlignment); + funcBuilder.parameterSize = + operand.offset + (uint32_t)sizeAlignment.getStride(); + } + } + break; + case kIROp_Var: + { + IRBuilder builder(inst); + auto type = tryGetPointedToType(&builder, inst->getDataType()); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + type, + &sizeAlignment); + auto varStorage = allocReg( + funcBuilder, + (size_t)sizeAlignment.size, + (size_t)sizeAlignment.alignment); + writeInst( + funcBuilder, + VMOp::GetWorkingSetPtr, + varStorage.offset, + ensureWorkingsetMemory(funcBuilder, inst)); + } + break; + case kIROp_Load: + { + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getDataType(), + &sizeAlignment); + writeInst( + funcBuilder, + VMOp::Load, + (uint32_t)sizeAlignment.getStride(), + ensureWorkingsetMemory(funcBuilder, inst), + ensureInst(inst->getOperand(0))); + } + break; + case kIROp_Store: + { + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getOperand(1)->getDataType(), + &sizeAlignment); + writeInst( + funcBuilder, + VMOp::Store, + (uint32_t)sizeAlignment.getStride(), + ensureInst(inst->getOperand(0)), + ensureInst(inst->getOperand(1))); + } + break; + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_And: + case kIROp_FRem: + case kIROp_IRem: + case kIROp_Or: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_Less: + case kIROp_Leq: + case kIROp_Greater: + case kIROp_Geq: + case kIROp_Eql: + case kIROp_Neq: + { + auto opInfo = translateArithmeticOp(inst); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getDataType(), + &sizeAlignment); + writeInst( + funcBuilder, + opInfo.opcode, + opInfo.opcodeExtension, + ensureWorkingsetMemory(funcBuilder, inst), + ensureInst(inst->getOperand(0)), + ensureInst(inst->getOperand(1))); + } + break; + case kIROp_Neg: + case kIROp_Not: + case kIROp_BitNot: + { + auto opInfo = translateArithmeticOp(inst); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getDataType(), + &sizeAlignment); + writeInst( + funcBuilder, + opInfo.opcode, + opInfo.opcodeExtension, + ensureWorkingsetMemory(funcBuilder, inst), + ensureInst(inst->getOperand(0))); + } + break; + case kIROp_unconditionalBranch: + case kIROp_loop: + { + // Write phi arguments into param registers. + auto branch = as<IRUnconditionalBranch>(inst); + auto params = branch->getTargetBlock()->getParams(); + List<IRInst*> paramList; + for (auto param : params) + { + paramList.add(param); + } + if (paramList.getCount() != (Index)branch->getArgCount()) + { + SLANG_UNEXPECTED("Invalid number of arguments for branch instruction"); + } + for (UInt i = 0; i < branch->getArgCount(); i++) + { + auto arg = branch->getArg(i); + auto param = paramList[i]; + auto paramReg = ensureWorkingsetMemory(funcBuilder, param); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + param->getDataType(), + &sizeAlignment); + writeInst( + funcBuilder, + VMOp::Copy, + (uint32_t)sizeAlignment.getStride(), + paramReg, + ensureInst(arg)); + } + // Write jump inst. + VMOperand relocOperand = {}; + writeInst(funcBuilder, VMOp::Jump, 0, relocOperand); + InstRelocationEntry entry; + entry.block = (IRBlock*)inst->getOperand(0); + entry.offsetToOperand = funcBuilder.code.getCount() - sizeof(VMOperand); + relocations.add(entry); + } + break; + case kIROp_ifElse: + { + VMOperand relocOperand = {}; + writeInst( + funcBuilder, + VMOp::JumpIf, + 0, + ensureInst(inst->getOperand(0)), + relocOperand, + relocOperand); + InstRelocationEntry entry; + entry.block = (IRBlock*)inst->getOperand(1); + entry.offsetToOperand = funcBuilder.code.getCount() - sizeof(VMOperand) * 2; + relocations.add(entry); + entry.block = (IRBlock*)inst->getOperand(2); + entry.offsetToOperand = funcBuilder.code.getCount() - sizeof(VMOperand); + relocations.add(entry); + } + break; + case kIROp_Call: + { + auto callInst = as<IRCall>(inst); + auto callee = as<IRFunc>(callInst->getCallee()); + UnownedStringSlice def; + IRInst* intrinsicInst; + if (findTargetIntrinsicDefinition( + callee, + codeGenContext->getTargetCaps(), + def, + intrinsicInst)) + { + auto calleeOperand = addStringLiteral(def); + List<VMOperand> operands; + operands.add(ensureWorkingsetMemory(funcBuilder, inst)); + operands.add(calleeOperand); + for (UInt i = 0; i < callInst->getArgCount(); ++i) + { + operands.add(ensureInst(callInst->getArg(i))); + } + writeInst(funcBuilder, VMOp::CallExt, 0, operands.getArrayView()); + break; + } + List<VMOperand> operands; + int calleeId = -1; + mapFuncToId.tryGetValue(callee, calleeId); + SLANG_ASSERT(calleeId != -1); + VMOperand calleeOperand = {}; + calleeOperand.sectionId = kSlangByteCodeSectionFuncs; + calleeOperand.offset = calleeId; + calleeOperand.setType(OperandDataType::Int32); + operands.add(ensureWorkingsetMemory(funcBuilder, inst)); + operands.add(calleeOperand); + for (UInt i = 0; i < callInst->getArgCount(); ++i) + { + operands.add(ensureInst(callInst->getArg(i))); + } + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getDataType(), + &sizeAlignment); + writeInst( + funcBuilder, + VMOp::Call, + (uint32_t)sizeAlignment.getStride(), + operands.getArrayView()); + } + break; + case kIROp_MissingReturn: + case kIROp_Return: + { + auto returnInst = as<IRReturn>(inst); + if (returnInst && returnInst->getVal()->getOp() != kIROp_VoidLit) + { + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + returnInst->getVal()->getDataType(), + &sizeAlignment); + writeInst( + funcBuilder, + VMOp::Ret, + (uint32_t)sizeAlignment.getStride(), + ensureInst(returnInst->getOperand(0))); + } + else + { + writeInst(funcBuilder, VMOp::Ret, 0); + } + } + break; + case kIROp_GetElementPtr: + { + auto getElemInst = as<IRGetElementPtr>(inst); + auto base = getElemInst->getBase(); + auto index = getElemInst->getIndex(); + IRBuilder builder(inst); + auto elementType = tryGetPointedToType(&builder, getElemInst->getDataType()); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + elementType, + &sizeAlignment); + auto stride = sizeAlignment.getStride(); + auto baseOperand = ensureInst(base); + auto indexOperand = ensureInst(index); + writeInst( + funcBuilder, + VMOp::GetElementPtr, + (uint32_t)stride, + ensureWorkingsetMemory(funcBuilder, inst), + baseOperand, + indexOperand); + } + break; + case kIROp_FieldAddress: + { + auto fieldAddrInst = as<IRFieldAddress>(inst); + auto base = fieldAddrInst->getBase(); + auto fieldKey = (IRStructKey*)fieldAddrInst->getField(); + IRBuilder builder(base); + + auto structType = + as<IRStructType>(tryGetPointedToType(&builder, base->getDataType())); + IRIntegerValue offset = 0; + auto field = findStructField(structType, fieldKey); + getNaturalOffset( + codeGenContext->getTargetProgram()->getOptionSet(), + field, + &offset); + + writeInst( + funcBuilder, + VMOp::Add, + getExtCode(inst->getDataType()), + ensureWorkingsetMemory(funcBuilder, inst), + ensureInst(base), + addConstantValue((uint64_t)offset)); + } + break; + case kIROp_GetOffsetPtr: + { + auto getOffsetPtrInst = as<IRGetOffsetPtr>(inst); + auto base = getOffsetPtrInst->getBase(); + auto offset = getOffsetPtrInst->getOffset(); + IRSizeAndAlignment sizeAlignment = {}; + IRBuilder builder(inst); + auto elementType = tryGetPointedToType(&builder, getOffsetPtrInst->getDataType()); + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + elementType, + &sizeAlignment); + writeInst( + funcBuilder, + VMOp::OffsetPtr, + (uint32_t)sizeAlignment.getStride(), + ensureWorkingsetMemory(funcBuilder, inst), + ensureInst(base), + ensureInst(offset)); + } + break; + case kIROp_FieldExtract: + { + auto fieldExtractInst = as<IRFieldExtract>(inst); + auto base = fieldExtractInst->getBase(); + auto fieldKey = (IRStructKey*)fieldExtractInst->getField(); + + auto structType = as<IRStructType>(base->getDataType()); + IRIntegerValue offset = 0; + auto field = findStructField(structType, fieldKey); + getNaturalOffset( + codeGenContext->getTargetProgram()->getOptionSet(), + field, + &offset); + + auto baseOperand = ensureInst(base); + baseOperand.offset += (uint32_t)offset; + mapInstToOperand[inst] = baseOperand; + } + break; + case kIROp_GetElement: + { + auto getElemInst = as<IRGetElement>(inst); + auto base = getElemInst->getBase(); + auto index = getElemInst->getIndex(); + auto elementType = getElemInst->getDataType(); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + elementType, + &sizeAlignment); + auto stride = sizeAlignment.getStride(); + auto baseOperand = ensureInst(base); + if (as<IRIntLit>(index)) + { + baseOperand.offset += (uint32_t)(stride * getIntVal(index)); + mapInstToOperand[inst] = baseOperand; + break; + } + writeInst( + funcBuilder, + VMOp::GetElement, + (uint32_t)stride, + ensureWorkingsetMemory(funcBuilder, inst), + baseOperand, + ensureInst(index)); + } + break; + case kIROp_BitCast: + { + auto operand = ensureInst(inst->getOperand(0)); + mapInstToOperand[inst] = operand; + } + break; + case kIROp_IntCast: + case kIROp_CastIntToPtr: + case kIROp_CastPtrToInt: + case kIROp_CastIntToFloat: + case kIROp_CastFloatToInt: + case kIROp_FloatCast: + emitCast(funcBuilder, VMOp::Cast, inst); + break; + case kIROp_swizzle: + { + auto swizzleInst = as<IRSwizzle>(inst); + auto base = swizzleInst->getBase(); + auto baseOperand = ensureInst(base); + auto count = (uint32_t)swizzleInst->getElementCount(); + List<VMOperand> operands; + operands.add(ensureWorkingsetMemory(funcBuilder, inst)); + operands.add(baseOperand); + for (UInt i = 0; i < count; ++i) + { + auto index = (uint32_t)getIntVal(swizzleInst->getElementIndex(i)); + VMOperand operand; + operand.sectionId = kSlangByteCodeSectionImmediate; + operand.offset = index; + operands.add(operand); + } + writeInst( + funcBuilder, + VMOp::Swizzle, + getExtCode(inst->getDataType()), + operands.getArrayView()); + } + break; + case kIROp_MakeArray: + { + auto result = ensureWorkingsetMemory(funcBuilder, inst); + auto arrayType = as<IRArrayTypeBase>(inst->getDataType()); + auto elementType = arrayType->getElementType(); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + elementType, + &sizeAlignment); + auto stride = (uint32_t)sizeAlignment.getStride(); + for (UInt i = 0; i < inst->getOperandCount(); ++i) + { + VMOperand elementOperand = result; + elementOperand.offset += (uint32_t)(stride * i); + writeInst( + funcBuilder, + VMOp::Copy, + stride, + elementOperand, + ensureInst(inst->getOperand(i))); + } + } + break; + case kIROp_MakeArrayFromElement: + { + auto result = ensureWorkingsetMemory(funcBuilder, inst); + auto arrayType = as<IRArrayTypeBase>(inst->getDataType()); + auto elementType = arrayType->getElementType(); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + elementType, + &sizeAlignment); + auto stride = (uint32_t)sizeAlignment.getStride(); + for (Index i = 0; i < getIntVal(arrayType->getElementCount()); ++i) + { + VMOperand elementOperand = result; + elementOperand.offset += (uint32_t)(stride * i); + writeInst( + funcBuilder, + VMOp::Copy, + stride, + elementOperand, + ensureInst(inst->getOperand(0))); + } + } + break; + case kIROp_MakeStruct: + { + auto result = ensureWorkingsetMemory(funcBuilder, inst); + auto structType = as<IRStructType>(inst->getDataType()); + List<IRStructField*> fields; + for (auto field : structType->getFields()) + { + fields.add(field); + } + for (UInt i = 0; i < inst->getOperandCount(); ++i) + { + auto field = fields[i]; + IRIntegerValue offset = 0; + getNaturalOffset( + codeGenContext->getTargetProgram()->getOptionSet(), + field, + &offset); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + field->getFieldType(), + &sizeAlignment); + VMOperand elementOperand = result; + elementOperand.offset += (uint32_t)offset; + writeInst( + funcBuilder, + VMOp::Copy, + (uint32_t)sizeAlignment.getStride(), + elementOperand, + ensureInst(inst->getOperand(i))); + } + } + break; + case kIROp_MakeVector: + case kIROp_MakeMatrix: + { + auto result = ensureWorkingsetMemory(funcBuilder, inst); + for (UInt i = 0; i < inst->getOperandCount(); ++i) + { + VMOperand elementOperand = result; + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getOperand(i)->getDataType(), + &sizeAlignment); + writeInst( + funcBuilder, + VMOp::Copy, + (uint32_t)sizeAlignment.getStride(), + elementOperand, + ensureInst(inst->getOperand(i))); + result.offset += (uint32_t)sizeAlignment.getStride(); + } + } + break; + case kIROp_MakeVectorFromScalar: + { + auto result = ensureWorkingsetMemory(funcBuilder, inst); + auto vectorType = as<IRVectorType>(inst->getDataType()); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + vectorType->getElementType(), + &sizeAlignment); + auto stride = (uint32_t)sizeAlignment.getStride(); + for (Index i = 0; i < getIntVal(vectorType->getElementCount()); ++i) + { + VMOperand elementOperand = result; + elementOperand.offset += (uint32_t)(stride * i); + writeInst( + funcBuilder, + VMOp::Copy, + stride, + elementOperand, + ensureInst(inst->getOperand(0))); + } + } + break; + case kIROp_MakeMatrixFromScalar: + { + auto result = ensureWorkingsetMemory(funcBuilder, inst); + auto matrixType = as<IRMatrixType>(inst->getDataType()); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + matrixType->getElementType(), + &sizeAlignment); + auto stride = (uint32_t)sizeAlignment.getStride(); + for (Index i = 0; i < getIntVal(matrixType->getRowCount()); ++i) + { + for (Index j = 0; j < getIntVal(matrixType->getColumnCount()); ++j) + { + writeInst( + funcBuilder, + VMOp::Copy, + stride, + result, + ensureInst(inst->getOperand(0))); + result.offset += stride; + } + } + } + break; + case kIROp_Printf: + { + List<VMOperand> operands; + operands.add(ensureInst(inst->getOperand(0))); + auto tuple = inst->getOperand(1); + if (auto makeTuple = as<IRMakeStruct>(tuple)) + { + for (UInt i = 0; i < makeTuple->getOperandCount(); i++) + { + operands.add(ensureInst(makeTuple->getOperand(i))); + } + } + else + { + // If not a tuple, it should be a single value. + operands.add(ensureInst(tuple)); + } + writeInst(funcBuilder, VMOp::Print, 0, operands.getArrayView()); + } + break; + default: + SLANG_UNIMPLEMENTED_X("VM bytecode gen for inst."); + } + } + + void emitFunction(IRFunc* func) + { + VMByteCodeFunctionBuilder funcBuilder; + funcBuilder.name = addStringLiteral(getName(func).getUnownedSlice()); + + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + func->getResultType(), + &sizeAlignment); + funcBuilder.resultSize = (uint32_t)sizeAlignment.getStride(); + + Dictionary<IRBlock*, Index> mapBlockToByteOffset; + List<InstRelocationEntry> relocations; + + for (auto block : func->getBlocks()) + { + mapBlockToByteOffset[block] = funcBuilder.code.getCount(); + + for (auto inst : block->getChildren()) + { + funcBuilder.instOffsets.add(funcBuilder.code.getCount()); + emitInst(funcBuilder, inst, relocations); + } + } + + // Apply relocations for jump targets. + for (auto reloc : relocations) + { + Index offset = mapBlockToByteOffset.getValue(reloc.block); + uint8_t* codePtr = (funcBuilder.code.getBuffer() + reloc.offsetToOperand); + VMOperand* operand = (VMOperand*)codePtr; + operand->sectionId = kSlangByteCodeSectionInsts; + operand->offset = (uint32_t)offset; + } + funcBuilder.workingSetSizeInBytes = + alignUp(funcBuilder.workingSetSizeInBytes, (uint32_t)sizeof(uint64_t)); + + byteCodeBuilder.functions.add(funcBuilder); + } + + void emitEntryPoints(LinkedIR& linkedIR) + { + Dictionary<IRInst*, HashSet<IRFunc*>> referencingEntryPoints; + buildEntryPointReferenceGraph(referencingEntryPoints, linkedIR.module); + OrderedHashSet<IRFunc*> entryPointSet; + for (auto entryPoint : linkedIR.entryPoints) + { + auto entryPointDecor = entryPoint->findDecoration<IREntryPointDecoration>(); + if (!entryPointDecor) + continue; + if (entryPointDecor->getProfile().getStage() != Stage::Dispatch) + continue; + entryPointSet.add(entryPoint); + } + + List<IRFunc*> functionsToEmit; + + // Emit all entrypoints first. + for (auto entryPoint : entryPointSet) + { + // Emit the function for the entry point. + functionsToEmit.add(entryPoint); + } + + // Emit remaining funcitons, if they are called by entry points. + for (auto globalInst : linkedIR.module->getGlobalInsts()) + { + auto func = as<IRFunc>(globalInst); + + if (!func) + continue; + + // Skip if already emitted as an entry point. + if (entryPointSet.contains(func)) + continue; + + HashSet<IRFunc*>* entryPointRefs = referencingEntryPoints.tryGetValue(func); + if (!entryPointRefs) + continue; + + // If the function is referenced by any entry point, emit it. + bool referencedByHostEntryPoint = false; + for (auto entryPoint : *entryPointRefs) + { + if (entryPointSet.contains(entryPoint)) + { + referencedByHostEntryPoint = true; + break; + } + } + if (referencedByHostEntryPoint) + { + functionsToEmit.add(func); + } + } + + // Emit all functions. + for (Index i = 0; i < functionsToEmit.getCount(); i++) + { + mapFuncToId[functionsToEmit[i]] = (int)i; + } + for (auto func : functionsToEmit) + { + emitFunction(func); + } + } +}; + +SlangResult emitVMByteCodeForEntryPoints( + CodeGenContext* codeGenContext, + LinkedIR& linkedIR, + VMByteCodeBuilder& byteCode) +{ + ByteCodeEmitter emitter(byteCode, codeGenContext); + emitter.emitEntryPoints(linkedIR); + return SLANG_OK; +} + +SlangResult VMByteCodeBuilder::serialize(slang::IBlob** outBlob) +{ + OwnedMemoryStream ms(FileAccess::Write); + ms.write(&kSlangByteCodeFourCC, sizeof(uint32_t)); + ms.write(&kSlangByteCodeVersion, sizeof(uint32_t)); + + // Write functions section. + ms.write(&kSlangByteCodeFunctionsFourCC, sizeof(uint32_t)); + uint32_t functionChunkSizeStart = (uint32_t)ms.getPosition(); + uint32_t zero = 0; + ms.write(&zero, sizeof(uint32_t)); // Reserve space for function chunk size. + + uint32_t functionCount = (uint32_t)functions.getCount(); + ms.write(&functionCount, sizeof(uint32_t)); + // Reserve space for function offsets. + auto functionOffsetStart = ms.getPosition(); + for (uint32_t i = 0; i < functionCount; ++i) + { + ms.write(&zero, sizeof(uint32_t)); + } + List<uint32_t> functionOffsets; + for (uint32_t i = 0; i < functionCount; ++i) + { + functionOffsets.add((uint32_t)ms.getPosition()); + + auto& function = functions[i]; + VMFuncHeader funcHeader; + funcHeader.name = function.name; + funcHeader.codeSize = (uint32_t)function.code.getCount(); + funcHeader.parameterCount = (uint32_t)function.parameterOffsets.getCount(); + funcHeader.workingSetSizeInBytes = function.workingSetSizeInBytes; + funcHeader.returnValueSizeInBytes = function.resultSize; + funcHeader.parameterSizeInBytes = function.parameterSize; + ms.write(&funcHeader, sizeof(funcHeader)); + ms.write( + function.parameterOffsets.getBuffer(), + sizeof(uint32_t) * function.parameterOffsets.getCount()); + + ms.write(function.code.begin(), funcHeader.codeSize); + } + uint32_t functionChunkSize = + (uint32_t)(ms.getPosition() - functionChunkSizeStart - sizeof(uint32_t)); + + // Write kernel Blob section. + ms.write(&kSlangByteCodeKernelBlobFourCC, sizeof(uint32_t)); + uint32_t kernelBlobSize = (uint32_t)kernelBlob->getBufferSize(); + ms.write(&kernelBlobSize, sizeof(uint32_t)); + ms.write(kernelBlob->getBufferPointer(), kernelBlobSize); + + // Write constant section. + ms.write(&kSlangByteCodeConstantsFourCC, sizeof(uint32_t)); + uint32_t constanBlobSize = (uint32_t)constantSection.getCount(); + ms.write(&constanBlobSize, sizeof(uint32_t)); + uint32_t stringCount = (uint32_t)stringOffsets.getCount(); + ms.write(&stringCount, sizeof(uint32_t)); + ms.write(stringOffsets.getBuffer(), sizeof(uint32_t) * stringCount); + ms.write(constantSection.begin(), constanBlobSize); + + auto blob = RawBlob::create(ms.getContents().getBuffer(), ms.getContents().getCount()); + + // Patch in the function chunk size. + uint32_t* functionChunkSizePtr = + (uint32_t*)((uint8_t*)blob->getBufferPointer() + functionChunkSizeStart); + *functionChunkSizePtr = functionChunkSize; + + // Patch in the function offsets. + auto funcOffsetTable = (uint32_t*)((uint8_t*)blob->getBufferPointer() + functionOffsetStart); + for (uint32_t i = 0; i < functionCount; ++i) + { + funcOffsetTable[i] = functionOffsets[i]; + } + + *outBlob = blob.detach(); + return SLANG_OK; +} + +} // namespace Slang diff --git a/source/slang/slang-emit-vm.h b/source/slang/slang-emit-vm.h new file mode 100644 index 000000000..ee412c752 --- /dev/null +++ b/source/slang/slang-emit-vm.h @@ -0,0 +1,38 @@ +#ifndef SLANG_EMIT_VM_H +#define SLANG_EMIT_VM_H + +#include "slang-emit-base.h" +#include "slang-ir-link.h" +#include "slang-vm-bytecode.h" + +namespace Slang +{ + +struct VMByteCodeFunctionBuilder +{ + VMOperand name = {}; + uint32_t workingSetSizeInBytes = 0; + List<uint8_t> code; + List<Index> instOffsets; + List<uint32_t> parameterOffsets; + uint32_t resultSize = 0; + uint32_t parameterSize = 0; +}; + +struct VMByteCodeBuilder +{ + List<VMByteCodeFunctionBuilder> functions; + ComPtr<slang::IBlob> kernelBlob; + + List<uint8_t> constantSection; + List<uint32_t> stringOffsets; + SlangResult serialize(slang::IBlob** outBlob); +}; + +SlangResult emitVMByteCodeForEntryPoints( + CodeGenContext* codeGenContext, + LinkedIR& linkedIR, + VMByteCodeBuilder& byteCode); +} // namespace Slang + +#endif diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 9cfa7dcae..aa7387e22 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -15,8 +15,10 @@ #include "slang-emit-glsl.h" #include "slang-emit-hlsl.h" #include "slang-emit-metal.h" +#include "slang-emit-slang.h" #include "slang-emit-source-writer.h" #include "slang-emit-torch.h" +#include "slang-emit-vm.h" #include "slang-emit-wgsl.h" #include "slang-ir-any-value-inference.h" #include "slang-ir-autodiff.h" @@ -116,6 +118,7 @@ #include "slang-syntax.h" #include "slang-type-layout.h" #include "slang-visitor.h" +#include "slang-vm-bytecode.h" #include <assert.h> @@ -825,6 +828,7 @@ Result linkAndOptimizeIR( switch (target) { case CodeGenTarget::HostCPPSource: + case CodeGenTarget::HostVM: break; case CodeGenTarget::CUDASource: collectOptiXEntryPointUniformParams(irModule); @@ -859,6 +863,7 @@ Result linkAndOptimizeIR( case CodeGenTarget::HostCPPSource: case CodeGenTarget::CPPSource: case CodeGenTarget::CUDASource: + case CodeGenTarget::HostVM: break; } @@ -1154,6 +1159,15 @@ Result linkAndOptimizeIR( cleanupGenerics(targetProgram, irModule, sink); dumpIRIfEnabled(codeGenContext, irModule, "AFTER-LOWER-GENERICS"); + // Don't need to run any further target-dependent passes if we are generating code + // for host vm. + if (target == CodeGenTarget::HostVM) + { + performForceInlining(irModule); + simplifyIR(targetProgram, irModule, defaultIRSimplificationOptions, sink); + return SLANG_OK; + } + // After dynamic dispatch logic is resolved into ordinary function calls, // we can now run our stage specialization logic. if (requiredLoweringPassSet.specializeStageSwitch) @@ -2329,4 +2343,47 @@ SlangResult emitSPIRVForEntryPointsDirectly( return SLANG_OK; } +SlangResult emitHostVMCode(CodeGenContext* codeGenContext, ComPtr<IArtifact>& outArtifact) +{ + LinkedIR linkedIR; + LinkingAndOptimizationOptions linkingAndOptimizationOptions; + SLANG_RETURN_ON_FAIL( + linkAndOptimizeIR(codeGenContext, linkingAndOptimizationOptions, linkedIR)); + + VMByteCodeBuilder byteCode; + SLANG_RETURN_ON_FAIL(emitVMByteCodeForEntryPoints(codeGenContext, linkedIR, byteCode)); + + String slangDeclaration; + SLANG_RETURN_ON_FAIL( + emitSlangDeclarationsForEntryPoints(codeGenContext, linkedIR, slangDeclaration)); + + slang::SessionDesc sessionDesc = {}; + ComPtr<slang::ISession> slangSession; + SLANG_RETURN_ON_FAIL( + codeGenContext->getSession()->createSession(sessionDesc, slangSession.writeRef())); + auto linkage = static_cast<Linkage*>(slangSession.get()); + + ComPtr<ISlangBlob> diagnostics; + auto module = slangSession->loadModuleFromSource( + "kernel", + "kernel.slang", + StringBlob::create(slangDeclaration), + diagnostics.writeRef()); + if (!module) + return SLANG_FAIL; + RefPtr<Module> newModule = new Module(linkage); + newModule->setModuleDecl(static_cast<Module*>(module)->getModuleDecl()); + newModule->setIRModule(linkedIR.module); + newModule->setName("kernels"); + SLANG_RETURN_ON_FAIL(newModule->serialize(byteCode.kernelBlob.writeRef())); + + ComPtr<slang::IBlob> byteCodeBlob; + SLANG_RETURN_ON_FAIL(byteCode.serialize(byteCodeBlob.writeRef())); + + outArtifact = ArtifactUtil::createArtifactForCompileTarget(SLANG_HOST_VM); + outArtifact->addRepresentationUnknown(byteCodeBlob); + + return SLANG_OK; +} + } // namespace Slang diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index 5e1f2eb21..0b89baf64 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -283,16 +283,6 @@ struct DeadCodeEliminationContext } }; -bool isFirstBlock(IRInst* inst) -{ - auto block = as<IRBlock>(inst); - if (!block) - return false; - if (!block->getParent()) - return false; - return block->getParent()->getFirstBlock() == block; -} - bool isPtrUsed(IRInst* ptrInst) { for (auto use = ptrInst->firstUse; use; use = use->nextUse) diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 4919850eb..36aec22b5 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -2232,4 +2232,14 @@ UnownedStringSlice getMangledName(IRInst* inst) return UnownedStringSlice(); } +bool isFirstBlock(IRInst* inst) +{ + auto block = as<IRBlock>(inst); + if (!block) + return false; + if (!block->getParent()) + return false; + return block->getParent()->getFirstBlock() == block; +} + } // namespace Slang diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 0a8bc9b1d..b111f8abf 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -387,6 +387,7 @@ void legalizeDefUse(IRGlobalValueWithCode* func); UnownedStringSlice getMangledName(IRInst* inst); +bool isFirstBlock(IRInst* inst); } // namespace Slang #endif diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index 3c0bcf8db..d73bc4307 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -3615,6 +3615,7 @@ SlangResult OptionsParser::_parse(int argc, char const* const* argv) case CodeGenTarget::MetalLibAssembly: case CodeGenTarget::Metal: case CodeGenTarget::WGSL: + case CodeGenTarget::HostVM: rawOutput.isWholeProgram = true; break; case CodeGenTarget::SPIRV: diff --git a/source/slang/slang-profile-defs.h b/source/slang/slang-profile-defs.h index 29e2bbf2e..4781d5bc3 100644 --- a/source/slang/slang-profile-defs.h +++ b/source/slang/slang-profile-defs.h @@ -73,7 +73,7 @@ PROFILE_STAGE(Callable, callable, SLANG_STAGE_CALLABLE) PROFILE_STAGE(Mesh, mesh, SLANG_STAGE_MESH) PROFILE_STAGE(Amplification, amplification, SLANG_STAGE_AMPLIFICATION) - +PROFILE_STAGE(Dispatch, dispatch, SLANG_STAGE_DISPATCH) // Note: HLSL and Direct3D convention erroneously uses the term "Pixel Shader" // for the thing that shades *fragments*. Slang strives to treat the more correct diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 63bcd1ba2..68df665d4 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -2254,6 +2254,7 @@ LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targe case CodeGenTarget::ShaderSharedLibrary: case CodeGenTarget::CPPSource: case CodeGenTarget::CSource: + case CodeGenTarget::HostVM: { // For now lets use some fairly simple CPU binding rules diff --git a/source/slang/slang-vm-bytecode.cpp b/source/slang/slang-vm-bytecode.cpp new file mode 100644 index 000000000..1eafa4e57 --- /dev/null +++ b/source/slang/slang-vm-bytecode.cpp @@ -0,0 +1,424 @@ +#include "slang-vm-bytecode.h" + +#include "core/slang-blob.h" +#include "core/slang-stream.h" +#include "core/slang-string-escape-util.h" + +using namespace slang; + +namespace Slang +{ +static SlangResult consumeFourCC(MemoryStreamBase& stream, uint32_t expected) +{ + uint32_t fourCC = 0; + size_t bytesRead = 0; + SLANG_RETURN_ON_FAIL(stream.read(&fourCC, sizeof(fourCC), bytesRead)); + if (fourCC != expected) + { + return SLANG_FAIL; + } + return SLANG_OK; +} + +template<typename T> +static SlangResult readValue(MemoryStreamBase& stream, T& value) +{ + size_t bytesRead = 0; + SLANG_RETURN_ON_FAIL(stream.read(&value, sizeof(T), bytesRead)); + if (bytesRead != sizeof(T)) + { + return SLANG_FAIL; // Not enough data + } + return SLANG_OK; +} + +static SlangResult readUInt32(MemoryStreamBase& stream, uint32_t& value) +{ + return readValue(stream, value); +} + +SlangResult initVMModule(uint8_t* code, uint32_t codeSize, VMModuleView* moduleView) +{ + MemoryStreamBase stream(FileAccess::Read, code, codeSize); + moduleView->code = code; + + // Check the FourCC + SLANG_RETURN_ON_FAIL(consumeFourCC(stream, kSlangByteCodeFourCC)); + + // Check the version + uint32_t version; + size_t bytesRead = 0; + SLANG_RETURN_ON_FAIL(stream.read(&version, sizeof(version), bytesRead)); + if (version > kSlangByteCodeVersion) + { + return SLANG_FAIL; // Unsupported version + } + + // Read the function section + SLANG_RETURN_ON_FAIL(consumeFourCC(stream, kSlangByteCodeFunctionsFourCC)); + uint32_t functionSectionSize = 0; + SLANG_RETURN_ON_FAIL(readUInt32(stream, functionSectionSize)); + auto funcDataStart = stream.getPosition(); + if (functionSectionSize < sizeof(uint32_t)) // At least the function count + { + return SLANG_FAIL; // Invalid section size + } + + SLANG_RETURN_ON_FAIL(readUInt32(stream, moduleView->functionCount)); + moduleView->functionOffsets = reinterpret_cast<uint32_t*>(code + stream.getPosition()); + + stream.seek(SeekOrigin::Start, funcDataStart + functionSectionSize); + + // Read the kernel blob section + SLANG_RETURN_ON_FAIL(consumeFourCC(stream, kSlangByteCodeKernelBlobFourCC)); + SLANG_RETURN_ON_FAIL(readUInt32(stream, moduleView->kernelBlobSize)); + if (moduleView->kernelBlobSize > codeSize - stream.getPosition()) + { + return SLANG_FAIL; // Invalid kernel blob size + } + moduleView->kernelBlob = code + stream.getPosition(); + stream.seek(SeekOrigin::Current, moduleView->kernelBlobSize); + + // Read the constants section + SLANG_RETURN_ON_FAIL(consumeFourCC(stream, kSlangByteCodeConstantsFourCC)); + SLANG_RETURN_ON_FAIL(readUInt32(stream, moduleView->constantBlobSize)); + if (moduleView->constantBlobSize < sizeof(uint32_t)) // At least the constant count + { + return SLANG_FAIL; // Invalid section size + } + SLANG_RETURN_ON_FAIL(readUInt32(stream, moduleView->stringCount)); + moduleView->stringOffsets = reinterpret_cast<uint32_t*>(code + stream.getPosition()); + stream.seek(SeekOrigin::Current, moduleView->stringCount * sizeof(uint32_t)); + moduleView->constants = code + stream.getPosition(); + + for (uint32_t i = 0; i < moduleView->functionCount; i++) + { + auto functionStart = code + moduleView->functionOffsets[i]; + auto header = (VMFuncHeader*)(functionStart); + VMFunctionView functionView; + functionView.moduleView = moduleView; + functionView.header = (VMFuncHeader*)(functionStart); + functionView.paramOffsets = (uint32_t*)(functionStart + sizeof(VMFuncHeader)); + functionView.name = (const char*)moduleView->constants + + moduleView->stringOffsets[functionView.header->name.offset]; + functionView.functionCode = + (uint8_t*)functionView.paramOffsets + sizeof(uint32_t) * header->parameterCount; + functionView.functionCodeEnd = functionView.functionCode + functionView.header->codeSize; + moduleView->functionViews.add(functionView); + } + return SLANG_OK; +} + +StringBuilder& operator<<(StringBuilder& sb, VMOp op) +{ + switch (op) + { + case VMOp::Add: + sb << "add"; + break; + case VMOp::Sub: + sb << "sub"; + break; + case VMOp::Mul: + sb << "mul"; + break; + case VMOp::Div: + sb << "div"; + break; + case VMOp::Rem: + sb << "rem"; + break; + case VMOp::And: + sb << "and"; + break; + case VMOp::Or: + sb << "or"; + break; + case VMOp::BitXor: + sb << "bitxor"; + break; + case VMOp::BitNot: + sb << "bitnot"; + break; + case VMOp::Shl: + sb << "shl"; + break; + case VMOp::Shr: + sb << "shr"; + break; + case VMOp::Equal: + sb << "equal"; + break; + case VMOp::Neq: + sb << "neq"; + break; + case VMOp::Less: + sb << "less"; + break; + case VMOp::Leq: + sb << "leq"; + break; + case VMOp::Greater: + sb << "greater"; + break; + case VMOp::Geq: + sb << "geq"; + break; + case VMOp::Nop: + sb << "nop"; + break; + case VMOp::Neg: + sb << "neg"; + break; + case VMOp::Not: + sb << "not"; + break; + case VMOp::Jump: + sb << "jump"; + break; + case VMOp::JumpIf: + sb << "jumpif"; + break; + case VMOp::Dispatch: + sb << "dispatch"; + break; + case VMOp::Load: + sb << "load"; + break; + case VMOp::Store: + sb << "store"; + break; + case VMOp::Copy: + sb << "copy"; + break; + case VMOp::GetWorkingSetPtr: + sb << "get_working_set_ptr"; + break; + case VMOp::GetElementPtr: + sb << "get_element_ptr"; + break; + case VMOp::OffsetPtr: + sb << "offset_ptr"; + break; + case VMOp::GetElement: + sb << "get_element"; + break; + case VMOp::Cast: + sb << "cast"; + break; + case VMOp::CallExt: + sb << "call_ext"; + break; + case VMOp::Call: + sb << "call"; + break; + case VMOp::Swizzle: + sb << "swizzle"; + break; + case VMOp::Ret: + sb << "ret"; + break; + case VMOp::Print: + sb << "print"; + break; + default: + sb << "unknown_op(" << static_cast<uint32_t>(op) << ")"; + break; + } + return sb; +} + +StringBuilder& operator<<(StringBuilder& sb, ArithmeticExtCode extCode) +{ + switch (extCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + sb << "i"; + break; + case kSlangByteCodeScalarTypeUnsignedInt: + sb << "u"; + break; + case kSlangByteCodeScalarTypeFloat: + sb << "f"; + break; + default: + sb << "x"; + break; + } + sb << (8 << extCode.scalarBitWidth); + if (extCode.vectorSize > 1) + { + sb << "v" << extCode.vectorSize; + } + return sb; +} + +void printVMInst(StringBuilder& sb, VMModuleView* moduleView, VMInstHeader* inst) +{ + auto lenBeforeOpCode = sb.getLength(); + sb << inst->opcode; + if (inst->opcodeExtension != 0) + { + switch (inst->opcode) + { + case VMOp::Add: + case VMOp::Sub: + case VMOp::Mul: + case VMOp::Div: + case VMOp::Rem: + case VMOp::And: + case VMOp::Or: + case VMOp::BitXor: + case VMOp::BitNot: + case VMOp::BitAnd: + case VMOp::BitOr: + case VMOp::Neg: + case VMOp::Not: + case VMOp::Shl: + case VMOp::Shr: + case VMOp::Equal: + case VMOp::Neq: + case VMOp::Less: + case VMOp::Leq: + case VMOp::Greater: + case VMOp::Geq: + { + ArithmeticExtCode extCode; + memcpy(&extCode, &inst->opcodeExtension, sizeof(extCode)); + sb << "." << extCode; + } + break; + case VMOp::Cast: + { + ArithmeticExtCode extCode; + memcpy(&extCode, &inst->opcodeExtension, sizeof(extCode)); + sb << "." << extCode; + uint32_t fromCode = inst->opcodeExtension >> 16; + memcpy(&extCode, &fromCode, sizeof(extCode)); + sb << "." << extCode; + } + break; + default: + sb << "." << inst->opcodeExtension; + break; + } + } + auto opCodeLength = (int)(sb.getLength() - lenBeforeOpCode); + static const int kOpCodeColumnWidth = 20; + if (opCodeLength < kOpCodeColumnWidth) + { + for (int i = 0; i < kOpCodeColumnWidth - opCodeLength; i++) + { + sb << " "; + } + } + else + { + sb << " "; + } + for (uint32_t i = 0; i < inst->operandCount; i++) + { + if (i > 0) + sb << ", "; + auto operand = inst->getOperand(i); + switch (operand.sectionId) + { + case kSlangByteCodeSectionConstants: + switch (operand.getType()) + { + case OperandDataType::Int32: + { + int32_t val; + moduleView->getConstant<int32_t>(operand, val); + sb << "i32(" << val << ")"; + continue; + } + case OperandDataType::Int64: + { + int64_t val; + moduleView->getConstant<int64_t>(operand, val); + sb << "i64(" << val << ")"; + continue; + } + case OperandDataType::Float32: + { + float val; + moduleView->getConstant<float>(operand, val); + sb << "f32(" << val << ")"; + continue; + } + case OperandDataType::Float64: + { + double val; + moduleView->getConstant<double>(operand, val); + sb << "f32(" << val << ")"; + continue; + } + } + sb << "const:"; + break; + case kSlangByteCodeSectionInsts: + sb << "inst:"; + break; + case kSlangByteCodeSectionWorkingSet: + sb << "ws:"; + break; + case kSlangByteCodeSectionImmediate: + sb << "!"; + break; + case kSlangByteCodeSectionFuncs: + sb << moduleView->getFunction(operand.offset).name; + continue; + case kSlangByteCodeSectionStrings: + sb << "str:"; + if (operand.offset < moduleView->stringCount) + { + auto str = StringEscapeUtil::escapeString(UnownedStringSlice( + ((char*)moduleView->constants + moduleView->stringOffsets[operand.offset]))); + sb << str; + } + else + { + sb << "<invalid string index>"; + } + continue; + default: + sb << "section(" << operand.sectionId << ")@"; + break; + } + sb << String(inst->getOperand(i).offset, 16); + } +} + +StringBuilder& operator<<(StringBuilder& sb, VMModuleView& module) +{ + static const int addrColumnSize = 6; + for (uint32_t i = 0; i < module.functionCount; i++) + { + auto f = module.getFunction(i); + sb << "func " << f.name << ":\n"; + for (auto inst : f) + { + sb << " "; + auto loc = ((uint8_t*)inst - f.functionCode); + auto pos = sb.getLength(); + sb << String((uint32_t)loc, 16) << ": "; + auto addrLength = (int)(sb.getLength() - pos); + for (int j = 0; j < addrColumnSize - addrLength; j++) + { + sb << " "; + } + printVMInst(sb, &module, inst); + sb << "\n"; + } + } + return sb; +} + +VMFunctionView VMModuleView::getFunction(Index index) const +{ + if (index >= functionCount) + return {}; + return functionViews[index]; +} +} // namespace Slang diff --git a/source/slang/slang-vm-bytecode.h b/source/slang/slang-vm-bytecode.h new file mode 100644 index 000000000..5d030ee3a --- /dev/null +++ b/source/slang/slang-vm-bytecode.h @@ -0,0 +1,259 @@ +#ifndef SLANG_VM_BYTE_CODE_H +#define SLANG_VM_BYTE_CODE_H + +#include "core/slang-basic.h" +#include "core/slang-riff.h" +#include "slang-com-ptr.h" + +namespace Slang +{ + +/* +Slang ByteCode Module File Format + +# Header + - (4 bytes) FourCC: 'S', 'V', 'M', 'C' (kSlangByteCodeFourCC) + - (4 bytes uint) Version: 100 + +# Function Section + - (4 bytes) FourCC: 'S', 'V', 'F', 'N' (kSlangByteCodeFunctionsFourCC) + - (4 bytes uint) Function Count: number of functions in the module + - (uint array) Function Offsets: + array of "Function Count" 32-bit uints, storing byte offsets from + start of file for each function. + +## Function i: + - (32 bytes `VMFuncHeader`) Function metadata, describing the name and other + info needed for execution. VMFuncHeader::name is a `VMOperand` whose + sectionId = kSlangByteCodeSectionConstants, pointing to the constant section for the + function name. + - (uint32 * parameterCount array): array of "parameterCount" 32-bit uints, storing byte offsets + from start of the function's working set for each parameter. + - (byte array) Code: array of `header.codeSize` bytes, containing the instruction stream for the + function. Each instruction starts with a `VMInstHeader`, followed by + `VMInstHeader::operandCount` `VMOperand` structs. + +# Kernel Blob Section: binary data for the kernel blob + - (4 bytes) FourCC: 'S', 'V', 'K', 'N' (kSlangByteCodeKernelBlobFourCC) + - (4 bytes uint) Kernel Blob Size: size of the kernel blob in bytes. + - (byte array) Kernel Blob: array of "Kernel Blob Size" bytes, containing the kernel blob data. + +# Constants Section + - (4 bytes) FourCC: 'S', 'V', 'C', 'S' (kSlangByteCodeConstantsFourCC) + - (4 bytes uint) Constant Count: number of constants in the module. + - (4 bytes uint) String Count: number of string literals in the constant section. + - (uint array) String offsets: array of "String Count" 32-bit uints, storing byte offsets from + start of the constant array blob (next item) for each string literal. + - (uint array) Constants: + array of "Constant Count" 32-bit uints, storing byte offsets from + start of file for each constant. +*/ + +static const int kSlangByteCodeVersion = 100; + +static const uint32_t kSlangByteCodeFourCC = SLANG_FOUR_CC('S', 'V', 'M', 'C'); +static const uint32_t kSlangByteCodeFunctionsFourCC = SLANG_FOUR_CC('S', 'V', 'F', 'N'); +static const uint32_t kSlangByteCodeKernelBlobFourCC = SLANG_FOUR_CC('S', 'V', 'K', 'N'); +static const uint32_t kSlangByteCodeConstantsFourCC = SLANG_FOUR_CC('S', 'V', 'C', 'S'); + +static const int kSlangByteCodeSectionWorkingSet = 0; +static const int kSlangByteCodeSectionConstants = 1; +static const int kSlangByteCodeSectionInsts = 2; +static const int kSlangByteCodeSectionImmediate = 3; +static const int kSlangByteCodeSectionFuncs = 4; +static const int kSlangByteCodeSectionStrings = 5; + + +enum class VMOp : uint32_t +{ + Nop, + Add, + Sub, + Mul, + Div, + Rem, + Neg, + And, + Or, + Not, + BitAnd, + BitOr, + BitNot, + BitXor, + Shl, + Shr, + Ret, + Less, + Leq, + Greater, + Geq, + Equal, + Neq, + Jump, + JumpIf, + Dispatch, + Load, + Store, + Copy, + GetWorkingSetPtr, + GetElementPtr, + OffsetPtr, + GetElement, + Swizzle, + Cast, + CallExt, + Call, + Print, +}; + +// Represents an operand in the VM bytecode. +// It consists of a section ID and a byte offset within that section. +struct VMOperand +{ + uint32_t sectionId; + uint32_t padding; // Padding to ensure section takes 8 bytes. sectionId will be replaced + // with actual pointers before execution. + uint32_t type : 8; // type of the operand data. + uint32_t size : 24; + uint32_t offset; + slang::OperandDataType getType() const { return slang::OperandDataType(type); } + void setType(slang::OperandDataType newType) { type = uint32_t(newType); } +}; + +struct VMInstHeader +{ + VMOp opcode; + uint32_t padding; // 32-bit padding, to ensure space are reserved to store function pointers for + // the opcode. + uint32_t opcodeExtension; + uint32_t operandCount; + VMOperand& getOperand(Index index) const { return *((VMOperand*)(this + 1) + index); } +}; + +struct VMFuncHeader +{ + VMOperand name; // Name of the function as a VMOperand, pointing to the constant section. + uint32_t workingSetSizeInBytes; // Size of the working set in bytes. + uint32_t codeSize; // Size of the code in bytes. + uint32_t parameterCount; // Number of parameters for the function. + uint32_t returnValueSizeInBytes; // Size of the return value in bytes. + uint32_t parameterSizeInBytes; // Size of the parameters in bytes. + uint32_t getParameterOffset(Index index) const { return *((uint32_t*)(this + 1) + index); } + uint8_t* getCode() const { return (uint8_t*)(this + 1) + parameterCount * sizeof(uint32_t); } +}; + +static const int kSlangByteCodeScalarTypeSignedInt = 0; +static const int kSlangByteCodeScalarTypeUnsignedInt = 1; +static const int kSlangByteCodeScalarTypeFloat = 2; + +struct ArithmeticExtCode +{ + uint32_t scalarType : 2; // 0: signed int, 1: unsigned int, 2: floating-point + uint32_t scalarBitWidth : 2; // 0: 8, 1: 16, 2: 32, 3: 64 + uint32_t vectorSize : 12; // number of elements in the vector. + uint32_t unused : 16; +}; + +template<typename TOperand, typename TInstHeader> +struct VMInstIterator +{ + uint8_t* codePtr; // Pointer to the current instruction. + + void moveNext() + { + // Read the instruction header + TInstHeader header; + memcpy(&header, codePtr, sizeof(header)); + codePtr += sizeof(header); + + // Calculate the size of operand list. + auto operandListSize = header.operandCount * sizeof(TOperand); + + // Advance the code pointer by the size of the header and operand list. + codePtr += operandListSize; + } + + VMInstIterator& operator++() + { + moveNext(); + return *this; + } + VMInstIterator operator++(int) + { + VMInstIterator rs = *this; + rs.moveNext(); + return rs; + } + + bool operator!=(const VMInstIterator& iter) const { return codePtr != iter.codePtr; } + bool operator==(const VMInstIterator& iter) const { return codePtr == iter.codePtr; } + TInstHeader* operator*() const { return reinterpret_cast<TInstHeader*>(codePtr); } +}; + +struct VMModuleView; + +struct VMFunctionView +{ + const char* name = nullptr; + VMFuncHeader* header; // Function header containing metadata. + uint32_t* paramOffsets; + uint8_t* functionCode; // Pointer to the function code. + uint8_t* functionCodeEnd; // Pointer to the end of the function code. + VMModuleView* moduleView; // Pointer to start of the module. + VMInstIterator<VMOperand, VMInstHeader> begin() const + { + VMInstIterator<VMOperand, VMInstHeader> iter; + iter.codePtr = functionCode; + return iter; + } + + VMInstIterator<VMOperand, VMInstHeader> end() const + { + VMInstIterator<VMOperand, VMInstHeader> iter; + iter.codePtr = functionCodeEnd; + return iter; + } +}; + +struct VMModuleView +{ + uint8_t* code; + uint32_t functionCount; + uint32_t* functionOffsets; + uint8_t* constants; + uint32_t constantBlobSize; + uint32_t stringCount; + uint32_t* stringOffsets; // Offsets to string literals in the constant section. + uint8_t* kernelBlob; + uint32_t kernelBlobSize; + + List<VMFunctionView> functionViews; + + VMFunctionView getFunction(Index index) const; + + template<typename T> + SlangResult getConstant(VMOperand operand, T& outValue) const + { + if (operand.sectionId != kSlangByteCodeSectionConstants) + { + return SLANG_FAIL; // Invalid section + } + if (operand.offset + sizeof(T) > constantBlobSize) + { + return SLANG_FAIL; // Out of bounds + } + memcpy(&outValue, constants + operand.offset, sizeof(T)); + return SLANG_OK; + } +}; + +SlangResult initVMModule(uint8_t* code, uint32_t codeSize, VMModuleView* moduleView); + +StringBuilder& operator<<(StringBuilder& sb, VMOp op); +StringBuilder& operator<<(StringBuilder& sb, VMModuleView& module); +void printVMInst(StringBuilder& sb, VMModuleView* moduleView, VMInstHeader* inst); + +} // namespace Slang + + +#endif diff --git a/source/slang/slang-vm-inst-impl.cpp b/source/slang/slang-vm-inst-impl.cpp new file mode 100644 index 000000000..304ea09a8 --- /dev/null +++ b/source/slang/slang-vm-inst-impl.cpp @@ -0,0 +1,1066 @@ +#include "slang-vm-inst-impl.h" + +#include "slang-vm.h" + +using namespace slang; + +namespace Slang +{ +ByteCodeInterpreter* convert(IByteCodeRunner* runner) +{ + return static_cast<ByteCodeInterpreter*>(runner); +} + +#define SIMPLE_BINARY_SCALAR_FUNC(name, op) \ + struct name##ScalarFunc \ + { \ + template<typename TR, typename T1, typename T2> \ + static void run(TR* dst, const T1* src1, const T2* src2) \ + { \ + *dst = (*src1)op(*src2); \ + } \ + } + +SIMPLE_BINARY_SCALAR_FUNC(Add, +); +SIMPLE_BINARY_SCALAR_FUNC(Sub, -); +SIMPLE_BINARY_SCALAR_FUNC(Mul, *); +SIMPLE_BINARY_SCALAR_FUNC(Div, /); +SIMPLE_BINARY_SCALAR_FUNC(And, &&); +SIMPLE_BINARY_SCALAR_FUNC(Or, ||); +SIMPLE_BINARY_SCALAR_FUNC(BitAnd, &); +SIMPLE_BINARY_SCALAR_FUNC(BitOr, |); +SIMPLE_BINARY_SCALAR_FUNC(BitXor, ^); +SIMPLE_BINARY_SCALAR_FUNC(Shl, <<); +SIMPLE_BINARY_SCALAR_FUNC(Shr, >>); +SIMPLE_BINARY_SCALAR_FUNC(Less, <); +SIMPLE_BINARY_SCALAR_FUNC(Leq, <=); +SIMPLE_BINARY_SCALAR_FUNC(Greater, >); +SIMPLE_BINARY_SCALAR_FUNC(Geq, >=); +SIMPLE_BINARY_SCALAR_FUNC(Equal, ==); +SIMPLE_BINARY_SCALAR_FUNC(Neq, !=); + +template<typename TR, typename T1, typename T2> +void scalarMod(TR* dst, const T1* src1, const T2* src2) +{ + *dst = *src1 % *src2; +} + +template<> +void scalarMod<float, float, float>(float* dst, const float* src1, const float* src2) +{ + *dst = fmodf(*src1, *src2); +} + +template<> +void scalarMod<double, double, double>(double* dst, const double* src1, const double* src2) +{ + *dst = fmod(*src1, *src2); +} + +struct ModScalarFunc +{ + template<typename TR, typename T1, typename T2> + static void run(TR* dst, const T1* src1, const T2* src2) + { + scalarMod<TR, T1, T2>(dst, src1, src2); + } +}; + +#define SIMPLE_UNARY_SCALAR_FUNC(name, op) \ + struct name##ScalarFunc \ + { \ + template<typename TR, typename T1> \ + static void run(TR* dst, const T1* src1) \ + { \ + *dst = op(*src1); \ + } \ + } +SIMPLE_UNARY_SCALAR_FUNC(Neg, -); +SIMPLE_UNARY_SCALAR_FUNC(Not, !); +SIMPLE_UNARY_SCALAR_FUNC(BitNot, ~); + +template<typename ScalarFunc, typename TR, typename T1, typename T2, int elementCount> +struct BinaryVectorFunc +{ + static void run(IByteCodeRunner* context, VMExecInstHeader* inst, void* userData) + { + SLANG_UNUSED(context); + SLANG_UNUSED(userData); + TR* dst = (TR*)inst->getOperand(0).getPtr(); + T1* src1 = (T1*)inst->getOperand(1).getPtr(); + T2* src2 = (T2*)inst->getOperand(2).getPtr(); + for (int i = 0; i < elementCount; ++i) + { + ScalarFunc::template run<TR, T1, T2>(&dst[i], &src1[i], &src2[i]); + } + } +}; + +template<typename ScalarFunc, typename TR, typename T1, typename T2> +struct GeneralBinaryVectorFunc +{ + static void run(IByteCodeRunner* context, VMExecInstHeader* inst, void* userData) + { + SLANG_UNUSED(context); + SLANG_UNUSED(userData); + TR* dst = (TR*)inst->getOperand(0).getPtr(); + T1* src1 = (T1*)inst->getOperand(1).getPtr(); + T2* src2 = (T2*)inst->getOperand(2).getPtr(); + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &inst->opcodeExtension, sizeof(arithExtCode)); + for (uint32_t i = 0; i < arithExtCode.vectorSize; ++i) + { + ScalarFunc::template run<TR, T1, T2>(&dst[i], &src1[i], &src2[i]); + } + } +}; + +template<typename Func, typename TR, typename T1 = TR, typename T2 = TR> +VMExtFunction binaryArithmeticInstHandler(int elementCount) +{ + switch (elementCount) + { + case 0: + case 1: + return BinaryVectorFunc<Func, TR, T1, T2, 1>::run; + case 2: + return BinaryVectorFunc<Func, TR, T1, T2, 2>::run; + case 3: + return BinaryVectorFunc<Func, TR, T1, T2, 3>::run; + case 4: + return BinaryVectorFunc<Func, TR, T1, T2, 4>::run; + case 6: + return BinaryVectorFunc<Func, TR, T1, T2, 6>::run; + case 8: + return BinaryVectorFunc<Func, TR, T1, T2, 8>::run; + case 9: + return BinaryVectorFunc<Func, TR, T1, T2, 9>::run; + case 10: + return BinaryVectorFunc<Func, TR, T1, T2, 10>::run; + case 12: + return BinaryVectorFunc<Func, TR, T1, T2, 12>::run; + case 16: + return BinaryVectorFunc<Func, TR, T1, T2, 16>::run; + default: + return GeneralBinaryVectorFunc<Func, TR, T1, T2>::run; + } +} + +template<typename Func> +VMExtFunction binaryArithmeticInstHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return binaryArithmeticInstHandler<Func, int8_t>(arithExtCode.vectorSize); + case 1: + return binaryArithmeticInstHandler<Func, int16_t>(arithExtCode.vectorSize); + case 2: + return binaryArithmeticInstHandler<Func, int32_t>(arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, int64_t>(arithExtCode.vectorSize); + } + case kSlangByteCodeScalarTypeUnsignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return binaryArithmeticInstHandler<Func, uint8_t>(arithExtCode.vectorSize); + case 1: + return binaryArithmeticInstHandler<Func, uint16_t>(arithExtCode.vectorSize); + case 2: + return binaryArithmeticInstHandler<Func, uint32_t>(arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, uint64_t>(arithExtCode.vectorSize); + } + case kSlangByteCodeScalarTypeFloat: + switch (arithExtCode.scalarBitWidth) + { + case 2: + return binaryArithmeticInstHandler<Func, float>(arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, double>(arithExtCode.vectorSize); + default: + return nullptr; // Unsupported scalar bit width + } + } + return nullptr; +} + +template<typename Func> +VMExtFunction binaryArithmeticLogicalInstHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarBitWidth) + { + case 0: + return binaryArithmeticInstHandler<Func, uint8_t>(arithExtCode.vectorSize); + case 1: + return binaryArithmeticInstHandler<Func, uint16_t>(arithExtCode.vectorSize); + case 2: + return binaryArithmeticInstHandler<Func, uint32_t>(arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, uint64_t>(arithExtCode.vectorSize); + } + return nullptr; +} + +template<typename Func> +VMExtFunction binaryArithmeticIntInstHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return binaryArithmeticInstHandler<Func, int8_t>(arithExtCode.vectorSize); + case 1: + return binaryArithmeticInstHandler<Func, int16_t>(arithExtCode.vectorSize); + case 2: + return binaryArithmeticInstHandler<Func, int32_t>(arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, int64_t>(arithExtCode.vectorSize); + } + case kSlangByteCodeScalarTypeUnsignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return binaryArithmeticInstHandler<Func, uint8_t>(arithExtCode.vectorSize); + case 1: + return binaryArithmeticInstHandler<Func, uint16_t>(arithExtCode.vectorSize); + case 2: + return binaryArithmeticInstHandler<Func, uint32_t>(arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, uint64_t>(arithExtCode.vectorSize); + } + } + return nullptr; +} + +template<typename Func> +VMExtFunction binaryArithmeticCompareInstHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return binaryArithmeticInstHandler<Func, uint32_t, int8_t, int8_t>( + arithExtCode.vectorSize); + case 1: + return binaryArithmeticInstHandler<Func, uint32_t, int16_t, int16_t>( + arithExtCode.vectorSize); + case 2: + return binaryArithmeticInstHandler<Func, uint32_t, int32_t, int32_t>( + arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, uint32_t, int64_t, int64_t>( + arithExtCode.vectorSize); + } + case kSlangByteCodeScalarTypeUnsignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return binaryArithmeticInstHandler<Func, uint32_t, uint8_t, uint8_t>( + arithExtCode.vectorSize); + case 1: + return binaryArithmeticInstHandler<Func, uint32_t, uint16_t, uint16_t>( + arithExtCode.vectorSize); + case 2: + return binaryArithmeticInstHandler<Func, uint32_t, uint32_t, uint32_t>( + arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, uint32_t, uint64_t, uint64_t>( + arithExtCode.vectorSize); + } + case kSlangByteCodeScalarTypeFloat: + switch (arithExtCode.scalarBitWidth) + { + case 2: + return binaryArithmeticInstHandler<Func, uint32_t, float, float>( + arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, uint32_t, double, double>( + arithExtCode.vectorSize); + default: + return nullptr; // Unsupported scalar bit width + } + } + return nullptr; +} + +//////// +template<typename ScalarFunc, typename TR, typename T1, int elementCount> +struct UnaryVectorFunc +{ + static void run(IByteCodeRunner* context, VMExecInstHeader* inst, void* userData) + { + SLANG_UNUSED(context); + SLANG_UNUSED(userData); + TR* dst = (TR*)inst->getOperand(0).getPtr(); + T1* src1 = (T1*)inst->getOperand(1).getPtr(); + for (int i = 0; i < elementCount; ++i) + { + ScalarFunc::template run<TR, T1>(&dst[i], &src1[i]); + } + } +}; + +template<typename ScalarFunc, typename TR, typename T1> +struct GeneralUnaryVectorFunc +{ + static void run(IByteCodeRunner* context, VMExecInstHeader* inst, void* userData) + { + SLANG_UNUSED(context); + SLANG_UNUSED(userData); + TR* dst = (TR*)inst->getOperand(0).getPtr(); + T1* src1 = (T1*)inst->getOperand(1).getPtr(); + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &inst->opcodeExtension, sizeof(arithExtCode)); + for (uint32_t i = 0; i < arithExtCode.vectorSize; ++i) + { + ScalarFunc::template run<TR, T1>(&dst[i], &src1[i]); + } + } +}; + +template<typename Func, typename TR, typename T1 = TR> +VMExtFunction unaryArithmeticInstHandler(int elementCount) +{ + switch (elementCount) + { + case 0: + case 1: + return UnaryVectorFunc<Func, TR, T1, 1>::run; + case 2: + return UnaryVectorFunc<Func, TR, T1, 2>::run; + case 3: + return UnaryVectorFunc<Func, TR, T1, 3>::run; + case 4: + return UnaryVectorFunc<Func, TR, T1, 4>::run; + case 6: + return UnaryVectorFunc<Func, TR, T1, 6>::run; + case 8: + return UnaryVectorFunc<Func, TR, T1, 8>::run; + case 9: + return UnaryVectorFunc<Func, TR, T1, 9>::run; + case 10: + return UnaryVectorFunc<Func, TR, T1, 10>::run; + case 12: + return UnaryVectorFunc<Func, TR, T1, 12>::run; + case 16: + return UnaryVectorFunc<Func, TR, T1, 16>::run; + default: + return GeneralUnaryVectorFunc<Func, TR, T1>::run; + } +} + +template<typename Func> +VMExtFunction unaryArithmeticLogicalInstHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarBitWidth) + { + case 0: + return unaryArithmeticInstHandler<Func, uint8_t>(arithExtCode.vectorSize); + case 1: + return unaryArithmeticInstHandler<Func, uint16_t>(arithExtCode.vectorSize); + case 2: + return unaryArithmeticInstHandler<Func, uint32_t>(arithExtCode.vectorSize); + case 3: + return unaryArithmeticInstHandler<Func, uint64_t>(arithExtCode.vectorSize); + } + return nullptr; +} + +template<typename Func> +VMExtFunction unaryArithmeticIntInstHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return unaryArithmeticInstHandler<Func, int8_t>(arithExtCode.vectorSize); + case 1: + return unaryArithmeticInstHandler<Func, int16_t>(arithExtCode.vectorSize); + case 2: + return unaryArithmeticInstHandler<Func, int32_t>(arithExtCode.vectorSize); + case 3: + return unaryArithmeticInstHandler<Func, int64_t>(arithExtCode.vectorSize); + } + case kSlangByteCodeScalarTypeUnsignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return unaryArithmeticInstHandler<Func, uint8_t>(arithExtCode.vectorSize); + case 1: + return unaryArithmeticInstHandler<Func, uint16_t>(arithExtCode.vectorSize); + case 2: + return unaryArithmeticInstHandler<Func, uint32_t>(arithExtCode.vectorSize); + case 3: + return unaryArithmeticInstHandler<Func, uint64_t>(arithExtCode.vectorSize); + } + } + return nullptr; +} + +template<typename Func> +VMExtFunction negInstHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + case kSlangByteCodeScalarTypeUnsignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return unaryArithmeticInstHandler<Func, int8_t>(arithExtCode.vectorSize); + case 1: + return unaryArithmeticInstHandler<Func, int16_t>(arithExtCode.vectorSize); + case 2: + return unaryArithmeticInstHandler<Func, int32_t>(arithExtCode.vectorSize); + case 3: + return unaryArithmeticInstHandler<Func, int64_t>(arithExtCode.vectorSize); + } + case kSlangByteCodeScalarTypeFloat: + switch (arithExtCode.scalarBitWidth) + { + case 2: + return unaryArithmeticInstHandler<Func, float>(arithExtCode.vectorSize); + case 3: + return unaryArithmeticInstHandler<Func, double>(arithExtCode.vectorSize); + default: + return nullptr; // Unsupported scalar bit width + } + } + return nullptr; +} + +static void nopHandler(IByteCodeRunner*, VMExecInstHeader*, void*) {} + +void callHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void*) +{ + auto ctx = convert(inCtx); + auto funcId = inst->getOperand(1).offset; + auto& func = ctx->m_functions[funcId]; + auto funcHeader = func.m_header; + + // Alloc working set. + ctx->pushFrame(funcHeader->workingSetSizeInBytes); + + // Save current instruction pointer. + auto& stackFrame = ctx->m_stack.getLast(); + stackFrame.m_currentInst = inst; + stackFrame.m_currentFuncCode = ctx->m_currentFuncCode; + auto newWorkingSetPtr = (uint8_t*)ctx->m_currentWorkingSet; + auto callerWorkingSetPtr = + (uint8_t*)(ctx->m_workingSetBuffer.getBuffer() + stackFrame.m_workingSetOffset); + + // Set working set pointer to the caller's working set. + ctx->m_currentWorkingSet = callerWorkingSetPtr; + + // Copy arguments to the callee's working set. + for (uint32_t i = 0; i < funcHeader->parameterCount; ++i) + { + auto dst = newWorkingSetPtr + func.m_parameterOffsets[i]; + auto src = (uint8_t*)inst->getOperand(i + 2).getPtr(); + + // func.m_parameterOffsets should be initialized to contain parameterCount+1 elements, + // where the last element is the total size of the parameters. + auto nextParamOffset = func.m_parameterOffsets[i + 1]; + memcpy(dst, src, nextParamOffset - func.m_parameterOffsets[i]); + } + ctx->m_currentWorkingSet = newWorkingSetPtr; + ctx->m_currentFuncCode = func.m_codeBuffer.getBuffer(); + ctx->m_currentInst = (VMExecInstHeader*)func.m_codeBuffer.getBuffer(); +} + +static void retHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void*) +{ + auto ctx = convert(inCtx); + if (inst->opcodeExtension != 0) + { + void* resultPtr = nullptr; + if (ctx->m_stack.getCount()) + { + auto callInst = ctx->m_stack.getLast().m_currentInst; + auto callerWorkingSetPtr = (uint8_t*)(ctx->m_workingSetBuffer.getBuffer() + + ctx->m_stack.getLast().m_workingSetOffset); + resultPtr = callerWorkingSetPtr + callInst->getOperand(0).offset; + } + else + { + // If there is no stack frame, we assume the result is stored in the return register. + ctx->m_returnRegister.setCount(inst->opcodeExtension); + resultPtr = ctx->m_returnRegister.getBuffer(); + ctx->m_returnValSize = inst->opcodeExtension; + } + memcpy(resultPtr, inst->getOperand(0).getPtr(), inst->opcodeExtension); + } + + // If we are returning from a main function, there is nothing to pop from the stack frame, + // and we should stop execution. + if (ctx->m_stack.getCount() == 0) + { + ctx->m_currentInst = nullptr; + return; + } + + // Pop the working set. + ctx->popFrame(); +} + +static void jumpHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void*) +{ + auto ctx = convert(inCtx); + ctx->m_currentInst = (VMExecInstHeader*)inst->getOperand(0).getPtr(); +} + +static void jumpIfHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void*) +{ + auto ctx = convert(inCtx); + + auto cond = *(uint32_t*)inst->getOperand(0).getPtr(); + if (cond) + { + ctx->m_currentInst = (VMExecInstHeader*)inst->getOperand(1).getPtr(); + } + else + { + ctx->m_currentInst = (VMExecInstHeader*)inst->getOperand(2).getPtr(); + } +} + +static void getWorkingSetPtrHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void*) +{ + auto ctx = convert(inCtx); + auto dst = (void**)inst->getOperand(0).getPtr(); + auto ptr = (uint8_t*)ctx->m_currentWorkingSet + inst->opcodeExtension; + *dst = ptr; +} + +static void getElementPtrHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (void**)inst->getOperand(0).getPtr(); + auto basePtr = *(uint8_t**)inst->getOperand(1).getPtr(); + auto elementIndex = *(uint32_t*)inst->getOperand(2).getPtr(); + *dst = (uint8_t*)basePtr + elementIndex * inst->opcodeExtension; +} + +static void getElementHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (void*)inst->getOperand(0).getPtr(); + auto basePtr = (uint8_t*)inst->getOperand(1).getPtr(); + auto elementIndex = *(uint32_t*)inst->getOperand(2).getPtr(); + memcpy(dst, basePtr + elementIndex * inst->opcodeExtension, inst->opcodeExtension); +} + +static void offsetPtrHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (void**)inst->getOperand(0).getPtr(); + auto basePtr = *(uint8_t**)inst->getOperand(1).getPtr(); + auto offset = *(int32_t*)inst->getOperand(2).getPtr(); + *dst = basePtr + offset * inst->opcodeExtension; +} + +void loadHandler8(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint8_t*)inst->getOperand(0).getPtr(); + auto src = *(uint8_t**)inst->getOperand(1).getPtr(); + *dst = *src; +} +void loadHandler16(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint16_t*)inst->getOperand(0).getPtr(); + auto src = *(uint16_t**)inst->getOperand(1).getPtr(); + *dst = *src; +} +void loadHandler32(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint32_t*)inst->getOperand(0).getPtr(); + auto src = *(uint32_t**)inst->getOperand(1).getPtr(); + *dst = *src; +} +void loadHandler64(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint64_t*)inst->getOperand(0).getPtr(); + auto src = *(uint64_t**)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void generalLoadHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint8_t*)inst->getOperand(0).getPtr(); + auto src = *(uint8_t**)inst->getOperand(1).getPtr(); + memcpy(dst, src, inst->opcodeExtension); +} + +VMExtFunction getLoadHandler(uint32_t extCode) +{ + switch (extCode) + { + case 1: + return loadHandler8; + case 2: + return loadHandler16; + case 4: + return loadHandler32; + case 8: + return loadHandler64; + default: + return generalLoadHandler; + } +} + +void storeHandler8(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = *(uint8_t**)inst->getOperand(0).getPtr(); + auto src = (uint8_t*)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void storeHandler16(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = *(uint16_t**)inst->getOperand(0).getPtr(); + auto src = (uint16_t*)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void storeHandler32(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = *(uint32_t**)inst->getOperand(0).getPtr(); + auto src = (uint32_t*)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void storeHandler64(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = *(uint64_t**)inst->getOperand(0).getPtr(); + auto src = (uint64_t*)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void generalStoreHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = *(uint8_t**)inst->getOperand(0).getPtr(); + auto src = (uint8_t*)inst->getOperand(1).getPtr(); + memcpy(dst, src, inst->opcodeExtension); +} + +VMExtFunction getStoreHandler(uint32_t extCode) +{ + switch (extCode) + { + case 1: + return storeHandler8; + case 2: + return storeHandler16; + case 4: + return storeHandler32; + case 8: + return storeHandler64; + default: + return generalStoreHandler; + } +} + +void copyHandler8(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint8_t*)inst->getOperand(0).getPtr(); + auto src = (uint8_t*)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void copyHandler16(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint16_t*)inst->getOperand(0).getPtr(); + auto src = (uint16_t*)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void copyHandler32(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint32_t*)inst->getOperand(0).getPtr(); + auto src = (uint32_t*)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void copyHandler64(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint64_t*)inst->getOperand(0).getPtr(); + auto src = (uint64_t*)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void generalCopyHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint8_t*)inst->getOperand(0).getPtr(); + auto src = (uint8_t*)inst->getOperand(1).getPtr(); + memcpy(dst, src, inst->opcodeExtension); +} + +VMExtFunction getCopyHandler(uint32_t extCode) +{ + switch (extCode) + { + case 1: + return copyHandler8; + case 2: + return copyHandler16; + case 4: + return copyHandler32; + case 8: + return copyHandler64; + default: + return generalCopyHandler; + } +} + +template<typename T> +void swizzleHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void* userData) +{ + SLANG_UNUSED(ctx); + SLANG_UNUSED(userData); + auto dst = (T*)inst->getOperand(0).getPtr(); + auto src = (T*)inst->getOperand(1).getPtr(); + for (uint32_t i = 2; i < inst->operandCount; ++i) + { + dst[i - 2] = src[inst->getOperand(i).offset]; + } +} + +VMExtFunction getSwizzleHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return swizzleHandler<int8_t>; + case 1: + return swizzleHandler<int16_t>; + case 2: + return swizzleHandler<int32_t>; + case 3: + return swizzleHandler<int64_t>; + } + case kSlangByteCodeScalarTypeUnsignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return swizzleHandler<uint8_t>; + case 1: + return swizzleHandler<uint16_t>; + case 2: + return swizzleHandler<uint32_t>; + case 3: + return swizzleHandler<uint64_t>; + } + case kSlangByteCodeScalarTypeFloat: + switch (arithExtCode.scalarBitWidth) + { + case 2: + return swizzleHandler<float>; + case 3: + return swizzleHandler<double>; + default: + return nullptr; // Unsupported scalar bit width + } + } + return nullptr; +} + +template<typename To, typename From, int vectorSize> +void castHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + To* dst = (To*)inst->getOperand(0).getPtr(); + From* src = (From*)inst->getOperand(1).getPtr(); + for (int i = 0; i < vectorSize; ++i) + { + dst[i] = static_cast<To>(src[i]); + } +} + +template<typename From, int vectorSize> +VMExtFunction getCastHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return castHandler<uint8_t, From, vectorSize>; + case 1: + return castHandler<uint16_t, From, vectorSize>; + case 2: + return castHandler<uint32_t, From, vectorSize>; + case 3: + return castHandler<uint64_t, From, vectorSize>; + } + case kSlangByteCodeScalarTypeUnsignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return castHandler<uint8_t, From, vectorSize>; + case 1: + return castHandler<uint16_t, From, vectorSize>; + case 2: + return castHandler<uint32_t, From, vectorSize>; + case 3: + return castHandler<uint64_t, From, vectorSize>; + } + case kSlangByteCodeScalarTypeFloat: + switch (arithExtCode.scalarBitWidth) + { + case 2: + return castHandler<float, From, vectorSize>; + case 3: + return castHandler<double, From, vectorSize>; + default: + return nullptr; // Unsupported scalar bit width + } + } + return nullptr; +} + +template<int vectorSize> +VMExtFunction getCastHandler(uint32_t extCode) +{ + uint32_t fromExtCode = extCode >> 16; + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &fromExtCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return getCastHandler<uint8_t, vectorSize>(extCode); + case 1: + return getCastHandler<uint16_t, vectorSize>(extCode); + case 2: + return getCastHandler<uint32_t, vectorSize>(extCode); + case 3: + return getCastHandler<uint64_t, vectorSize>(extCode); + } + case kSlangByteCodeScalarTypeUnsignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return getCastHandler<uint8_t, vectorSize>(extCode); + case 1: + return getCastHandler<uint16_t, vectorSize>(extCode); + case 2: + return getCastHandler<uint32_t, vectorSize>(extCode); + case 3: + return getCastHandler<uint64_t, vectorSize>(extCode); + } + case kSlangByteCodeScalarTypeFloat: + switch (arithExtCode.scalarBitWidth) + { + case 2: + return getCastHandler<float, vectorSize>(extCode); + case 3: + return getCastHandler<double, vectorSize>(extCode); + default: + return nullptr; // Unsupported scalar bit width + } + } + return nullptr; +} + +VMExtFunction getCastHandler(uint32_t extCode) +{ + uint32_t fromExtCode = extCode >> 16; + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &fromExtCode, sizeof(arithExtCode)); + switch (arithExtCode.vectorSize) + { + case 0: + case 1: + return getCastHandler<1>(extCode); + case 2: + return getCastHandler<2>(extCode); + case 3: + return getCastHandler<3>(extCode); + case 4: + return getCastHandler<4>(extCode); + case 6: + return getCastHandler<6>(extCode); + case 8: + return getCastHandler<8>(extCode); + case 9: + return getCastHandler<9>(extCode); + case 12: + return getCastHandler<12>(extCode); + case 16: + return getCastHandler<16>(extCode); + } + return nullptr; +} + +void printHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void* userData) +{ + auto ctx = convert(inCtx); + SLANG_UNUSED(userData); + const char* formatString = nullptr; + formatString = *(const char**)inst->getOperand(0).getPtr(); + + List<List<uint8_t>> args; + List<const void*> argPtrs; + for (uint32_t i = 1; i < inst->operandCount; ++i) + { + auto& arg = inst->getOperand(i); + List<uint8_t> data; + data.setCount(arg.size); + memcpy(data.getBuffer(), arg.getPtr(), arg.size); + args.add(data); + } + for (auto& arg : args) + { + argPtrs.add(arg.getBuffer()); + } + auto result = + StringUtil::makeStringWithFormatFromArgArray(formatString, argPtrs.getArrayView()); + ctx->m_printCallback(result.getBuffer(), ctx->m_printCallbackUserData); +} + + +VMExtFunction mapInstToFunction( + VMInstHeader* instHeader, + VMModuleView* module, + Dictionary<String, slang::VMExtFunction>& extInstHandlers) +{ + switch (instHeader->opcode) + { + case VMOp::Nop: + return nopHandler; + case VMOp::Add: + return binaryArithmeticInstHandler<AddScalarFunc>(instHeader->opcodeExtension); + case VMOp::Sub: + return binaryArithmeticInstHandler<SubScalarFunc>(instHeader->opcodeExtension); + case VMOp::Mul: + return binaryArithmeticInstHandler<MulScalarFunc>(instHeader->opcodeExtension); + case VMOp::Div: + return binaryArithmeticInstHandler<DivScalarFunc>(instHeader->opcodeExtension); + case VMOp::Rem: + return binaryArithmeticInstHandler<ModScalarFunc>(instHeader->opcodeExtension); + case VMOp::And: + return binaryArithmeticLogicalInstHandler<AndScalarFunc>(instHeader->opcodeExtension); + case VMOp::Or: + return binaryArithmeticLogicalInstHandler<OrScalarFunc>(instHeader->opcodeExtension); + case VMOp::BitAnd: + return binaryArithmeticLogicalInstHandler<BitAndScalarFunc>(instHeader->opcodeExtension); + case VMOp::BitOr: + return binaryArithmeticLogicalInstHandler<BitOrScalarFunc>(instHeader->opcodeExtension); + case VMOp::BitXor: + return binaryArithmeticLogicalInstHandler<BitXorScalarFunc>(instHeader->opcodeExtension); + case VMOp::Shl: + return binaryArithmeticIntInstHandler<ShlScalarFunc>(instHeader->opcodeExtension); + case VMOp::Shr: + return binaryArithmeticIntInstHandler<ShrScalarFunc>(instHeader->opcodeExtension); + case VMOp::Less: + return binaryArithmeticCompareInstHandler<LessScalarFunc>(instHeader->opcodeExtension); + case VMOp::Leq: + return binaryArithmeticCompareInstHandler<LeqScalarFunc>(instHeader->opcodeExtension); + case VMOp::Greater: + return binaryArithmeticCompareInstHandler<GreaterScalarFunc>(instHeader->opcodeExtension); + case VMOp::Geq: + return binaryArithmeticCompareInstHandler<GeqScalarFunc>(instHeader->opcodeExtension); + case VMOp::Equal: + return binaryArithmeticCompareInstHandler<EqualScalarFunc>(instHeader->opcodeExtension); + case VMOp::Neq: + return binaryArithmeticCompareInstHandler<NeqScalarFunc>(instHeader->opcodeExtension); + case VMOp::Neg: + return negInstHandler<NegScalarFunc>(instHeader->opcodeExtension); + case VMOp::Not: + return unaryArithmeticLogicalInstHandler<NotScalarFunc>(instHeader->opcodeExtension); + case VMOp::BitNot: + return unaryArithmeticIntInstHandler<BitNotScalarFunc>(instHeader->opcodeExtension); + case VMOp::Ret: + return retHandler; + case VMOp::Call: + return callHandler; + case VMOp::Jump: + return jumpHandler; + case VMOp::JumpIf: + return jumpIfHandler; + case VMOp::Load: + return getLoadHandler(instHeader->opcodeExtension); + case VMOp::Store: + return getStoreHandler(instHeader->opcodeExtension); + case VMOp::Copy: + return getCopyHandler(instHeader->opcodeExtension); + case VMOp::GetWorkingSetPtr: + return getWorkingSetPtrHandler; + case VMOp::GetElementPtr: + return getElementPtrHandler; + case VMOp::OffsetPtr: + return offsetPtrHandler; + case VMOp::GetElement: + return getElementHandler; + case VMOp::Swizzle: + return getSwizzleHandler(instHeader->opcodeExtension); + case VMOp::Cast: + return getCastHandler(instHeader->opcodeExtension); + case VMOp::CallExt: + { + if (instHeader->getOperand(0).offset >= module->stringCount) + return nullptr; + auto funcName = (const char*)module->constants + + module->stringOffsets[instHeader->getOperand(0).offset]; + VMExtFunction handler = nullptr; + if (!extInstHandlers.tryGetValue(funcName, handler)) + return nullptr; + return handler; + } + case VMOp::Print: + return printHandler; + } + return VMExtFunction(); +} + +} // namespace Slang diff --git a/source/slang/slang-vm-inst-impl.h b/source/slang/slang-vm-inst-impl.h new file mode 100644 index 000000000..4331a79ec --- /dev/null +++ b/source/slang/slang-vm-inst-impl.h @@ -0,0 +1,16 @@ +#ifndef SLANG_VM_INST_IMPL_H +#define SLANG_VM_INST_IMPL_H + +#include "slang-vm-bytecode.h" + +namespace Slang +{ + +slang::VMExtFunction mapInstToFunction( + VMInstHeader* instHeader, + VMModuleView* module, + Dictionary<String, slang::VMExtFunction>& extInstHandlers); + +} // namespace Slang + +#endif diff --git a/source/slang/slang-vm.cpp b/source/slang/slang-vm.cpp new file mode 100644 index 000000000..05d53acb8 --- /dev/null +++ b/source/slang/slang-vm.cpp @@ -0,0 +1,254 @@ +#include "slang-vm.h" + +#include "core/slang-blob.h" +#include "slang-vm-inst-impl.h" + +namespace Slang +{ + +// Our VM insts need to be 8-byte aligned, so we can replace the opcode with function pointers and +// sectionId with data pointers. +static_assert(sizeof(VMOperand) % 8 == 0); +static_assert(sizeof(VMInstHeader) % 8 == 0); +static_assert(sizeof(VMOperand) == sizeof(VMExecOperand)); +static_assert(sizeof(VMInstHeader) == sizeof(VMExecInstHeader)); + +ISlangUnknown* ByteCodeInterpreter::getInterface(const Guid& guid) +{ + if (guid == ISlangUnknown::getTypeGuid() || guid == IByteCodeRunner::getTypeGuid()) + return static_cast<IByteCodeRunner*>(this); + + return nullptr; +} + +SlangResult ByteCodeInterpreter::prepareModuleForExecution() +{ + m_stringLits.clear(); + m_stringLits.setCount(m_moduleView.stringCount); + for (uint32_t i = 0; i < m_moduleView.stringCount; i++) + { + auto strOffset = m_moduleView.stringOffsets[i]; + const char* str = (const char*)m_moduleView.constants + strOffset; + m_stringLits[i] = str; + } + m_stringLitsPtr = m_stringLits.getBuffer(); + + m_functions.setCount(m_moduleView.functionCount); + for (uint32_t i = 0; i < m_moduleView.functionCount; i++) + { + auto func = m_moduleView.getFunction(i); + auto& exeFunc = m_functions[i]; + exeFunc.m_codeBuffer.setCount(func.header->codeSize / sizeof(uint64_t)); + exeFunc.m_header = func.header; + for (uint32_t j = 0; j < func.header->parameterCount; j++) + { + exeFunc.m_parameterOffsets.add(func.header->getParameterOffset(j)); + } + exeFunc.m_parameterOffsets.add(func.header->parameterSizeInBytes); + + // Copy the code into the executable function buffer + memcpy(exeFunc.m_codeBuffer.getBuffer(), func.functionCode, func.header->codeSize); + + // Replace the instruction headers with function pointers + for (auto inst : exeFunc) + { + VMInstHeader* instHeader = reinterpret_cast<VMInstHeader*>(inst); + auto handler = mapInstToFunction(instHeader, &m_moduleView, m_extInstHandlers); + if (!handler) + { + StringBuilder instStr; + printVMInst(instStr, &m_moduleView, instHeader); + reportError( + "Cannot find execution handler for instruction %s\n", + instStr.toString().getBuffer()); + return SLANG_FAIL; + } + inst->functionPtr = handler; + for (uint32_t operandIdx = 0; operandIdx < instHeader->operandCount; operandIdx++) + { + auto& operand = instHeader->getOperand(operandIdx); + auto& execOpernad = inst->getOperand(operandIdx); + switch (operand.sectionId) + { + case kSlangByteCodeSectionConstants: + execOpernad.section = &m_moduleView.constants; + break; + case kSlangByteCodeSectionInsts: + execOpernad.section = (uint8_t**)&m_currentFuncCode; + break; + case kSlangByteCodeSectionWorkingSet: + execOpernad.section = (uint8_t**)&m_currentWorkingSet; + break; + case kSlangByteCodeSectionStrings: + execOpernad.section = (uint8_t**)&m_stringLitsPtr; + execOpernad.offset *= sizeof(const char*); + break; + } + } + } + } + + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL ByteCodeInterpreter::loadModule(IBlob* moduleBlob) +{ + m_stack.reserve(128); + m_workingSetBuffer.reserve(1024 * 1024); // Reserve 1MB for working set + m_currentWorkingSet = m_workingSetBuffer.getBuffer(); + + m_errorBuilder.clear(); + m_code.addRange((uint8_t*)(moduleBlob->getBufferPointer()), moduleBlob->getBufferSize()); + SLANG_RETURN_ON_FAIL( + initVMModule(m_code.getBuffer(), (uint32_t)moduleBlob->getBufferSize(), &m_moduleView)); + SLANG_RETURN_ON_FAIL(prepareModuleForExecution()); + return SLANG_OK; +} + +SLANG_NO_THROW void SLANG_MCALL ByteCodeInterpreter::getErrorString(slang::IBlob** outBlob) +{ + *outBlob = StringBlob::moveCreate(m_errorBuilder.produceString()).detach(); + m_errorBuilder.clear(); +} + +SLANG_NO_THROW int SLANG_MCALL ByteCodeInterpreter::findFunctionByName(const char* name) +{ + for (uint32_t i = 0; i < m_moduleView.functionCount; i++) + { + auto func = m_moduleView.getFunction(i); + if (UnownedStringSlice(func.name) == name) + { + return (int)i; + } + } + return -1; // Function not found +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +ByteCodeInterpreter::getFunctionInfo(uint32_t index, slang::ByteCodeFuncInfo* outInfo) +{ + if (index >= m_moduleView.functionCount) + { + return SLANG_FAIL; + } + auto func = m_moduleView.getFunction(index); + outInfo->parameterCount = func.header->parameterCount; + outInfo->returnValueSize = func.header->returnValueSizeInBytes; + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +ByteCodeInterpreter::selectFunctionByIndex(uint32_t functionIndex) +{ + if (functionIndex >= m_moduleView.functionCount) + { + reportError( + "Function index %u out of range [0, %u)", + functionIndex, + m_moduleView.functionCount); + return SLANG_FAIL; + } + auto func = m_moduleView.getFunction(functionIndex); + m_currentFuncCode = m_functions[functionIndex].m_codeBuffer.getBuffer(); + m_currentInst = reinterpret_cast<VMExecInstHeader*>(m_currentFuncCode); + m_workingSetBuffer.setCount(func.header->workingSetSizeInBytes / sizeof(uint64_t)); + m_currentWorkingSet = m_workingSetBuffer.getBuffer(); + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +ByteCodeInterpreter::execute(void* argumentData, size_t argumentSize) +{ + if (!m_currentInst) + { + reportError("No function selected for execution"); + return SLANG_FAIL; + } + if (!m_currentWorkingSet) + { + reportError("No working set allocated for execution"); + return SLANG_FAIL; + } + if ((uint8_t*)m_currentWorkingSet + argumentSize > + (uint8_t*)(m_workingSetBuffer.getBuffer() + m_workingSetBuffer.getCount())) + { + reportError("Argument size exceeds working set."); + return SLANG_FAIL; + } + // Copy the arguments into the working set + if (argumentData && argumentSize > 0) + { + memcpy(m_currentWorkingSet, argumentData, argumentSize); + } + m_returnValSize = 0; + while (m_currentInst) + { + auto nextInst = m_currentInst->getNextInst(); + auto currentInst = m_currentInst; + m_currentInst = nextInst; + currentInst->functionPtr(this, currentInst, m_extInstHandlerUserData); + } + return SLANG_OK; +} + +ByteCodeInterpreter::ByteCodeInterpreter() +{ + m_printCallback = defaultPrintCallback; + m_printCallbackUserData = this; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +ByteCodeInterpreter::setPrintCallback(slang::VMPrintFunc callback, void* userData) +{ + m_printCallback = callback; + m_printCallbackUserData = userData; + return SLANG_OK; +} + +void ByteCodeInterpreter::defaultPrintCallback(const char* str, void* userData) +{ + SLANG_UNUSED(userData); + printf("%s", str); +} + +ExecutableFunction::InstIterator ExecutableFunction::begin() +{ + ExecutableFunction::InstIterator iter; + iter.codePtr = (uint8_t*)m_codeBuffer.getBuffer(); + return iter; +} + +ExecutableFunction::InstIterator ExecutableFunction::end() +{ + ExecutableFunction::InstIterator iter; + iter.codePtr = (uint8_t*)(m_codeBuffer.getBuffer() + m_codeBuffer.getCount()); + return iter; +} + + +} // namespace Slang + + +SLANG_EXTERN_C SLANG_API SlangResult slang_createByteCodeRunner( + const slang::ByteCodeRunnerDesc* desc, + slang::IByteCodeRunner** outByteCodeRunner) +{ + SLANG_UNUSED(desc); + Slang::RefPtr<Slang::ByteCodeInterpreter> runner = new Slang::ByteCodeInterpreter(); + *outByteCodeRunner = static_cast<slang::IByteCodeRunner*>(runner.detach()); + return SLANG_OK; +} + +SLANG_EXTERN_C SLANG_API SlangResult +slang_disassembleByteCode(slang::IBlob* moduleBlob, slang::IBlob** outDisassemblyBlob) +{ + Slang::VMModuleView moduleView; + SLANG_RETURN_ON_FAIL(Slang::initVMModule( + (uint8_t*)moduleBlob->getBufferPointer(), + (uint32_t)moduleBlob->getBufferSize(), + &moduleView)); + Slang::StringBuilder sb; + sb << moduleView; + *outDisassemblyBlob = Slang::StringBlob::moveCreate(sb.produceString()).detach(); + return SLANG_OK; +} diff --git a/source/slang/slang-vm.h b/source/slang/slang-vm.h new file mode 100644 index 000000000..7f85e398f --- /dev/null +++ b/source/slang/slang-vm.h @@ -0,0 +1,140 @@ +#ifndef SLANG_VM_H +#define SLANG_VM_H + +#include "core/slang-string-util.h" +#include "slang-vm-bytecode.h" + +using namespace slang; + +namespace Slang +{ + +struct ByteCodeExecutionContext +{ + void* currentWorkingSet; + uint32_t currentWorkingSetSizeInBytes; +}; + +class ByteCodeInterpreter; + +// Represents a relocated function code ready for execution. +// Relocated functions are VMInsts allocated in a 8-byte aligned buffer, and instruction headers +// Replaced with actual function pointers that can execute the instruction. +class ExecutableFunction +{ +public: + typedef VMInstIterator<VMExecOperand, VMExecInstHeader> InstIterator; + List<uint64_t> m_codeBuffer; + VMFuncHeader* m_header; + List<uint32_t> m_parameterOffsets; + + InstIterator begin(); + InstIterator end(); +}; + +struct StackFrame +{ + VMExecInstHeader* m_currentInst = nullptr; + void* m_currentFuncCode = nullptr; + size_t m_workingSetOffset = 0; +}; + +class ByteCodeInterpreter : public RefObject, public IByteCodeRunner +{ +public: + SLANG_REF_OBJECT_IUNKNOWN_ALL + ISlangUnknown* getInterface(const Guid& guid); + +public: + VMModuleView m_moduleView; + List<uint8_t> m_code; + StringBuilder m_errorBuilder; + List<ExecutableFunction> m_functions; + Dictionary<String, VMExtFunction> m_extInstHandlers; + SlangResult prepareModuleForExecution(); + void* m_extInstHandlerUserData = nullptr; + List<uint8_t> m_returnRegister; + List<uint64_t> m_workingSetBuffer; + List<StackFrame> m_stack; + List<const char*> m_stringLits; + const char** m_stringLitsPtr = nullptr; + + size_t m_returnValSize = 0; + + void pushFrame(uint32_t size) + { + StackFrame frame; + frame.m_workingSetOffset = + (uint32_t)((uint64_t*)m_currentWorkingSet - m_workingSetBuffer.getBuffer()); + m_stack.add(frame); + auto stackBufferCount = m_workingSetBuffer.getCount(); + m_workingSetBuffer.setCount(m_workingSetBuffer.getCount() + size / sizeof(uint64_t)); + m_currentWorkingSet = m_workingSetBuffer.getBuffer() + stackBufferCount; + } + void popFrame() + { + auto& stackFrame = m_stack.getLast(); + auto lastWorkingSetBufferCount = + (uint32_t)((uint64_t*)m_currentWorkingSet - m_workingSetBuffer.getBuffer()); + m_workingSetBuffer.setCount(lastWorkingSetBufferCount); + m_currentInst = stackFrame.m_currentInst->getNextInst(); + m_currentFuncCode = stackFrame.m_currentFuncCode; + m_currentWorkingSet = m_workingSetBuffer.getBuffer() + stackFrame.m_workingSetOffset; + m_stack.removeLast(); + } + + VMExecInstHeader* m_currentInst = nullptr; + void* m_currentFuncCode = nullptr; + void* m_currentWorkingSet = nullptr; + + VMPrintFunc m_printCallback = nullptr; + void* m_printCallbackUserData = nullptr; + + template<typename... Args> + void reportError(const char* format, Args... args) + { + m_errorBuilder.append(StringUtil::makeStringWithFormat(format, args...)); + m_errorBuilder.append("\n"); + } + + static void defaultPrintCallback(const char* message, void* userData); + ByteCodeInterpreter(); + +public: + virtual SLANG_NO_THROW SlangResult SLANG_MCALL loadModule(IBlob* moduleBlob) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + selectFunctionByIndex(uint32_t functionIndex) override; + virtual SLANG_NO_THROW int SLANG_MCALL findFunctionByName(const char* name) override; + + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getFunctionInfo(uint32_t index, ByteCodeFuncInfo* outInfo) override; + virtual SLANG_NO_THROW void* SLANG_MCALL getCurrentWorkingSet() override + { + return m_currentWorkingSet; + } + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + execute(void* argumentData, size_t argumentSize) override; + virtual SLANG_NO_THROW void SLANG_MCALL getErrorString(slang::IBlob** outBlob) override; + virtual SLANG_NO_THROW void* SLANG_MCALL getReturnValue(size_t* outValueSize) override + { + *outValueSize = m_returnValSize; + return m_returnRegister.getBuffer(); + } + virtual SLANG_NO_THROW void SLANG_MCALL setExtInstHandlerUserData(void* userData) override + { + m_extInstHandlerUserData = userData; + } + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + registerExtCall(const char* name, VMExtFunction functionPtr) override + { + m_extInstHandlers[name] = functionPtr; + return SLANG_OK; + } + + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + setPrintCallback(VMPrintFunc callback, void* userData) override; +}; + +} // namespace Slang + +#endif diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index df07a7637..613fb7090 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -5099,6 +5099,8 @@ RefPtr<EntryPoint> Module::findAndCheckEntryPoint( if (auto existingEntryPoint = findEntryPointByName(name)) return existingEntryPoint; + SLANG_AST_BUILDER_RAII(m_astBuilder); + // If the function hasn't been marked as [shader], then it won't be discovered // by findEntryPointByName. We need to route this to the `findAndValidateEntryPoint` // function. To do that we need to setup a FrontEndCompileRequest and a |
