diff options
43 files changed, 4658 insertions, 20 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 081a40fb4..8fa900d83 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -118,6 +118,7 @@ option(SLANG_ENABLE_PREBUILT_BINARIES "Enable using prebuilt binaries" ON) option(SLANG_ENABLE_GFX "Enable gfx targets" ON) option(SLANG_ENABLE_SLANGD "Enable language server target" ON) option(SLANG_ENABLE_SLANGC "Enable standalone compiler target" ON) +option(SLANG_ENABLE_SLANGI "Enable Slang interpreter target" ON) option(SLANG_ENABLE_SLANGRT "Enable runtime target" ON) option( SLANG_ENABLE_SLANG_GLSLANG diff --git a/docs/building.md b/docs/building.md index 6597c5199..de76d2825 100644 --- a/docs/building.md +++ b/docs/building.md @@ -159,6 +159,7 @@ See the [documentation on testing](../tools/slang-test/README.md) for more infor | `SLANG_ENABLE_GFX` | `TRUE` | Enable gfx targets | | `SLANG_ENABLE_SLANGD` | `TRUE` | Enable language server target | | `SLANG_ENABLE_SLANGC` | `TRUE` | Enable standalone compiler target | +| `SLANG_ENABLE_SLANGI` | `TRUE` | Enable Slang interpreter target | | `SLANG_ENABLE_SLANGRT` | `TRUE` | Enable runtime target | | `SLANG_ENABLE_SLANG_GLSLANG` | `TRUE` | Enable glslang dependency and slang-glslang wrapper target | | `SLANG_ENABLE_TESTS` | `TRUE` | Enable test targets, requires SLANG_ENABLE_GFX, SLANG_ENABLE_SLANGD and SLANG_ENABLE_SLANGRT | diff --git a/docs/user-guide/a3-02-reference-capability-atoms.md b/docs/user-guide/a3-02-reference-capability-atoms.md index a9455761c..ea3e5952a 100644 --- a/docs/user-guide/a3-02-reference-capability-atoms.md +++ b/docs/user-guide/a3-02-reference-capability-atoms.md @@ -45,6 +45,9 @@ Targets `wgsl` > Represents the WebGPU shading language code generation target. +`slangvm` +> Represents the Slang VM bytecode target. + Stages ---------------------- *Capabilities to specify code generation stages (`vertex`, `fragment`...)* @@ -67,6 +70,9 @@ Stages `geometry` > Geometry shader stage +`dispatch` +> Dispatch shader stage + `pixel` > Pixel shader stage diff --git a/include/slang.h b/include/slang.h index 1f50c2648..e61c3cd7f 100644 --- a/include/slang.h +++ b/include/slang.h @@ -622,6 +622,8 @@ typedef uint32_t SlangSizeT; SLANG_WGSL, ///< WebGPU shading language SLANG_WGSL_SPIRV_ASM, ///< SPIR-V assembly via WebGPU shading language SLANG_WGSL_SPIRV, ///< SPIR-V via WebGPU shading language + + SLANG_HOST_VM, ///< Bytecode that can be interpreted by the Slang VM SLANG_TARGET_COUNT_OF, }; @@ -805,6 +807,7 @@ typedef uint32_t SlangSizeT; SLANG_STAGE_CALLABLE, SLANG_STAGE_MESH, SLANG_STAGE_AMPLIFICATION, + SLANG_STAGE_DISPATCH, // SLANG_STAGE_COUNT, @@ -4574,6 +4577,125 @@ SLANG_EXTERN_C SLANG_API void slang_shutdown(); */ SLANG_EXTERN_C SLANG_API const char* slang_getLastInternalErrorMessage(); +// Slang VM +namespace slang +{ + +enum class OperandDataType +{ + General = 0, // General data type, can be any type. + Int32 = 1, // 32-bit integer. + Int64 = 2, // 64-bit integer. + Float32 = 3, // 32-bit floating-point number. + Float64 = 4, // 64-bit floating-point number. + String = 5, // String data type, represented as a pointer to a null-terminated string. +}; + +struct VMExecOperand +{ + uint8_t** section; // Pointer to the section start pointer. + #if SLANG_PTR_IS_32 + uint32_t padding; + #endif + uint32_t type : 8; // type of the operand data. + uint32_t size : 24; + uint32_t offset; + void* getPtr() const { return *section + offset; } + OperandDataType getType() const { return (OperandDataType)type; } +}; + +struct VMExecInstHeader; +class IByteCodeRunner; + +typedef void (*VMExtFunction)(IByteCodeRunner* context, VMExecInstHeader* inst, void* userData); +typedef void (*VMPrintFunc)(const char* message, void* userData); + +struct VMExecInstHeader +{ + VMExtFunction functionPtr; // Pointer to the function that executes this instruction. + #if SLANG_PTR_IS_32 + uint32_t padding; + #endif + uint32_t opcodeExtension; + uint32_t operandCount; + VMExecInstHeader* getNextInst() + { + return (VMExecInstHeader*)((VMExecOperand*)(this + 1) + operandCount); + } + VMExecOperand& getOperand(SlangInt index) const + { + return *((VMExecOperand*)(this + 1) + index); + } +}; + +struct ByteCodeFuncInfo +{ + uint32_t parameterCount; + uint32_t returnValueSize; +}; + +struct ByteCodeRunnerDesc +{ + /** The size of this structure, in bytes. + */ + size_t structSize = sizeof(ByteCodeRunnerDesc); +}; + +/// Represents a byte code runner that can execute Slang byte code. +class IByteCodeRunner : public ISlangUnknown +{ +public: + // {AFDAB195-361F-42CB-9513-9006261DD8CD} + SLANG_COM_INTERFACE(0xafdab195, 0x361f, 0x42cb, {0x95, 0x13, 0x90, 0x6, 0x26, 0x1d, 0xd8, 0xcd}) + + /// Load a byte code module into the execution context. + virtual SLANG_NO_THROW SlangResult SLANG_MCALL loadModule(IBlob* moduleBlob) = 0; + + /// Select a function for execution. + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + selectFunctionByIndex(uint32_t functionIndex) = 0; + + virtual SLANG_NO_THROW int SLANG_MCALL findFunctionByName(const char* name) = 0; + + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getFunctionInfo(uint32_t index, ByteCodeFuncInfo* outInfo) = 0; + + /// Obtain the current working set memory for the selected function. + virtual SLANG_NO_THROW void* SLANG_MCALL getCurrentWorkingSet() = 0; + + /// Execute the selected function. + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + execute(void* argumentData, size_t argumentSize) = 0; + + /// Query the error string. + virtual SLANG_NO_THROW void SLANG_MCALL getErrorString(IBlob** outBlob) = 0; + + /// Retrieve the return value of the last executed function. + virtual SLANG_NO_THROW void* SLANG_MCALL getReturnValue(size_t* outValueSize) = 0; + + /// Set the user data for the external instruction handler. + virtual SLANG_NO_THROW void SLANG_MCALL setExtInstHandlerUserData(void* userData) = 0; + + /// Register an external function that can be called from the byte code. + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + registerExtCall(const char* name, VMExtFunction functionPtr) = 0; + + /// Set a callback function to print messages from the byte code runner. + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + setPrintCallback(VMPrintFunc callback, void* userData) = 0; +}; + +} // namespace slang + +/// Create a byte code runner that can execute Slang byte code. +SLANG_EXTERN_C SLANG_API SlangResult slang_createByteCodeRunner( + const slang::ByteCodeRunnerDesc* desc, + slang::IByteCodeRunner** outByteCodeRunner); + +/// Disassemble a Slang byte code blob into human-readable text. +SLANG_EXTERN_C SLANG_API SlangResult +slang_disassembleByteCode(slang::IBlob* moduleBlob, slang::IBlob** outDisassemblyBlob); + namespace slang { inline SlangResult createGlobalSession(slang::IGlobalSession** outGlobalSession) diff --git a/source/compiler-core/slang-artifact-desc-util.cpp b/source/compiler-core/slang-artifact-desc-util.cpp index c1575ba16..5f06853ec 100644 --- a/source/compiler-core/slang-artifact-desc-util.cpp +++ b/source/compiler-core/slang-artifact-desc-util.cpp @@ -325,6 +325,9 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL return Desc::make(Kind::Assembly, Payload::WGSL_SPIRV, Style::Kernel, 0); case SLANG_WGSL_SPIRV: return Desc::make(Kind::ObjectCode, Payload::WGSL_SPIRV, Style::Kernel, 0); + + case SLANG_HOST_VM: + return Desc::make(Kind::ObjectCode, Payload::UniversalCPU, Style::Host, 0); default: break; } diff --git a/source/core/slang-string-escape-util.cpp b/source/core/slang-string-escape-util.cpp index 42e757b23..0645d94ba 100644 --- a/source/core/slang-string-escape-util.cpp +++ b/source/core/slang-string-escape-util.cpp @@ -1184,4 +1184,19 @@ StringEscapeUtil::Handler* StringEscapeUtil::getHandler(Style style) return SLANG_OK; } +String StringEscapeUtil::escapeString(UnownedStringSlice input, StringEscapeUtil::Style style) +{ + StringBuilder sb; + auto handler = StringEscapeUtil::getHandler(style); + StringEscapeUtil::appendQuoted(handler, input, sb); + return sb.produceString(); +} + +String StringEscapeUtil::unescapeString(UnownedStringSlice input, StringEscapeUtil::Style style) +{ + StringBuilder sb; + auto handler = StringEscapeUtil::getHandler(style); + StringEscapeUtil::appendUnquoted(handler, input, sb); + return sb.produceString(); +} } // namespace Slang diff --git a/source/core/slang-string-escape-util.h b/source/core/slang-string-escape-util.h index 6e02d772d..ece8de79f 100644 --- a/source/core/slang-string-escape-util.h +++ b/source/core/slang-string-escape-util.h @@ -113,6 +113,9 @@ struct StringEscapeUtil Handler* handler, const UnownedStringSlice& slice, StringBuilder& out); + + static String escapeString(UnownedStringSlice input, Style style = Style::Slang); + static String unescapeString(UnownedStringSlice input, Style style = Style::Slang); }; diff --git a/source/core/slang-string-util.cpp b/source/core/slang-string-util.cpp index ba234e30b..65758fd14 100644 --- a/source/core/slang-string-util.cpp +++ b/source/core/slang-string-util.cpp @@ -390,6 +390,159 @@ UnownedStringSlice StringUtil::getAtInSplit( return builder; } +template<typename T> +static T readValue(ArrayView<const void*> ptrToArgs, Count& argIndex) +{ + if (argIndex < ptrToArgs.getCount()) + { + T value; + memcpy(&value, ptrToArgs[argIndex], sizeof(T)); + argIndex++; + return value; + } + return T(); +} + +String StringUtil::makeStringWithFormatFromArgArray( + const char* format, + ArrayView<const void*> ptrToArgs) +{ + if (!format) + { + return String(); + } + StringBuilder builder; + const char* ptr = format; + Count argIndex = 0; + auto consumeString = [&]() + { + if (argIndex < ptrToArgs.getCount()) + { + const char* strPtr = *(const char**)ptrToArgs[argIndex]; + argIndex++; + if (strPtr) + { + // Append the string to the builder + builder.append(strPtr); + } + } + }; +#define ADVANCE_PTR \ + ptr++; \ + if (!*ptr) \ + { \ + return builder.produceString(); \ + } + + while (*ptr) + { + if (*ptr == '%') + { + const char* formatStart = ptr; + ADVANCE_PTR; + if (*ptr == 's') + { + // If we have a %s, then we want to append the data + consumeString(); + // Move past the 's' + ADVANCE_PTR; + continue; + } + if (*ptr == '-') + { + // If we have a %- then we want to continue parsing format string. + ADVANCE_PTR; + } + while (CharUtil::isDigit(*ptr)) + { + // Skip the digits after the '.' + ADVANCE_PTR; + } + if (*ptr == '.') + { + ADVANCE_PTR; + while (CharUtil::isDigit(*ptr)) + { + // Skip the digits after the '.' + ADVANCE_PTR; + } + } + int isLong = 0; + if (*ptr == 'l' || *ptr == 'L') + { + // If we have a 'l' or 'L', then we want to skip it. + ADVANCE_PTR; + isLong = 1; + if (*ptr == 'l' || *ptr == 'L') + { + // If we have another 'l' or 'L', then we want to skip it too. + ADVANCE_PTR; + isLong = 2; + } + } + const char typeChar = *ptr; + ADVANCE_PTR; + String formatStr = UnownedStringSlice(formatStart, ptr); + switch (CharUtil::toLower(typeChar)) + { + case 'd': + case 'x': + case 'i': + case 'u': + case 'o': + case 'c': + if (isLong == 2) + { + StringUtil::appendFormat( + builder, + formatStr.getBuffer(), + readValue<int64_t>(ptrToArgs, argIndex)); + } + else + { + StringUtil::appendFormat( + builder, + formatStr.getBuffer(), + readValue<int>(ptrToArgs, argIndex)); + } + break; + case 'e': + case 'f': + case 'g': + if (isLong != 0) + { + StringUtil::appendFormat( + builder, + formatStr.getBuffer(), + readValue<double>(ptrToArgs, argIndex)); + } + else + { + StringUtil::appendFormat( + builder, + formatStr.getBuffer(), + readValue<float>(ptrToArgs, argIndex)); + } + break; + case 'n': + break; + case '%': + // If we have a '%%' then we want to append a single '%' + builder.appendChar('%'); + continue; + } + } + else + { + // Just append the character + builder.appendChar(*ptr); + ptr++; + } + } + return builder.produceString(); +} + + /* static */ UnownedStringSlice StringUtil::getSlice(ISlangBlob* blob) { if (blob) diff --git a/source/core/slang-string-util.h b/source/core/slang-string-util.h index 9e7b2d65a..4f4368ba3 100644 --- a/source/core/slang-string-util.h +++ b/source/core/slang-string-util.h @@ -135,6 +135,11 @@ struct StringUtil /// Create a string from the format string applying args (like sprintf) static String makeStringWithFormat(const char* format, ...); + /// Create a string from the format string and arguments in a buffer. + static String makeStringWithFormatFromArgArray( + const char* format, + ArrayView<const void*> ptrToArgs); + /// Given a string held in a blob, returns as a String /// Returns an empty string if blob is nullptr static String getString(ISlangBlob* blob); diff --git a/source/core/slang-type-text-util.cpp b/source/core/slang-type-text-util.cpp index 01ce0d1ac..2261ed614 100644 --- a/source/core/slang-type-text-util.cpp +++ b/source/core/slang-type-text-util.cpp @@ -87,6 +87,7 @@ static const TypeTextUtil::CompileTargetInfo s_compileTargetInfos[] = { "wgsl-spirv-asm,wgsl-spirv-assembly", "SPIR-V assembly via WebGPU shading language"}, {SLANG_WGSL_SPIRV, "wgsl-spirv", "wgsl-spirv", "SPIR-V via WebGPU shading language"}, + {SLANG_HOST_VM, "slang-vm", "slangvm,slang-vm", "Slang VM byte code"}, }; static const NamesDescriptionValue s_languageInfos[] = { diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 03321bfaf..44b9a8860 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -12155,6 +12155,7 @@ vector<T, N> powr(vector<T, N> x, vector<T, N> y) /// } /// ``` [require(cpp_cuda_glsl_hlsl_spirv, printf)] +[require(slangvm)] __intrinsic_op($(kIROp_Printf)) void printf<each T>(NativeString format, expand each T args); diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index f4ae94978..8f6aa254c 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -115,6 +115,10 @@ def spirv : target; /// [Target] def wgsl : target + textualTarget; +/// Represents the Slang VM bytecode target. +/// [Target] +def slangvm : target; + // Capabilities that stand for target SPIR-V versions for the GLSL backend. // These are not compilation targets. We will convert `_spirv_*` to `glsl_spirv_*` during compilation. @@ -428,6 +432,10 @@ def domain : stage; /// [Stage] def geometry : stage; +/// Dispatch shader stage +/// [Stage] +def dispatch : stage; + def _raygen : stage; alias _raygeneration = _raygen; def _intersection : stage; diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index a8c3aee15..4547281e1 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -448,6 +448,7 @@ void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) } bool canHaveVaryingInput = false; + bool shouldWarnOnNonUniformParam = true; switch (stage) { case Stage::Vertex: @@ -462,6 +463,9 @@ void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) case Stage::Domain: canHaveVaryingInput = true; break; + case Stage::Dispatch: + shouldWarnOnNonUniformParam = false; + break; default: break; } @@ -499,10 +503,13 @@ void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) // support varying input/output. We will automatically convert it to a 'uniform' parameter, // and diagnose a warning. addModifier(param, getCurrentASTBuilder()->create<HLSLUniformModifier>()); - sink->diagnose( - param, - Diagnostics::nonUniformEntryPointParameterTreatedAsUniform, - param->getName()); + if (shouldWarnOnNonUniformParam) + { + sink->diagnose( + param, + Diagnostics::nonUniformEntryPointParameterTreatedAsUniform, + param->getName()); + } } for (auto target : linkage->targets) diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 8e9b8f430..15f22630c 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -468,6 +468,8 @@ Stage getStageFromAtom(CapabilityAtom atom) return Stage::Miss; case CapabilityAtom::_callable: return Stage::Callable; + case CapabilityAtom::dispatch: + return Stage::Dispatch; default: SLANG_UNEXPECTED("unknown stage atom"); UNREACHABLE_RETURN(Stage::Unknown); @@ -1766,6 +1768,8 @@ SlangResult emitSPIRVForEntryPointsDirectly( CodeGenContext* codeGenContext, ComPtr<IArtifact>& outArtifact); +SlangResult emitHostVMCode(CodeGenContext* codeGenContext, ComPtr<IArtifact>& outArtifact); + static CodeGenTarget _getIntermediateTarget(CodeGenTarget target) { switch (target) @@ -1835,7 +1839,9 @@ SlangResult CodeGenContext::_emitEntryPoints(ComPtr<IArtifact>& outArtifact) case CodeGenTarget::WGSLSPIRV: SLANG_RETURN_ON_FAIL(emitWithDownstreamForEntryPoints(outArtifact)); return SLANG_OK; - + case CodeGenTarget::HostVM: + SLANG_RETURN_ON_FAIL(emitHostVMCode(this, outArtifact)); + return SLANG_OK; default: break; } @@ -1887,6 +1893,7 @@ SlangResult CodeGenContext::emitEntryPoints(ComPtr<IArtifact>& outArtifact) case CodeGenTarget::HostExecutable: case CodeGenTarget::HostSharedLibrary: case CodeGenTarget::WGSLSPIRVAssembly: + case CodeGenTarget::HostVM: { SLANG_RETURN_ON_FAIL(_emitEntryPoints(outArtifact)); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index bfae6e400..27c368738 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -92,6 +92,7 @@ enum class CodeGenTarget : SlangCompileTargetIntegral WGSL = SLANG_WGSL, WGSLSPIRVAssembly = SLANG_WGSL_SPIRV_ASM, WGSLSPIRV = SLANG_WGSL_SPIRV, + HostVM = SLANG_HOST_VM, CountOf = SLANG_TARGET_COUNT_OF, }; @@ -1492,7 +1493,7 @@ public: { return SLANG_E_INVALID_ARG; } - + SLANG_AST_BUILDER_RAII(m_astBuilder); ComPtr<slang::IEntryPoint> entryPoint(findEntryPointByName(UnownedStringSlice(name))); if ((!entryPoint)) return SLANG_FAIL; @@ -1511,7 +1512,6 @@ public: { return SLANG_E_INVALID_ARG; } - ComPtr<slang::IEntryPoint> entryPoint( findAndCheckEntryPoint(UnownedStringSlice(name), stage, outDiagnostics)); if ((!entryPoint)) diff --git a/source/slang/slang-emit-slang.cpp b/source/slang/slang-emit-slang.cpp new file mode 100644 index 000000000..175f7ffc5 --- /dev/null +++ b/source/slang/slang-emit-slang.cpp @@ -0,0 +1,17 @@ +#include "slang-emit-slang.h" + +namespace Slang +{ + +SlangResult emitSlangDeclarationsForEntryPoints( + CodeGenContext* codeGenContext, + LinkedIR& linkedIR, + String& outSlangCode) +{ + SLANG_UNUSED(codeGenContext); + SLANG_UNUSED(linkedIR); + SLANG_UNUSED(outSlangCode); + return SLANG_OK; +} + +} // namespace Slang diff --git a/source/slang/slang-emit-slang.h b/source/slang/slang-emit-slang.h new file mode 100644 index 000000000..964e1e52e --- /dev/null +++ b/source/slang/slang-emit-slang.h @@ -0,0 +1,16 @@ +#ifndef SLANG_EMIT_SLANG_H +#define SLANG_EMIT_SLANG_H + +#include "slang-emit-base.h" +#include "slang-ir-link.h" +#include "slang-vm-bytecode.h" + +namespace Slang +{ +SlangResult emitSlangDeclarationsForEntryPoints( + CodeGenContext* codeGenContext, + LinkedIR& linkedIR, + String& outSlangDeclaration); +} + +#endif diff --git a/source/slang/slang-emit-vm.cpp b/source/slang/slang-emit-vm.cpp new file mode 100644 index 000000000..fc0b4432e --- /dev/null +++ b/source/slang/slang-emit-vm.cpp @@ -0,0 +1,1266 @@ +#include "slang-emit-vm.h" + +#include "slang-ir-call-graph.h" +#include "slang-ir-layout.h" +#include "slang-ir-util.h" + +using namespace slang; + +namespace Slang +{ +class ByteCodeEmitter +{ +public: + Dictionary<IRInst*, String> mapInstToName; + Dictionary<String, int> mapNameToUniqueId; + Dictionary<IRInst*, VMOperand> mapInstToOperand; + Dictionary<UnownedStringSlice, VMOperand> mapStringToOperand; + struct ConstKey + { + uint64_t value; + uint32_t size; + bool operator==(const ConstKey& other) const + { + return value == other.value && size == other.size; + } + bool operator!=(const ConstKey& other) const { return !(*this == other); } + HashCode getHashCode() const { return combineHash(value, size); } + }; + Dictionary<ConstKey, VMOperand> mapConstantIntToOperand; + Dictionary<IRFunc*, int> mapFuncToId; + + VMByteCodeBuilder& byteCodeBuilder; + CodeGenContext* codeGenContext; + + ByteCodeEmitter(VMByteCodeBuilder& builder, CodeGenContext* codeGenContext) + : byteCodeBuilder(builder), codeGenContext(codeGenContext) + { + } + + String getName(IRInst* inst) + { + String name; + if (mapInstToName.tryGetValue(inst, name)) + return name; + + if (auto nameDecor = inst->findDecoration<IRNameHintDecoration>()) + { + name = nameDecor->getName(); + } + else if (auto linkageDecor = inst->findDecoration<IRLinkageDecoration>()) + { + name = linkageDecor->getMangledName(); + } + else + { + name = getIROpInfo(inst->getOp()).name; + } + if (int* id = mapNameToUniqueId.tryGetValue(name)) + { + (*id)++; + name = name + "_" + String(*id); + } + else + { + mapNameToUniqueId[name] = 0; + } + mapInstToName[inst] = name; + return name; + } + + struct InstRelocationEntry + { + Index offsetToOperand; + IRBlock* block; + }; + + template<typename T> + static T alignUp(T value, T alignment) + { + return (value + alignment - 1) / alignment * alignment; + } + + VMOperand allocReg(VMByteCodeFunctionBuilder& funcBuilder, size_t size, size_t alignment) + { + VMOperand operand; + operand.sectionId = kSlangByteCodeSectionWorkingSet; + operand.offset = funcBuilder.workingSetSizeInBytes; + funcBuilder.workingSetSizeInBytes = + alignUp(funcBuilder.workingSetSizeInBytes, (uint32_t)alignment); + operand.offset = funcBuilder.workingSetSizeInBytes; + operand.size = size; + funcBuilder.workingSetSizeInBytes += (uint32_t)size; + return operand; + } + + VMOperand ensureWorkingsetMemory(VMByteCodeFunctionBuilder& funcBuilder, IRInst* inst) + { + VMOperand operand; + + if (mapInstToOperand.tryGetValue(inst, operand)) + return operand; + + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getDataType(), + &sizeAlignment); + operand = allocReg(funcBuilder, sizeAlignment.size, sizeAlignment.alignment); + mapInstToOperand[inst] = operand; + return operand; + } + + VMOperand addStringLiteral(UnownedStringSlice str) + { + if (auto operand = mapStringToOperand.tryGetValue(str)) + return *operand; + VMOperand operand; + operand.sectionId = kSlangByteCodeSectionStrings; + operand.offset = (uint32_t)byteCodeBuilder.stringOffsets.getCount(); + + byteCodeBuilder.stringOffsets.add((uint32_t)byteCodeBuilder.constantSection.getCount()); + byteCodeBuilder.constantSection.addRange((uint8_t*)str.begin(), str.getLength()); + byteCodeBuilder.constantSection.add(0); + operand.setType(OperandDataType::String); + operand.size = 0; + mapStringToOperand[str] = operand; + return operand; + } + + void alignConstSection(int alignment) + { + int rem = (int)byteCodeBuilder.constantSection.getCount() % alignment; + if (rem != 0) + { + int paddingSize = alignment - rem; + for (int i = 0; i < paddingSize; i++) + { + byteCodeBuilder.constantSection.add(0); + } + } + } + + template<typename IntType> + VMOperand addConstantValue(IntType value) + { + ConstKey key; + key.value = value; + key.size = (uint32_t)sizeof(IntType); + if (auto operand = mapConstantIntToOperand.tryGetValue(key)) + return *operand; + VMOperand operand; + operand.sectionId = kSlangByteCodeSectionConstants; + // align constantSection + alignConstSection((int)sizeof(IntType)); + operand.offset = (uint32_t)byteCodeBuilder.constantSection.getCount(); + byteCodeBuilder.constantSection.addRange((uint8_t*)&value, sizeof(value)); + mapConstantIntToOperand[key] = operand; + + operand.size = sizeof(IntType); + if (operand.size == 4) + operand.setType(OperandDataType::Int32); + else if (operand.size == 8) + operand.setType(OperandDataType::Int64); + else + operand.setType(OperandDataType::General); + return operand; + } + + VMOperand addConstantValue(IRConstant* inst) + { + VMOperand operand; + operand.sectionId = kSlangByteCodeSectionConstants; + + // Align constantSection. + IRSizeAndAlignment sizeAlignment; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getDataType(), + &sizeAlignment); + alignConstSection(sizeAlignment.alignment); + + operand.offset = (uint32_t)byteCodeBuilder.constantSection.getCount(); + operand.size = sizeAlignment.size; + + switch (inst->getOp()) + { + case kIROp_StringLit: + { + return addStringLiteral(static_cast<IRStringLit*>(inst)->getStringSlice()); + } + case kIROp_IntLit: + { + int64_t value = static_cast<IRIntLit*>(inst)->getValue(); + byteCodeBuilder.constantSection.addRange((uint8_t*)&value, sizeAlignment.size); + operand.setType(OperandDataType::General); + if (sizeAlignment.size != 64) + { + operand.setType(OperandDataType::Int32); + } + break; + } + case kIROp_FloatLit: + { + auto value = static_cast<IRFloatLit*>(inst)->getValue(); + if (inst->getDataType()->getOp() == kIROp_HalfType) + { + auto halfValue = FloatToHalf((float)value); + byteCodeBuilder.constantSection.addRange( + (uint8_t*)&halfValue, + sizeof(halfValue)); + } + else if (inst->getDataType()->getOp() == kIROp_FloatType) + { + float floatValue = (float)value; + byteCodeBuilder.constantSection.addRange( + (uint8_t*)&floatValue, + sizeof(floatValue)); + operand.setType(OperandDataType::Float32); + } + else + { + byteCodeBuilder.constantSection.addRange((uint8_t*)&value, sizeof(value)); + operand.setType(OperandDataType::Float64); + } + break; + } + case kIROp_PtrLit: + { + int64_t value = static_cast<IRIntLit*>(inst)->getValue(); + byteCodeBuilder.constantSection.addRange((uint8_t*)&value, sizeof(value)); + break; + } + case kIROp_VoidLit: + break; + } + return operand; + } + + VMOperand ensureInst(IRInst* inst) + { + VMOperand operand; + if (mapInstToOperand.tryGetValue(inst, operand)) + return operand; + + if (auto constantInst = as<IRConstant>(inst)) + { + operand = addConstantValue(constantInst); + mapInstToOperand[inst] = operand; + } + else + { + SLANG_UNEXPECTED("unsupported global inst for vm bytecode emit"); + } + return operand; + } + + void writeInst( + VMByteCodeFunctionBuilder& funcBuilder, + VMOp op, + uint32_t extOp, + ArrayView<VMOperand> operands) + { + VMInstHeader instHeader; + instHeader.opcode = op; + instHeader.opcodeExtension = extOp; + instHeader.operandCount = (uint16_t)operands.getCount(); + funcBuilder.instOffsets.add(funcBuilder.code.getCount()); + funcBuilder.code.addRange(reinterpret_cast<uint8_t*>(&instHeader), sizeof(instHeader)); + for (auto operand : operands) + { + funcBuilder.code.addRange(reinterpret_cast<uint8_t*>(&operand), sizeof(operand)); + } + } + + void writeInst(VMByteCodeFunctionBuilder& funcBuilder, VMOp op, uint32_t extOp) + { + writeInst(funcBuilder, op, extOp, ArrayView<VMOperand>()); + } + + void writeInst( + VMByteCodeFunctionBuilder& funcBuilder, + VMOp op, + uint32_t extOp, + VMOperand operand) + { + writeInst(funcBuilder, op, extOp, makeArrayViewSingle(operand)); + } + + void writeInst( + VMByteCodeFunctionBuilder& funcBuilder, + VMOp op, + uint32_t extOp, + VMOperand operand1, + VMOperand operand2) + { + writeInst(funcBuilder, op, extOp, makeArray(operand1, operand2).getView()); + } + + void writeInst( + VMByteCodeFunctionBuilder& funcBuilder, + VMOp op, + uint32_t extOp, + VMOperand operand1, + VMOperand operand2, + VMOperand operand3) + { + writeInst(funcBuilder, op, extOp, makeArray(operand1, operand2, operand3).getView()); + } + + uint32_t getExtCode(IRInst* type) + { + ArithmeticExtCode extCode = {}; + if (auto vecType = as<IRVectorType>(type)) + { + extCode.vectorSize = getIntVal(vecType->getElementCount()); + type = vecType->getElementType(); + } + else if (auto matType = as<IRMatrixType>(type)) + { + extCode.vectorSize = + getIntVal(matType->getRowCount()) * getIntVal(matType->getColumnCount()); + type = matType->getElementType(); + } + switch (type->getOp()) + { + case kIROp_IntType: + case kIROp_BoolType: + extCode.scalarType = kSlangByteCodeScalarTypeSignedInt; + extCode.scalarBitWidth = 2; + break; + case kIROp_Int8Type: + extCode.scalarType = kSlangByteCodeScalarTypeSignedInt; + extCode.scalarBitWidth = 0; + break; + case kIROp_Int16Type: + extCode.scalarType = kSlangByteCodeScalarTypeSignedInt; + extCode.scalarBitWidth = 1; + break; + case kIROp_Int64Type: + case kIROp_IntPtrType: + extCode.scalarType = kSlangByteCodeScalarTypeSignedInt; + extCode.scalarBitWidth = 3; + break; + case kIROp_UIntType: + extCode.scalarType = kSlangByteCodeScalarTypeUnsignedInt; + extCode.scalarBitWidth = 2; + break; + case kIROp_UInt8Type: + extCode.scalarType = kSlangByteCodeScalarTypeUnsignedInt; + extCode.scalarBitWidth = 0; + break; + case kIROp_UInt16Type: + extCode.scalarType = kSlangByteCodeScalarTypeUnsignedInt; + extCode.scalarBitWidth = 1; + break; + case kIROp_UInt64Type: + case kIROp_UIntPtrType: + case kIROp_PtrType: + case kIROp_OutType: + case kIROp_InOutType: + case kIROp_RefType: + case kIROp_NativePtrType: + extCode.scalarType = kSlangByteCodeScalarTypeUnsignedInt; + extCode.scalarBitWidth = 3; + break; + case kIROp_FloatType: + extCode.scalarType = kSlangByteCodeScalarTypeFloat; + extCode.scalarBitWidth = 2; + break; + case kIROp_HalfType: + extCode.scalarType = kSlangByteCodeScalarTypeFloat; + extCode.scalarBitWidth = 1; + break; + case kIROp_DoubleType: + extCode.scalarType = kSlangByteCodeScalarTypeFloat; + extCode.scalarBitWidth = 3; + break; + default: + SLANG_UNEXPECTED("Unsupported type for arithmetic operation"); + } + uint32_t result; + memcpy(&result, &extCode, sizeof(extCode)); + return result; + } + + VMInstHeader translateArithmeticOp(IRInst* inst) + { + VMInstHeader opInfo = {}; + + switch (inst->getOp()) + { + case kIROp_Add: + opInfo.opcode = VMOp::Add; + break; + case kIROp_Sub: + opInfo.opcode = VMOp::Sub; + break; + case kIROp_Mul: + opInfo.opcode = VMOp::Mul; + break; + case kIROp_Div: + opInfo.opcode = VMOp::Div; + break; + case kIROp_IRem: + case kIROp_FRem: + opInfo.opcode = VMOp::Rem; + break; + case kIROp_Neg: + opInfo.opcode = VMOp::Neg; + break; + case kIROp_And: + opInfo.opcode = VMOp::And; + break; + case kIROp_Or: + opInfo.opcode = VMOp::Or; + break; + case kIROp_Not: + opInfo.opcode = VMOp::Not; + break; + case kIROp_BitAnd: + opInfo.opcode = VMOp::BitAnd; + break; + case kIROp_BitOr: + opInfo.opcode = VMOp::BitOr; + break; + case kIROp_BitXor: + opInfo.opcode = VMOp::BitXor; + break; + case kIROp_BitNot: + opInfo.opcode = VMOp::BitNot; + break; + case kIROp_Lsh: + opInfo.opcode = VMOp::Shl; + break; + case kIROp_Rsh: + opInfo.opcode = VMOp::Shr; + break; + case kIROp_Less: + opInfo.opcode = VMOp::Less; + break; + case kIROp_Leq: + opInfo.opcode = VMOp::Leq; + break; + case kIROp_Greater: + opInfo.opcode = VMOp::Greater; + break; + case kIROp_Geq: + opInfo.opcode = VMOp::Geq; + break; + case kIROp_Eql: + opInfo.opcode = VMOp::Equal; + break; + case kIROp_Neq: + opInfo.opcode = VMOp::Neq; + break; + default: + SLANG_UNEXPECTED("Unsupported operation"); + break; + } + opInfo.opcodeExtension = getExtCode(inst->getOperand(0)->getDataType()); + return opInfo; + } + + void emitCast(VMByteCodeFunctionBuilder& funcBuilder, VMOp op, IRInst* inst) + { + auto extCode1 = getExtCode(inst->getDataType()); + auto extCode2 = getExtCode(inst->getOperand(0)->getDataType()); + auto extCode = extCode1 | (extCode2 << 16); + writeInst( + funcBuilder, + op, + extCode, + ensureWorkingsetMemory(funcBuilder, inst), + ensureInst(inst->getOperand(0))); + } + + void emitInst( + VMByteCodeFunctionBuilder& funcBuilder, + IRInst* inst, + List<InstRelocationEntry>& relocations) + { + switch (inst->getOp()) + { + case kIROp_undefined: + { + ensureWorkingsetMemory(funcBuilder, inst); + } + break; + case kIROp_Param: + { + auto operand = ensureWorkingsetMemory(funcBuilder, inst); + if (isFirstBlock(inst->getParent())) + { + funcBuilder.parameterOffsets.add(operand.offset); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getDataType(), + &sizeAlignment); + funcBuilder.parameterSize = + operand.offset + (uint32_t)sizeAlignment.getStride(); + } + } + break; + case kIROp_Var: + { + IRBuilder builder(inst); + auto type = tryGetPointedToType(&builder, inst->getDataType()); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + type, + &sizeAlignment); + auto varStorage = allocReg( + funcBuilder, + (size_t)sizeAlignment.size, + (size_t)sizeAlignment.alignment); + writeInst( + funcBuilder, + VMOp::GetWorkingSetPtr, + varStorage.offset, + ensureWorkingsetMemory(funcBuilder, inst)); + } + break; + case kIROp_Load: + { + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getDataType(), + &sizeAlignment); + writeInst( + funcBuilder, + VMOp::Load, + (uint32_t)sizeAlignment.getStride(), + ensureWorkingsetMemory(funcBuilder, inst), + ensureInst(inst->getOperand(0))); + } + break; + case kIROp_Store: + { + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getOperand(1)->getDataType(), + &sizeAlignment); + writeInst( + funcBuilder, + VMOp::Store, + (uint32_t)sizeAlignment.getStride(), + ensureInst(inst->getOperand(0)), + ensureInst(inst->getOperand(1))); + } + break; + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_And: + case kIROp_FRem: + case kIROp_IRem: + case kIROp_Or: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_Less: + case kIROp_Leq: + case kIROp_Greater: + case kIROp_Geq: + case kIROp_Eql: + case kIROp_Neq: + { + auto opInfo = translateArithmeticOp(inst); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getDataType(), + &sizeAlignment); + writeInst( + funcBuilder, + opInfo.opcode, + opInfo.opcodeExtension, + ensureWorkingsetMemory(funcBuilder, inst), + ensureInst(inst->getOperand(0)), + ensureInst(inst->getOperand(1))); + } + break; + case kIROp_Neg: + case kIROp_Not: + case kIROp_BitNot: + { + auto opInfo = translateArithmeticOp(inst); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getDataType(), + &sizeAlignment); + writeInst( + funcBuilder, + opInfo.opcode, + opInfo.opcodeExtension, + ensureWorkingsetMemory(funcBuilder, inst), + ensureInst(inst->getOperand(0))); + } + break; + case kIROp_unconditionalBranch: + case kIROp_loop: + { + // Write phi arguments into param registers. + auto branch = as<IRUnconditionalBranch>(inst); + auto params = branch->getTargetBlock()->getParams(); + List<IRInst*> paramList; + for (auto param : params) + { + paramList.add(param); + } + if (paramList.getCount() != (Index)branch->getArgCount()) + { + SLANG_UNEXPECTED("Invalid number of arguments for branch instruction"); + } + for (UInt i = 0; i < branch->getArgCount(); i++) + { + auto arg = branch->getArg(i); + auto param = paramList[i]; + auto paramReg = ensureWorkingsetMemory(funcBuilder, param); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + param->getDataType(), + &sizeAlignment); + writeInst( + funcBuilder, + VMOp::Copy, + (uint32_t)sizeAlignment.getStride(), + paramReg, + ensureInst(arg)); + } + // Write jump inst. + VMOperand relocOperand = {}; + writeInst(funcBuilder, VMOp::Jump, 0, relocOperand); + InstRelocationEntry entry; + entry.block = (IRBlock*)inst->getOperand(0); + entry.offsetToOperand = funcBuilder.code.getCount() - sizeof(VMOperand); + relocations.add(entry); + } + break; + case kIROp_ifElse: + { + VMOperand relocOperand = {}; + writeInst( + funcBuilder, + VMOp::JumpIf, + 0, + ensureInst(inst->getOperand(0)), + relocOperand, + relocOperand); + InstRelocationEntry entry; + entry.block = (IRBlock*)inst->getOperand(1); + entry.offsetToOperand = funcBuilder.code.getCount() - sizeof(VMOperand) * 2; + relocations.add(entry); + entry.block = (IRBlock*)inst->getOperand(2); + entry.offsetToOperand = funcBuilder.code.getCount() - sizeof(VMOperand); + relocations.add(entry); + } + break; + case kIROp_Call: + { + auto callInst = as<IRCall>(inst); + auto callee = as<IRFunc>(callInst->getCallee()); + UnownedStringSlice def; + IRInst* intrinsicInst; + if (findTargetIntrinsicDefinition( + callee, + codeGenContext->getTargetCaps(), + def, + intrinsicInst)) + { + auto calleeOperand = addStringLiteral(def); + List<VMOperand> operands; + operands.add(ensureWorkingsetMemory(funcBuilder, inst)); + operands.add(calleeOperand); + for (UInt i = 0; i < callInst->getArgCount(); ++i) + { + operands.add(ensureInst(callInst->getArg(i))); + } + writeInst(funcBuilder, VMOp::CallExt, 0, operands.getArrayView()); + break; + } + List<VMOperand> operands; + int calleeId = -1; + mapFuncToId.tryGetValue(callee, calleeId); + SLANG_ASSERT(calleeId != -1); + VMOperand calleeOperand = {}; + calleeOperand.sectionId = kSlangByteCodeSectionFuncs; + calleeOperand.offset = calleeId; + calleeOperand.setType(OperandDataType::Int32); + operands.add(ensureWorkingsetMemory(funcBuilder, inst)); + operands.add(calleeOperand); + for (UInt i = 0; i < callInst->getArgCount(); ++i) + { + operands.add(ensureInst(callInst->getArg(i))); + } + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getDataType(), + &sizeAlignment); + writeInst( + funcBuilder, + VMOp::Call, + (uint32_t)sizeAlignment.getStride(), + operands.getArrayView()); + } + break; + case kIROp_MissingReturn: + case kIROp_Return: + { + auto returnInst = as<IRReturn>(inst); + if (returnInst && returnInst->getVal()->getOp() != kIROp_VoidLit) + { + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + returnInst->getVal()->getDataType(), + &sizeAlignment); + writeInst( + funcBuilder, + VMOp::Ret, + (uint32_t)sizeAlignment.getStride(), + ensureInst(returnInst->getOperand(0))); + } + else + { + writeInst(funcBuilder, VMOp::Ret, 0); + } + } + break; + case kIROp_GetElementPtr: + { + auto getElemInst = as<IRGetElementPtr>(inst); + auto base = getElemInst->getBase(); + auto index = getElemInst->getIndex(); + IRBuilder builder(inst); + auto elementType = tryGetPointedToType(&builder, getElemInst->getDataType()); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + elementType, + &sizeAlignment); + auto stride = sizeAlignment.getStride(); + auto baseOperand = ensureInst(base); + auto indexOperand = ensureInst(index); + writeInst( + funcBuilder, + VMOp::GetElementPtr, + (uint32_t)stride, + ensureWorkingsetMemory(funcBuilder, inst), + baseOperand, + indexOperand); + } + break; + case kIROp_FieldAddress: + { + auto fieldAddrInst = as<IRFieldAddress>(inst); + auto base = fieldAddrInst->getBase(); + auto fieldKey = (IRStructKey*)fieldAddrInst->getField(); + IRBuilder builder(base); + + auto structType = + as<IRStructType>(tryGetPointedToType(&builder, base->getDataType())); + IRIntegerValue offset = 0; + auto field = findStructField(structType, fieldKey); + getNaturalOffset( + codeGenContext->getTargetProgram()->getOptionSet(), + field, + &offset); + + writeInst( + funcBuilder, + VMOp::Add, + getExtCode(inst->getDataType()), + ensureWorkingsetMemory(funcBuilder, inst), + ensureInst(base), + addConstantValue((uint64_t)offset)); + } + break; + case kIROp_GetOffsetPtr: + { + auto getOffsetPtrInst = as<IRGetOffsetPtr>(inst); + auto base = getOffsetPtrInst->getBase(); + auto offset = getOffsetPtrInst->getOffset(); + IRSizeAndAlignment sizeAlignment = {}; + IRBuilder builder(inst); + auto elementType = tryGetPointedToType(&builder, getOffsetPtrInst->getDataType()); + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + elementType, + &sizeAlignment); + writeInst( + funcBuilder, + VMOp::OffsetPtr, + (uint32_t)sizeAlignment.getStride(), + ensureWorkingsetMemory(funcBuilder, inst), + ensureInst(base), + ensureInst(offset)); + } + break; + case kIROp_FieldExtract: + { + auto fieldExtractInst = as<IRFieldExtract>(inst); + auto base = fieldExtractInst->getBase(); + auto fieldKey = (IRStructKey*)fieldExtractInst->getField(); + + auto structType = as<IRStructType>(base->getDataType()); + IRIntegerValue offset = 0; + auto field = findStructField(structType, fieldKey); + getNaturalOffset( + codeGenContext->getTargetProgram()->getOptionSet(), + field, + &offset); + + auto baseOperand = ensureInst(base); + baseOperand.offset += (uint32_t)offset; + mapInstToOperand[inst] = baseOperand; + } + break; + case kIROp_GetElement: + { + auto getElemInst = as<IRGetElement>(inst); + auto base = getElemInst->getBase(); + auto index = getElemInst->getIndex(); + auto elementType = getElemInst->getDataType(); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + elementType, + &sizeAlignment); + auto stride = sizeAlignment.getStride(); + auto baseOperand = ensureInst(base); + if (as<IRIntLit>(index)) + { + baseOperand.offset += (uint32_t)(stride * getIntVal(index)); + mapInstToOperand[inst] = baseOperand; + break; + } + writeInst( + funcBuilder, + VMOp::GetElement, + (uint32_t)stride, + ensureWorkingsetMemory(funcBuilder, inst), + baseOperand, + ensureInst(index)); + } + break; + case kIROp_BitCast: + { + auto operand = ensureInst(inst->getOperand(0)); + mapInstToOperand[inst] = operand; + } + break; + case kIROp_IntCast: + case kIROp_CastIntToPtr: + case kIROp_CastPtrToInt: + case kIROp_CastIntToFloat: + case kIROp_CastFloatToInt: + case kIROp_FloatCast: + emitCast(funcBuilder, VMOp::Cast, inst); + break; + case kIROp_swizzle: + { + auto swizzleInst = as<IRSwizzle>(inst); + auto base = swizzleInst->getBase(); + auto baseOperand = ensureInst(base); + auto count = (uint32_t)swizzleInst->getElementCount(); + List<VMOperand> operands; + operands.add(ensureWorkingsetMemory(funcBuilder, inst)); + operands.add(baseOperand); + for (UInt i = 0; i < count; ++i) + { + auto index = (uint32_t)getIntVal(swizzleInst->getElementIndex(i)); + VMOperand operand; + operand.sectionId = kSlangByteCodeSectionImmediate; + operand.offset = index; + operands.add(operand); + } + writeInst( + funcBuilder, + VMOp::Swizzle, + getExtCode(inst->getDataType()), + operands.getArrayView()); + } + break; + case kIROp_MakeArray: + { + auto result = ensureWorkingsetMemory(funcBuilder, inst); + auto arrayType = as<IRArrayTypeBase>(inst->getDataType()); + auto elementType = arrayType->getElementType(); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + elementType, + &sizeAlignment); + auto stride = (uint32_t)sizeAlignment.getStride(); + for (UInt i = 0; i < inst->getOperandCount(); ++i) + { + VMOperand elementOperand = result; + elementOperand.offset += (uint32_t)(stride * i); + writeInst( + funcBuilder, + VMOp::Copy, + stride, + elementOperand, + ensureInst(inst->getOperand(i))); + } + } + break; + case kIROp_MakeArrayFromElement: + { + auto result = ensureWorkingsetMemory(funcBuilder, inst); + auto arrayType = as<IRArrayTypeBase>(inst->getDataType()); + auto elementType = arrayType->getElementType(); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + elementType, + &sizeAlignment); + auto stride = (uint32_t)sizeAlignment.getStride(); + for (Index i = 0; i < getIntVal(arrayType->getElementCount()); ++i) + { + VMOperand elementOperand = result; + elementOperand.offset += (uint32_t)(stride * i); + writeInst( + funcBuilder, + VMOp::Copy, + stride, + elementOperand, + ensureInst(inst->getOperand(0))); + } + } + break; + case kIROp_MakeStruct: + { + auto result = ensureWorkingsetMemory(funcBuilder, inst); + auto structType = as<IRStructType>(inst->getDataType()); + List<IRStructField*> fields; + for (auto field : structType->getFields()) + { + fields.add(field); + } + for (UInt i = 0; i < inst->getOperandCount(); ++i) + { + auto field = fields[i]; + IRIntegerValue offset = 0; + getNaturalOffset( + codeGenContext->getTargetProgram()->getOptionSet(), + field, + &offset); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + field->getFieldType(), + &sizeAlignment); + VMOperand elementOperand = result; + elementOperand.offset += (uint32_t)offset; + writeInst( + funcBuilder, + VMOp::Copy, + (uint32_t)sizeAlignment.getStride(), + elementOperand, + ensureInst(inst->getOperand(i))); + } + } + break; + case kIROp_MakeVector: + case kIROp_MakeMatrix: + { + auto result = ensureWorkingsetMemory(funcBuilder, inst); + for (UInt i = 0; i < inst->getOperandCount(); ++i) + { + VMOperand elementOperand = result; + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + inst->getOperand(i)->getDataType(), + &sizeAlignment); + writeInst( + funcBuilder, + VMOp::Copy, + (uint32_t)sizeAlignment.getStride(), + elementOperand, + ensureInst(inst->getOperand(i))); + result.offset += (uint32_t)sizeAlignment.getStride(); + } + } + break; + case kIROp_MakeVectorFromScalar: + { + auto result = ensureWorkingsetMemory(funcBuilder, inst); + auto vectorType = as<IRVectorType>(inst->getDataType()); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + vectorType->getElementType(), + &sizeAlignment); + auto stride = (uint32_t)sizeAlignment.getStride(); + for (Index i = 0; i < getIntVal(vectorType->getElementCount()); ++i) + { + VMOperand elementOperand = result; + elementOperand.offset += (uint32_t)(stride * i); + writeInst( + funcBuilder, + VMOp::Copy, + stride, + elementOperand, + ensureInst(inst->getOperand(0))); + } + } + break; + case kIROp_MakeMatrixFromScalar: + { + auto result = ensureWorkingsetMemory(funcBuilder, inst); + auto matrixType = as<IRMatrixType>(inst->getDataType()); + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + matrixType->getElementType(), + &sizeAlignment); + auto stride = (uint32_t)sizeAlignment.getStride(); + for (Index i = 0; i < getIntVal(matrixType->getRowCount()); ++i) + { + for (Index j = 0; j < getIntVal(matrixType->getColumnCount()); ++j) + { + writeInst( + funcBuilder, + VMOp::Copy, + stride, + result, + ensureInst(inst->getOperand(0))); + result.offset += stride; + } + } + } + break; + case kIROp_Printf: + { + List<VMOperand> operands; + operands.add(ensureInst(inst->getOperand(0))); + auto tuple = inst->getOperand(1); + if (auto makeTuple = as<IRMakeStruct>(tuple)) + { + for (UInt i = 0; i < makeTuple->getOperandCount(); i++) + { + operands.add(ensureInst(makeTuple->getOperand(i))); + } + } + else + { + // If not a tuple, it should be a single value. + operands.add(ensureInst(tuple)); + } + writeInst(funcBuilder, VMOp::Print, 0, operands.getArrayView()); + } + break; + default: + SLANG_UNIMPLEMENTED_X("VM bytecode gen for inst."); + } + } + + void emitFunction(IRFunc* func) + { + VMByteCodeFunctionBuilder funcBuilder; + funcBuilder.name = addStringLiteral(getName(func).getUnownedSlice()); + + IRSizeAndAlignment sizeAlignment = {}; + getNaturalSizeAndAlignment( + codeGenContext->getTargetProgram()->getOptionSet(), + func->getResultType(), + &sizeAlignment); + funcBuilder.resultSize = (uint32_t)sizeAlignment.getStride(); + + Dictionary<IRBlock*, Index> mapBlockToByteOffset; + List<InstRelocationEntry> relocations; + + for (auto block : func->getBlocks()) + { + mapBlockToByteOffset[block] = funcBuilder.code.getCount(); + + for (auto inst : block->getChildren()) + { + funcBuilder.instOffsets.add(funcBuilder.code.getCount()); + emitInst(funcBuilder, inst, relocations); + } + } + + // Apply relocations for jump targets. + for (auto reloc : relocations) + { + Index offset = mapBlockToByteOffset.getValue(reloc.block); + uint8_t* codePtr = (funcBuilder.code.getBuffer() + reloc.offsetToOperand); + VMOperand* operand = (VMOperand*)codePtr; + operand->sectionId = kSlangByteCodeSectionInsts; + operand->offset = (uint32_t)offset; + } + funcBuilder.workingSetSizeInBytes = + alignUp(funcBuilder.workingSetSizeInBytes, (uint32_t)sizeof(uint64_t)); + + byteCodeBuilder.functions.add(funcBuilder); + } + + void emitEntryPoints(LinkedIR& linkedIR) + { + Dictionary<IRInst*, HashSet<IRFunc*>> referencingEntryPoints; + buildEntryPointReferenceGraph(referencingEntryPoints, linkedIR.module); + OrderedHashSet<IRFunc*> entryPointSet; + for (auto entryPoint : linkedIR.entryPoints) + { + auto entryPointDecor = entryPoint->findDecoration<IREntryPointDecoration>(); + if (!entryPointDecor) + continue; + if (entryPointDecor->getProfile().getStage() != Stage::Dispatch) + continue; + entryPointSet.add(entryPoint); + } + + List<IRFunc*> functionsToEmit; + + // Emit all entrypoints first. + for (auto entryPoint : entryPointSet) + { + // Emit the function for the entry point. + functionsToEmit.add(entryPoint); + } + + // Emit remaining funcitons, if they are called by entry points. + for (auto globalInst : linkedIR.module->getGlobalInsts()) + { + auto func = as<IRFunc>(globalInst); + + if (!func) + continue; + + // Skip if already emitted as an entry point. + if (entryPointSet.contains(func)) + continue; + + HashSet<IRFunc*>* entryPointRefs = referencingEntryPoints.tryGetValue(func); + if (!entryPointRefs) + continue; + + // If the function is referenced by any entry point, emit it. + bool referencedByHostEntryPoint = false; + for (auto entryPoint : *entryPointRefs) + { + if (entryPointSet.contains(entryPoint)) + { + referencedByHostEntryPoint = true; + break; + } + } + if (referencedByHostEntryPoint) + { + functionsToEmit.add(func); + } + } + + // Emit all functions. + for (Index i = 0; i < functionsToEmit.getCount(); i++) + { + mapFuncToId[functionsToEmit[i]] = (int)i; + } + for (auto func : functionsToEmit) + { + emitFunction(func); + } + } +}; + +SlangResult emitVMByteCodeForEntryPoints( + CodeGenContext* codeGenContext, + LinkedIR& linkedIR, + VMByteCodeBuilder& byteCode) +{ + ByteCodeEmitter emitter(byteCode, codeGenContext); + emitter.emitEntryPoints(linkedIR); + return SLANG_OK; +} + +SlangResult VMByteCodeBuilder::serialize(slang::IBlob** outBlob) +{ + OwnedMemoryStream ms(FileAccess::Write); + ms.write(&kSlangByteCodeFourCC, sizeof(uint32_t)); + ms.write(&kSlangByteCodeVersion, sizeof(uint32_t)); + + // Write functions section. + ms.write(&kSlangByteCodeFunctionsFourCC, sizeof(uint32_t)); + uint32_t functionChunkSizeStart = (uint32_t)ms.getPosition(); + uint32_t zero = 0; + ms.write(&zero, sizeof(uint32_t)); // Reserve space for function chunk size. + + uint32_t functionCount = (uint32_t)functions.getCount(); + ms.write(&functionCount, sizeof(uint32_t)); + // Reserve space for function offsets. + auto functionOffsetStart = ms.getPosition(); + for (uint32_t i = 0; i < functionCount; ++i) + { + ms.write(&zero, sizeof(uint32_t)); + } + List<uint32_t> functionOffsets; + for (uint32_t i = 0; i < functionCount; ++i) + { + functionOffsets.add((uint32_t)ms.getPosition()); + + auto& function = functions[i]; + VMFuncHeader funcHeader; + funcHeader.name = function.name; + funcHeader.codeSize = (uint32_t)function.code.getCount(); + funcHeader.parameterCount = (uint32_t)function.parameterOffsets.getCount(); + funcHeader.workingSetSizeInBytes = function.workingSetSizeInBytes; + funcHeader.returnValueSizeInBytes = function.resultSize; + funcHeader.parameterSizeInBytes = function.parameterSize; + ms.write(&funcHeader, sizeof(funcHeader)); + ms.write( + function.parameterOffsets.getBuffer(), + sizeof(uint32_t) * function.parameterOffsets.getCount()); + + ms.write(function.code.begin(), funcHeader.codeSize); + } + uint32_t functionChunkSize = + (uint32_t)(ms.getPosition() - functionChunkSizeStart - sizeof(uint32_t)); + + // Write kernel Blob section. + ms.write(&kSlangByteCodeKernelBlobFourCC, sizeof(uint32_t)); + uint32_t kernelBlobSize = (uint32_t)kernelBlob->getBufferSize(); + ms.write(&kernelBlobSize, sizeof(uint32_t)); + ms.write(kernelBlob->getBufferPointer(), kernelBlobSize); + + // Write constant section. + ms.write(&kSlangByteCodeConstantsFourCC, sizeof(uint32_t)); + uint32_t constanBlobSize = (uint32_t)constantSection.getCount(); + ms.write(&constanBlobSize, sizeof(uint32_t)); + uint32_t stringCount = (uint32_t)stringOffsets.getCount(); + ms.write(&stringCount, sizeof(uint32_t)); + ms.write(stringOffsets.getBuffer(), sizeof(uint32_t) * stringCount); + ms.write(constantSection.begin(), constanBlobSize); + + auto blob = RawBlob::create(ms.getContents().getBuffer(), ms.getContents().getCount()); + + // Patch in the function chunk size. + uint32_t* functionChunkSizePtr = + (uint32_t*)((uint8_t*)blob->getBufferPointer() + functionChunkSizeStart); + *functionChunkSizePtr = functionChunkSize; + + // Patch in the function offsets. + auto funcOffsetTable = (uint32_t*)((uint8_t*)blob->getBufferPointer() + functionOffsetStart); + for (uint32_t i = 0; i < functionCount; ++i) + { + funcOffsetTable[i] = functionOffsets[i]; + } + + *outBlob = blob.detach(); + return SLANG_OK; +} + +} // namespace Slang diff --git a/source/slang/slang-emit-vm.h b/source/slang/slang-emit-vm.h new file mode 100644 index 000000000..ee412c752 --- /dev/null +++ b/source/slang/slang-emit-vm.h @@ -0,0 +1,38 @@ +#ifndef SLANG_EMIT_VM_H +#define SLANG_EMIT_VM_H + +#include "slang-emit-base.h" +#include "slang-ir-link.h" +#include "slang-vm-bytecode.h" + +namespace Slang +{ + +struct VMByteCodeFunctionBuilder +{ + VMOperand name = {}; + uint32_t workingSetSizeInBytes = 0; + List<uint8_t> code; + List<Index> instOffsets; + List<uint32_t> parameterOffsets; + uint32_t resultSize = 0; + uint32_t parameterSize = 0; +}; + +struct VMByteCodeBuilder +{ + List<VMByteCodeFunctionBuilder> functions; + ComPtr<slang::IBlob> kernelBlob; + + List<uint8_t> constantSection; + List<uint32_t> stringOffsets; + SlangResult serialize(slang::IBlob** outBlob); +}; + +SlangResult emitVMByteCodeForEntryPoints( + CodeGenContext* codeGenContext, + LinkedIR& linkedIR, + VMByteCodeBuilder& byteCode); +} // namespace Slang + +#endif diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 9cfa7dcae..aa7387e22 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -15,8 +15,10 @@ #include "slang-emit-glsl.h" #include "slang-emit-hlsl.h" #include "slang-emit-metal.h" +#include "slang-emit-slang.h" #include "slang-emit-source-writer.h" #include "slang-emit-torch.h" +#include "slang-emit-vm.h" #include "slang-emit-wgsl.h" #include "slang-ir-any-value-inference.h" #include "slang-ir-autodiff.h" @@ -116,6 +118,7 @@ #include "slang-syntax.h" #include "slang-type-layout.h" #include "slang-visitor.h" +#include "slang-vm-bytecode.h" #include <assert.h> @@ -825,6 +828,7 @@ Result linkAndOptimizeIR( switch (target) { case CodeGenTarget::HostCPPSource: + case CodeGenTarget::HostVM: break; case CodeGenTarget::CUDASource: collectOptiXEntryPointUniformParams(irModule); @@ -859,6 +863,7 @@ Result linkAndOptimizeIR( case CodeGenTarget::HostCPPSource: case CodeGenTarget::CPPSource: case CodeGenTarget::CUDASource: + case CodeGenTarget::HostVM: break; } @@ -1154,6 +1159,15 @@ Result linkAndOptimizeIR( cleanupGenerics(targetProgram, irModule, sink); dumpIRIfEnabled(codeGenContext, irModule, "AFTER-LOWER-GENERICS"); + // Don't need to run any further target-dependent passes if we are generating code + // for host vm. + if (target == CodeGenTarget::HostVM) + { + performForceInlining(irModule); + simplifyIR(targetProgram, irModule, defaultIRSimplificationOptions, sink); + return SLANG_OK; + } + // After dynamic dispatch logic is resolved into ordinary function calls, // we can now run our stage specialization logic. if (requiredLoweringPassSet.specializeStageSwitch) @@ -2329,4 +2343,47 @@ SlangResult emitSPIRVForEntryPointsDirectly( return SLANG_OK; } +SlangResult emitHostVMCode(CodeGenContext* codeGenContext, ComPtr<IArtifact>& outArtifact) +{ + LinkedIR linkedIR; + LinkingAndOptimizationOptions linkingAndOptimizationOptions; + SLANG_RETURN_ON_FAIL( + linkAndOptimizeIR(codeGenContext, linkingAndOptimizationOptions, linkedIR)); + + VMByteCodeBuilder byteCode; + SLANG_RETURN_ON_FAIL(emitVMByteCodeForEntryPoints(codeGenContext, linkedIR, byteCode)); + + String slangDeclaration; + SLANG_RETURN_ON_FAIL( + emitSlangDeclarationsForEntryPoints(codeGenContext, linkedIR, slangDeclaration)); + + slang::SessionDesc sessionDesc = {}; + ComPtr<slang::ISession> slangSession; + SLANG_RETURN_ON_FAIL( + codeGenContext->getSession()->createSession(sessionDesc, slangSession.writeRef())); + auto linkage = static_cast<Linkage*>(slangSession.get()); + + ComPtr<ISlangBlob> diagnostics; + auto module = slangSession->loadModuleFromSource( + "kernel", + "kernel.slang", + StringBlob::create(slangDeclaration), + diagnostics.writeRef()); + if (!module) + return SLANG_FAIL; + RefPtr<Module> newModule = new Module(linkage); + newModule->setModuleDecl(static_cast<Module*>(module)->getModuleDecl()); + newModule->setIRModule(linkedIR.module); + newModule->setName("kernels"); + SLANG_RETURN_ON_FAIL(newModule->serialize(byteCode.kernelBlob.writeRef())); + + ComPtr<slang::IBlob> byteCodeBlob; + SLANG_RETURN_ON_FAIL(byteCode.serialize(byteCodeBlob.writeRef())); + + outArtifact = ArtifactUtil::createArtifactForCompileTarget(SLANG_HOST_VM); + outArtifact->addRepresentationUnknown(byteCodeBlob); + + return SLANG_OK; +} + } // namespace Slang diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index 5e1f2eb21..0b89baf64 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -283,16 +283,6 @@ struct DeadCodeEliminationContext } }; -bool isFirstBlock(IRInst* inst) -{ - auto block = as<IRBlock>(inst); - if (!block) - return false; - if (!block->getParent()) - return false; - return block->getParent()->getFirstBlock() == block; -} - bool isPtrUsed(IRInst* ptrInst) { for (auto use = ptrInst->firstUse; use; use = use->nextUse) diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 4919850eb..36aec22b5 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -2232,4 +2232,14 @@ UnownedStringSlice getMangledName(IRInst* inst) return UnownedStringSlice(); } +bool isFirstBlock(IRInst* inst) +{ + auto block = as<IRBlock>(inst); + if (!block) + return false; + if (!block->getParent()) + return false; + return block->getParent()->getFirstBlock() == block; +} + } // namespace Slang diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 0a8bc9b1d..b111f8abf 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -387,6 +387,7 @@ void legalizeDefUse(IRGlobalValueWithCode* func); UnownedStringSlice getMangledName(IRInst* inst); +bool isFirstBlock(IRInst* inst); } // namespace Slang #endif diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index 3c0bcf8db..d73bc4307 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -3615,6 +3615,7 @@ SlangResult OptionsParser::_parse(int argc, char const* const* argv) case CodeGenTarget::MetalLibAssembly: case CodeGenTarget::Metal: case CodeGenTarget::WGSL: + case CodeGenTarget::HostVM: rawOutput.isWholeProgram = true; break; case CodeGenTarget::SPIRV: diff --git a/source/slang/slang-profile-defs.h b/source/slang/slang-profile-defs.h index 29e2bbf2e..4781d5bc3 100644 --- a/source/slang/slang-profile-defs.h +++ b/source/slang/slang-profile-defs.h @@ -73,7 +73,7 @@ PROFILE_STAGE(Callable, callable, SLANG_STAGE_CALLABLE) PROFILE_STAGE(Mesh, mesh, SLANG_STAGE_MESH) PROFILE_STAGE(Amplification, amplification, SLANG_STAGE_AMPLIFICATION) - +PROFILE_STAGE(Dispatch, dispatch, SLANG_STAGE_DISPATCH) // Note: HLSL and Direct3D convention erroneously uses the term "Pixel Shader" // for the thing that shades *fragments*. Slang strives to treat the more correct diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 63bcd1ba2..68df665d4 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -2254,6 +2254,7 @@ LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targe case CodeGenTarget::ShaderSharedLibrary: case CodeGenTarget::CPPSource: case CodeGenTarget::CSource: + case CodeGenTarget::HostVM: { // For now lets use some fairly simple CPU binding rules diff --git a/source/slang/slang-vm-bytecode.cpp b/source/slang/slang-vm-bytecode.cpp new file mode 100644 index 000000000..1eafa4e57 --- /dev/null +++ b/source/slang/slang-vm-bytecode.cpp @@ -0,0 +1,424 @@ +#include "slang-vm-bytecode.h" + +#include "core/slang-blob.h" +#include "core/slang-stream.h" +#include "core/slang-string-escape-util.h" + +using namespace slang; + +namespace Slang +{ +static SlangResult consumeFourCC(MemoryStreamBase& stream, uint32_t expected) +{ + uint32_t fourCC = 0; + size_t bytesRead = 0; + SLANG_RETURN_ON_FAIL(stream.read(&fourCC, sizeof(fourCC), bytesRead)); + if (fourCC != expected) + { + return SLANG_FAIL; + } + return SLANG_OK; +} + +template<typename T> +static SlangResult readValue(MemoryStreamBase& stream, T& value) +{ + size_t bytesRead = 0; + SLANG_RETURN_ON_FAIL(stream.read(&value, sizeof(T), bytesRead)); + if (bytesRead != sizeof(T)) + { + return SLANG_FAIL; // Not enough data + } + return SLANG_OK; +} + +static SlangResult readUInt32(MemoryStreamBase& stream, uint32_t& value) +{ + return readValue(stream, value); +} + +SlangResult initVMModule(uint8_t* code, uint32_t codeSize, VMModuleView* moduleView) +{ + MemoryStreamBase stream(FileAccess::Read, code, codeSize); + moduleView->code = code; + + // Check the FourCC + SLANG_RETURN_ON_FAIL(consumeFourCC(stream, kSlangByteCodeFourCC)); + + // Check the version + uint32_t version; + size_t bytesRead = 0; + SLANG_RETURN_ON_FAIL(stream.read(&version, sizeof(version), bytesRead)); + if (version > kSlangByteCodeVersion) + { + return SLANG_FAIL; // Unsupported version + } + + // Read the function section + SLANG_RETURN_ON_FAIL(consumeFourCC(stream, kSlangByteCodeFunctionsFourCC)); + uint32_t functionSectionSize = 0; + SLANG_RETURN_ON_FAIL(readUInt32(stream, functionSectionSize)); + auto funcDataStart = stream.getPosition(); + if (functionSectionSize < sizeof(uint32_t)) // At least the function count + { + return SLANG_FAIL; // Invalid section size + } + + SLANG_RETURN_ON_FAIL(readUInt32(stream, moduleView->functionCount)); + moduleView->functionOffsets = reinterpret_cast<uint32_t*>(code + stream.getPosition()); + + stream.seek(SeekOrigin::Start, funcDataStart + functionSectionSize); + + // Read the kernel blob section + SLANG_RETURN_ON_FAIL(consumeFourCC(stream, kSlangByteCodeKernelBlobFourCC)); + SLANG_RETURN_ON_FAIL(readUInt32(stream, moduleView->kernelBlobSize)); + if (moduleView->kernelBlobSize > codeSize - stream.getPosition()) + { + return SLANG_FAIL; // Invalid kernel blob size + } + moduleView->kernelBlob = code + stream.getPosition(); + stream.seek(SeekOrigin::Current, moduleView->kernelBlobSize); + + // Read the constants section + SLANG_RETURN_ON_FAIL(consumeFourCC(stream, kSlangByteCodeConstantsFourCC)); + SLANG_RETURN_ON_FAIL(readUInt32(stream, moduleView->constantBlobSize)); + if (moduleView->constantBlobSize < sizeof(uint32_t)) // At least the constant count + { + return SLANG_FAIL; // Invalid section size + } + SLANG_RETURN_ON_FAIL(readUInt32(stream, moduleView->stringCount)); + moduleView->stringOffsets = reinterpret_cast<uint32_t*>(code + stream.getPosition()); + stream.seek(SeekOrigin::Current, moduleView->stringCount * sizeof(uint32_t)); + moduleView->constants = code + stream.getPosition(); + + for (uint32_t i = 0; i < moduleView->functionCount; i++) + { + auto functionStart = code + moduleView->functionOffsets[i]; + auto header = (VMFuncHeader*)(functionStart); + VMFunctionView functionView; + functionView.moduleView = moduleView; + functionView.header = (VMFuncHeader*)(functionStart); + functionView.paramOffsets = (uint32_t*)(functionStart + sizeof(VMFuncHeader)); + functionView.name = (const char*)moduleView->constants + + moduleView->stringOffsets[functionView.header->name.offset]; + functionView.functionCode = + (uint8_t*)functionView.paramOffsets + sizeof(uint32_t) * header->parameterCount; + functionView.functionCodeEnd = functionView.functionCode + functionView.header->codeSize; + moduleView->functionViews.add(functionView); + } + return SLANG_OK; +} + +StringBuilder& operator<<(StringBuilder& sb, VMOp op) +{ + switch (op) + { + case VMOp::Add: + sb << "add"; + break; + case VMOp::Sub: + sb << "sub"; + break; + case VMOp::Mul: + sb << "mul"; + break; + case VMOp::Div: + sb << "div"; + break; + case VMOp::Rem: + sb << "rem"; + break; + case VMOp::And: + sb << "and"; + break; + case VMOp::Or: + sb << "or"; + break; + case VMOp::BitXor: + sb << "bitxor"; + break; + case VMOp::BitNot: + sb << "bitnot"; + break; + case VMOp::Shl: + sb << "shl"; + break; + case VMOp::Shr: + sb << "shr"; + break; + case VMOp::Equal: + sb << "equal"; + break; + case VMOp::Neq: + sb << "neq"; + break; + case VMOp::Less: + sb << "less"; + break; + case VMOp::Leq: + sb << "leq"; + break; + case VMOp::Greater: + sb << "greater"; + break; + case VMOp::Geq: + sb << "geq"; + break; + case VMOp::Nop: + sb << "nop"; + break; + case VMOp::Neg: + sb << "neg"; + break; + case VMOp::Not: + sb << "not"; + break; + case VMOp::Jump: + sb << "jump"; + break; + case VMOp::JumpIf: + sb << "jumpif"; + break; + case VMOp::Dispatch: + sb << "dispatch"; + break; + case VMOp::Load: + sb << "load"; + break; + case VMOp::Store: + sb << "store"; + break; + case VMOp::Copy: + sb << "copy"; + break; + case VMOp::GetWorkingSetPtr: + sb << "get_working_set_ptr"; + break; + case VMOp::GetElementPtr: + sb << "get_element_ptr"; + break; + case VMOp::OffsetPtr: + sb << "offset_ptr"; + break; + case VMOp::GetElement: + sb << "get_element"; + break; + case VMOp::Cast: + sb << "cast"; + break; + case VMOp::CallExt: + sb << "call_ext"; + break; + case VMOp::Call: + sb << "call"; + break; + case VMOp::Swizzle: + sb << "swizzle"; + break; + case VMOp::Ret: + sb << "ret"; + break; + case VMOp::Print: + sb << "print"; + break; + default: + sb << "unknown_op(" << static_cast<uint32_t>(op) << ")"; + break; + } + return sb; +} + +StringBuilder& operator<<(StringBuilder& sb, ArithmeticExtCode extCode) +{ + switch (extCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + sb << "i"; + break; + case kSlangByteCodeScalarTypeUnsignedInt: + sb << "u"; + break; + case kSlangByteCodeScalarTypeFloat: + sb << "f"; + break; + default: + sb << "x"; + break; + } + sb << (8 << extCode.scalarBitWidth); + if (extCode.vectorSize > 1) + { + sb << "v" << extCode.vectorSize; + } + return sb; +} + +void printVMInst(StringBuilder& sb, VMModuleView* moduleView, VMInstHeader* inst) +{ + auto lenBeforeOpCode = sb.getLength(); + sb << inst->opcode; + if (inst->opcodeExtension != 0) + { + switch (inst->opcode) + { + case VMOp::Add: + case VMOp::Sub: + case VMOp::Mul: + case VMOp::Div: + case VMOp::Rem: + case VMOp::And: + case VMOp::Or: + case VMOp::BitXor: + case VMOp::BitNot: + case VMOp::BitAnd: + case VMOp::BitOr: + case VMOp::Neg: + case VMOp::Not: + case VMOp::Shl: + case VMOp::Shr: + case VMOp::Equal: + case VMOp::Neq: + case VMOp::Less: + case VMOp::Leq: + case VMOp::Greater: + case VMOp::Geq: + { + ArithmeticExtCode extCode; + memcpy(&extCode, &inst->opcodeExtension, sizeof(extCode)); + sb << "." << extCode; + } + break; + case VMOp::Cast: + { + ArithmeticExtCode extCode; + memcpy(&extCode, &inst->opcodeExtension, sizeof(extCode)); + sb << "." << extCode; + uint32_t fromCode = inst->opcodeExtension >> 16; + memcpy(&extCode, &fromCode, sizeof(extCode)); + sb << "." << extCode; + } + break; + default: + sb << "." << inst->opcodeExtension; + break; + } + } + auto opCodeLength = (int)(sb.getLength() - lenBeforeOpCode); + static const int kOpCodeColumnWidth = 20; + if (opCodeLength < kOpCodeColumnWidth) + { + for (int i = 0; i < kOpCodeColumnWidth - opCodeLength; i++) + { + sb << " "; + } + } + else + { + sb << " "; + } + for (uint32_t i = 0; i < inst->operandCount; i++) + { + if (i > 0) + sb << ", "; + auto operand = inst->getOperand(i); + switch (operand.sectionId) + { + case kSlangByteCodeSectionConstants: + switch (operand.getType()) + { + case OperandDataType::Int32: + { + int32_t val; + moduleView->getConstant<int32_t>(operand, val); + sb << "i32(" << val << ")"; + continue; + } + case OperandDataType::Int64: + { + int64_t val; + moduleView->getConstant<int64_t>(operand, val); + sb << "i64(" << val << ")"; + continue; + } + case OperandDataType::Float32: + { + float val; + moduleView->getConstant<float>(operand, val); + sb << "f32(" << val << ")"; + continue; + } + case OperandDataType::Float64: + { + double val; + moduleView->getConstant<double>(operand, val); + sb << "f32(" << val << ")"; + continue; + } + } + sb << "const:"; + break; + case kSlangByteCodeSectionInsts: + sb << "inst:"; + break; + case kSlangByteCodeSectionWorkingSet: + sb << "ws:"; + break; + case kSlangByteCodeSectionImmediate: + sb << "!"; + break; + case kSlangByteCodeSectionFuncs: + sb << moduleView->getFunction(operand.offset).name; + continue; + case kSlangByteCodeSectionStrings: + sb << "str:"; + if (operand.offset < moduleView->stringCount) + { + auto str = StringEscapeUtil::escapeString(UnownedStringSlice( + ((char*)moduleView->constants + moduleView->stringOffsets[operand.offset]))); + sb << str; + } + else + { + sb << "<invalid string index>"; + } + continue; + default: + sb << "section(" << operand.sectionId << ")@"; + break; + } + sb << String(inst->getOperand(i).offset, 16); + } +} + +StringBuilder& operator<<(StringBuilder& sb, VMModuleView& module) +{ + static const int addrColumnSize = 6; + for (uint32_t i = 0; i < module.functionCount; i++) + { + auto f = module.getFunction(i); + sb << "func " << f.name << ":\n"; + for (auto inst : f) + { + sb << " "; + auto loc = ((uint8_t*)inst - f.functionCode); + auto pos = sb.getLength(); + sb << String((uint32_t)loc, 16) << ": "; + auto addrLength = (int)(sb.getLength() - pos); + for (int j = 0; j < addrColumnSize - addrLength; j++) + { + sb << " "; + } + printVMInst(sb, &module, inst); + sb << "\n"; + } + } + return sb; +} + +VMFunctionView VMModuleView::getFunction(Index index) const +{ + if (index >= functionCount) + return {}; + return functionViews[index]; +} +} // namespace Slang diff --git a/source/slang/slang-vm-bytecode.h b/source/slang/slang-vm-bytecode.h new file mode 100644 index 000000000..5d030ee3a --- /dev/null +++ b/source/slang/slang-vm-bytecode.h @@ -0,0 +1,259 @@ +#ifndef SLANG_VM_BYTE_CODE_H +#define SLANG_VM_BYTE_CODE_H + +#include "core/slang-basic.h" +#include "core/slang-riff.h" +#include "slang-com-ptr.h" + +namespace Slang +{ + +/* +Slang ByteCode Module File Format + +# Header + - (4 bytes) FourCC: 'S', 'V', 'M', 'C' (kSlangByteCodeFourCC) + - (4 bytes uint) Version: 100 + +# Function Section + - (4 bytes) FourCC: 'S', 'V', 'F', 'N' (kSlangByteCodeFunctionsFourCC) + - (4 bytes uint) Function Count: number of functions in the module + - (uint array) Function Offsets: + array of "Function Count" 32-bit uints, storing byte offsets from + start of file for each function. + +## Function i: + - (32 bytes `VMFuncHeader`) Function metadata, describing the name and other + info needed for execution. VMFuncHeader::name is a `VMOperand` whose + sectionId = kSlangByteCodeSectionConstants, pointing to the constant section for the + function name. + - (uint32 * parameterCount array): array of "parameterCount" 32-bit uints, storing byte offsets + from start of the function's working set for each parameter. + - (byte array) Code: array of `header.codeSize` bytes, containing the instruction stream for the + function. Each instruction starts with a `VMInstHeader`, followed by + `VMInstHeader::operandCount` `VMOperand` structs. + +# Kernel Blob Section: binary data for the kernel blob + - (4 bytes) FourCC: 'S', 'V', 'K', 'N' (kSlangByteCodeKernelBlobFourCC) + - (4 bytes uint) Kernel Blob Size: size of the kernel blob in bytes. + - (byte array) Kernel Blob: array of "Kernel Blob Size" bytes, containing the kernel blob data. + +# Constants Section + - (4 bytes) FourCC: 'S', 'V', 'C', 'S' (kSlangByteCodeConstantsFourCC) + - (4 bytes uint) Constant Count: number of constants in the module. + - (4 bytes uint) String Count: number of string literals in the constant section. + - (uint array) String offsets: array of "String Count" 32-bit uints, storing byte offsets from + start of the constant array blob (next item) for each string literal. + - (uint array) Constants: + array of "Constant Count" 32-bit uints, storing byte offsets from + start of file for each constant. +*/ + +static const int kSlangByteCodeVersion = 100; + +static const uint32_t kSlangByteCodeFourCC = SLANG_FOUR_CC('S', 'V', 'M', 'C'); +static const uint32_t kSlangByteCodeFunctionsFourCC = SLANG_FOUR_CC('S', 'V', 'F', 'N'); +static const uint32_t kSlangByteCodeKernelBlobFourCC = SLANG_FOUR_CC('S', 'V', 'K', 'N'); +static const uint32_t kSlangByteCodeConstantsFourCC = SLANG_FOUR_CC('S', 'V', 'C', 'S'); + +static const int kSlangByteCodeSectionWorkingSet = 0; +static const int kSlangByteCodeSectionConstants = 1; +static const int kSlangByteCodeSectionInsts = 2; +static const int kSlangByteCodeSectionImmediate = 3; +static const int kSlangByteCodeSectionFuncs = 4; +static const int kSlangByteCodeSectionStrings = 5; + + +enum class VMOp : uint32_t +{ + Nop, + Add, + Sub, + Mul, + Div, + Rem, + Neg, + And, + Or, + Not, + BitAnd, + BitOr, + BitNot, + BitXor, + Shl, + Shr, + Ret, + Less, + Leq, + Greater, + Geq, + Equal, + Neq, + Jump, + JumpIf, + Dispatch, + Load, + Store, + Copy, + GetWorkingSetPtr, + GetElementPtr, + OffsetPtr, + GetElement, + Swizzle, + Cast, + CallExt, + Call, + Print, +}; + +// Represents an operand in the VM bytecode. +// It consists of a section ID and a byte offset within that section. +struct VMOperand +{ + uint32_t sectionId; + uint32_t padding; // Padding to ensure section takes 8 bytes. sectionId will be replaced + // with actual pointers before execution. + uint32_t type : 8; // type of the operand data. + uint32_t size : 24; + uint32_t offset; + slang::OperandDataType getType() const { return slang::OperandDataType(type); } + void setType(slang::OperandDataType newType) { type = uint32_t(newType); } +}; + +struct VMInstHeader +{ + VMOp opcode; + uint32_t padding; // 32-bit padding, to ensure space are reserved to store function pointers for + // the opcode. + uint32_t opcodeExtension; + uint32_t operandCount; + VMOperand& getOperand(Index index) const { return *((VMOperand*)(this + 1) + index); } +}; + +struct VMFuncHeader +{ + VMOperand name; // Name of the function as a VMOperand, pointing to the constant section. + uint32_t workingSetSizeInBytes; // Size of the working set in bytes. + uint32_t codeSize; // Size of the code in bytes. + uint32_t parameterCount; // Number of parameters for the function. + uint32_t returnValueSizeInBytes; // Size of the return value in bytes. + uint32_t parameterSizeInBytes; // Size of the parameters in bytes. + uint32_t getParameterOffset(Index index) const { return *((uint32_t*)(this + 1) + index); } + uint8_t* getCode() const { return (uint8_t*)(this + 1) + parameterCount * sizeof(uint32_t); } +}; + +static const int kSlangByteCodeScalarTypeSignedInt = 0; +static const int kSlangByteCodeScalarTypeUnsignedInt = 1; +static const int kSlangByteCodeScalarTypeFloat = 2; + +struct ArithmeticExtCode +{ + uint32_t scalarType : 2; // 0: signed int, 1: unsigned int, 2: floating-point + uint32_t scalarBitWidth : 2; // 0: 8, 1: 16, 2: 32, 3: 64 + uint32_t vectorSize : 12; // number of elements in the vector. + uint32_t unused : 16; +}; + +template<typename TOperand, typename TInstHeader> +struct VMInstIterator +{ + uint8_t* codePtr; // Pointer to the current instruction. + + void moveNext() + { + // Read the instruction header + TInstHeader header; + memcpy(&header, codePtr, sizeof(header)); + codePtr += sizeof(header); + + // Calculate the size of operand list. + auto operandListSize = header.operandCount * sizeof(TOperand); + + // Advance the code pointer by the size of the header and operand list. + codePtr += operandListSize; + } + + VMInstIterator& operator++() + { + moveNext(); + return *this; + } + VMInstIterator operator++(int) + { + VMInstIterator rs = *this; + rs.moveNext(); + return rs; + } + + bool operator!=(const VMInstIterator& iter) const { return codePtr != iter.codePtr; } + bool operator==(const VMInstIterator& iter) const { return codePtr == iter.codePtr; } + TInstHeader* operator*() const { return reinterpret_cast<TInstHeader*>(codePtr); } +}; + +struct VMModuleView; + +struct VMFunctionView +{ + const char* name = nullptr; + VMFuncHeader* header; // Function header containing metadata. + uint32_t* paramOffsets; + uint8_t* functionCode; // Pointer to the function code. + uint8_t* functionCodeEnd; // Pointer to the end of the function code. + VMModuleView* moduleView; // Pointer to start of the module. + VMInstIterator<VMOperand, VMInstHeader> begin() const + { + VMInstIterator<VMOperand, VMInstHeader> iter; + iter.codePtr = functionCode; + return iter; + } + + VMInstIterator<VMOperand, VMInstHeader> end() const + { + VMInstIterator<VMOperand, VMInstHeader> iter; + iter.codePtr = functionCodeEnd; + return iter; + } +}; + +struct VMModuleView +{ + uint8_t* code; + uint32_t functionCount; + uint32_t* functionOffsets; + uint8_t* constants; + uint32_t constantBlobSize; + uint32_t stringCount; + uint32_t* stringOffsets; // Offsets to string literals in the constant section. + uint8_t* kernelBlob; + uint32_t kernelBlobSize; + + List<VMFunctionView> functionViews; + + VMFunctionView getFunction(Index index) const; + + template<typename T> + SlangResult getConstant(VMOperand operand, T& outValue) const + { + if (operand.sectionId != kSlangByteCodeSectionConstants) + { + return SLANG_FAIL; // Invalid section + } + if (operand.offset + sizeof(T) > constantBlobSize) + { + return SLANG_FAIL; // Out of bounds + } + memcpy(&outValue, constants + operand.offset, sizeof(T)); + return SLANG_OK; + } +}; + +SlangResult initVMModule(uint8_t* code, uint32_t codeSize, VMModuleView* moduleView); + +StringBuilder& operator<<(StringBuilder& sb, VMOp op); +StringBuilder& operator<<(StringBuilder& sb, VMModuleView& module); +void printVMInst(StringBuilder& sb, VMModuleView* moduleView, VMInstHeader* inst); + +} // namespace Slang + + +#endif diff --git a/source/slang/slang-vm-inst-impl.cpp b/source/slang/slang-vm-inst-impl.cpp new file mode 100644 index 000000000..304ea09a8 --- /dev/null +++ b/source/slang/slang-vm-inst-impl.cpp @@ -0,0 +1,1066 @@ +#include "slang-vm-inst-impl.h" + +#include "slang-vm.h" + +using namespace slang; + +namespace Slang +{ +ByteCodeInterpreter* convert(IByteCodeRunner* runner) +{ + return static_cast<ByteCodeInterpreter*>(runner); +} + +#define SIMPLE_BINARY_SCALAR_FUNC(name, op) \ + struct name##ScalarFunc \ + { \ + template<typename TR, typename T1, typename T2> \ + static void run(TR* dst, const T1* src1, const T2* src2) \ + { \ + *dst = (*src1)op(*src2); \ + } \ + } + +SIMPLE_BINARY_SCALAR_FUNC(Add, +); +SIMPLE_BINARY_SCALAR_FUNC(Sub, -); +SIMPLE_BINARY_SCALAR_FUNC(Mul, *); +SIMPLE_BINARY_SCALAR_FUNC(Div, /); +SIMPLE_BINARY_SCALAR_FUNC(And, &&); +SIMPLE_BINARY_SCALAR_FUNC(Or, ||); +SIMPLE_BINARY_SCALAR_FUNC(BitAnd, &); +SIMPLE_BINARY_SCALAR_FUNC(BitOr, |); +SIMPLE_BINARY_SCALAR_FUNC(BitXor, ^); +SIMPLE_BINARY_SCALAR_FUNC(Shl, <<); +SIMPLE_BINARY_SCALAR_FUNC(Shr, >>); +SIMPLE_BINARY_SCALAR_FUNC(Less, <); +SIMPLE_BINARY_SCALAR_FUNC(Leq, <=); +SIMPLE_BINARY_SCALAR_FUNC(Greater, >); +SIMPLE_BINARY_SCALAR_FUNC(Geq, >=); +SIMPLE_BINARY_SCALAR_FUNC(Equal, ==); +SIMPLE_BINARY_SCALAR_FUNC(Neq, !=); + +template<typename TR, typename T1, typename T2> +void scalarMod(TR* dst, const T1* src1, const T2* src2) +{ + *dst = *src1 % *src2; +} + +template<> +void scalarMod<float, float, float>(float* dst, const float* src1, const float* src2) +{ + *dst = fmodf(*src1, *src2); +} + +template<> +void scalarMod<double, double, double>(double* dst, const double* src1, const double* src2) +{ + *dst = fmod(*src1, *src2); +} + +struct ModScalarFunc +{ + template<typename TR, typename T1, typename T2> + static void run(TR* dst, const T1* src1, const T2* src2) + { + scalarMod<TR, T1, T2>(dst, src1, src2); + } +}; + +#define SIMPLE_UNARY_SCALAR_FUNC(name, op) \ + struct name##ScalarFunc \ + { \ + template<typename TR, typename T1> \ + static void run(TR* dst, const T1* src1) \ + { \ + *dst = op(*src1); \ + } \ + } +SIMPLE_UNARY_SCALAR_FUNC(Neg, -); +SIMPLE_UNARY_SCALAR_FUNC(Not, !); +SIMPLE_UNARY_SCALAR_FUNC(BitNot, ~); + +template<typename ScalarFunc, typename TR, typename T1, typename T2, int elementCount> +struct BinaryVectorFunc +{ + static void run(IByteCodeRunner* context, VMExecInstHeader* inst, void* userData) + { + SLANG_UNUSED(context); + SLANG_UNUSED(userData); + TR* dst = (TR*)inst->getOperand(0).getPtr(); + T1* src1 = (T1*)inst->getOperand(1).getPtr(); + T2* src2 = (T2*)inst->getOperand(2).getPtr(); + for (int i = 0; i < elementCount; ++i) + { + ScalarFunc::template run<TR, T1, T2>(&dst[i], &src1[i], &src2[i]); + } + } +}; + +template<typename ScalarFunc, typename TR, typename T1, typename T2> +struct GeneralBinaryVectorFunc +{ + static void run(IByteCodeRunner* context, VMExecInstHeader* inst, void* userData) + { + SLANG_UNUSED(context); + SLANG_UNUSED(userData); + TR* dst = (TR*)inst->getOperand(0).getPtr(); + T1* src1 = (T1*)inst->getOperand(1).getPtr(); + T2* src2 = (T2*)inst->getOperand(2).getPtr(); + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &inst->opcodeExtension, sizeof(arithExtCode)); + for (uint32_t i = 0; i < arithExtCode.vectorSize; ++i) + { + ScalarFunc::template run<TR, T1, T2>(&dst[i], &src1[i], &src2[i]); + } + } +}; + +template<typename Func, typename TR, typename T1 = TR, typename T2 = TR> +VMExtFunction binaryArithmeticInstHandler(int elementCount) +{ + switch (elementCount) + { + case 0: + case 1: + return BinaryVectorFunc<Func, TR, T1, T2, 1>::run; + case 2: + return BinaryVectorFunc<Func, TR, T1, T2, 2>::run; + case 3: + return BinaryVectorFunc<Func, TR, T1, T2, 3>::run; + case 4: + return BinaryVectorFunc<Func, TR, T1, T2, 4>::run; + case 6: + return BinaryVectorFunc<Func, TR, T1, T2, 6>::run; + case 8: + return BinaryVectorFunc<Func, TR, T1, T2, 8>::run; + case 9: + return BinaryVectorFunc<Func, TR, T1, T2, 9>::run; + case 10: + return BinaryVectorFunc<Func, TR, T1, T2, 10>::run; + case 12: + return BinaryVectorFunc<Func, TR, T1, T2, 12>::run; + case 16: + return BinaryVectorFunc<Func, TR, T1, T2, 16>::run; + default: + return GeneralBinaryVectorFunc<Func, TR, T1, T2>::run; + } +} + +template<typename Func> +VMExtFunction binaryArithmeticInstHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return binaryArithmeticInstHandler<Func, int8_t>(arithExtCode.vectorSize); + case 1: + return binaryArithmeticInstHandler<Func, int16_t>(arithExtCode.vectorSize); + case 2: + return binaryArithmeticInstHandler<Func, int32_t>(arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, int64_t>(arithExtCode.vectorSize); + } + case kSlangByteCodeScalarTypeUnsignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return binaryArithmeticInstHandler<Func, uint8_t>(arithExtCode.vectorSize); + case 1: + return binaryArithmeticInstHandler<Func, uint16_t>(arithExtCode.vectorSize); + case 2: + return binaryArithmeticInstHandler<Func, uint32_t>(arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, uint64_t>(arithExtCode.vectorSize); + } + case kSlangByteCodeScalarTypeFloat: + switch (arithExtCode.scalarBitWidth) + { + case 2: + return binaryArithmeticInstHandler<Func, float>(arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, double>(arithExtCode.vectorSize); + default: + return nullptr; // Unsupported scalar bit width + } + } + return nullptr; +} + +template<typename Func> +VMExtFunction binaryArithmeticLogicalInstHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarBitWidth) + { + case 0: + return binaryArithmeticInstHandler<Func, uint8_t>(arithExtCode.vectorSize); + case 1: + return binaryArithmeticInstHandler<Func, uint16_t>(arithExtCode.vectorSize); + case 2: + return binaryArithmeticInstHandler<Func, uint32_t>(arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, uint64_t>(arithExtCode.vectorSize); + } + return nullptr; +} + +template<typename Func> +VMExtFunction binaryArithmeticIntInstHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return binaryArithmeticInstHandler<Func, int8_t>(arithExtCode.vectorSize); + case 1: + return binaryArithmeticInstHandler<Func, int16_t>(arithExtCode.vectorSize); + case 2: + return binaryArithmeticInstHandler<Func, int32_t>(arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, int64_t>(arithExtCode.vectorSize); + } + case kSlangByteCodeScalarTypeUnsignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return binaryArithmeticInstHandler<Func, uint8_t>(arithExtCode.vectorSize); + case 1: + return binaryArithmeticInstHandler<Func, uint16_t>(arithExtCode.vectorSize); + case 2: + return binaryArithmeticInstHandler<Func, uint32_t>(arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, uint64_t>(arithExtCode.vectorSize); + } + } + return nullptr; +} + +template<typename Func> +VMExtFunction binaryArithmeticCompareInstHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return binaryArithmeticInstHandler<Func, uint32_t, int8_t, int8_t>( + arithExtCode.vectorSize); + case 1: + return binaryArithmeticInstHandler<Func, uint32_t, int16_t, int16_t>( + arithExtCode.vectorSize); + case 2: + return binaryArithmeticInstHandler<Func, uint32_t, int32_t, int32_t>( + arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, uint32_t, int64_t, int64_t>( + arithExtCode.vectorSize); + } + case kSlangByteCodeScalarTypeUnsignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return binaryArithmeticInstHandler<Func, uint32_t, uint8_t, uint8_t>( + arithExtCode.vectorSize); + case 1: + return binaryArithmeticInstHandler<Func, uint32_t, uint16_t, uint16_t>( + arithExtCode.vectorSize); + case 2: + return binaryArithmeticInstHandler<Func, uint32_t, uint32_t, uint32_t>( + arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, uint32_t, uint64_t, uint64_t>( + arithExtCode.vectorSize); + } + case kSlangByteCodeScalarTypeFloat: + switch (arithExtCode.scalarBitWidth) + { + case 2: + return binaryArithmeticInstHandler<Func, uint32_t, float, float>( + arithExtCode.vectorSize); + case 3: + return binaryArithmeticInstHandler<Func, uint32_t, double, double>( + arithExtCode.vectorSize); + default: + return nullptr; // Unsupported scalar bit width + } + } + return nullptr; +} + +//////// +template<typename ScalarFunc, typename TR, typename T1, int elementCount> +struct UnaryVectorFunc +{ + static void run(IByteCodeRunner* context, VMExecInstHeader* inst, void* userData) + { + SLANG_UNUSED(context); + SLANG_UNUSED(userData); + TR* dst = (TR*)inst->getOperand(0).getPtr(); + T1* src1 = (T1*)inst->getOperand(1).getPtr(); + for (int i = 0; i < elementCount; ++i) + { + ScalarFunc::template run<TR, T1>(&dst[i], &src1[i]); + } + } +}; + +template<typename ScalarFunc, typename TR, typename T1> +struct GeneralUnaryVectorFunc +{ + static void run(IByteCodeRunner* context, VMExecInstHeader* inst, void* userData) + { + SLANG_UNUSED(context); + SLANG_UNUSED(userData); + TR* dst = (TR*)inst->getOperand(0).getPtr(); + T1* src1 = (T1*)inst->getOperand(1).getPtr(); + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &inst->opcodeExtension, sizeof(arithExtCode)); + for (uint32_t i = 0; i < arithExtCode.vectorSize; ++i) + { + ScalarFunc::template run<TR, T1>(&dst[i], &src1[i]); + } + } +}; + +template<typename Func, typename TR, typename T1 = TR> +VMExtFunction unaryArithmeticInstHandler(int elementCount) +{ + switch (elementCount) + { + case 0: + case 1: + return UnaryVectorFunc<Func, TR, T1, 1>::run; + case 2: + return UnaryVectorFunc<Func, TR, T1, 2>::run; + case 3: + return UnaryVectorFunc<Func, TR, T1, 3>::run; + case 4: + return UnaryVectorFunc<Func, TR, T1, 4>::run; + case 6: + return UnaryVectorFunc<Func, TR, T1, 6>::run; + case 8: + return UnaryVectorFunc<Func, TR, T1, 8>::run; + case 9: + return UnaryVectorFunc<Func, TR, T1, 9>::run; + case 10: + return UnaryVectorFunc<Func, TR, T1, 10>::run; + case 12: + return UnaryVectorFunc<Func, TR, T1, 12>::run; + case 16: + return UnaryVectorFunc<Func, TR, T1, 16>::run; + default: + return GeneralUnaryVectorFunc<Func, TR, T1>::run; + } +} + +template<typename Func> +VMExtFunction unaryArithmeticLogicalInstHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarBitWidth) + { + case 0: + return unaryArithmeticInstHandler<Func, uint8_t>(arithExtCode.vectorSize); + case 1: + return unaryArithmeticInstHandler<Func, uint16_t>(arithExtCode.vectorSize); + case 2: + return unaryArithmeticInstHandler<Func, uint32_t>(arithExtCode.vectorSize); + case 3: + return unaryArithmeticInstHandler<Func, uint64_t>(arithExtCode.vectorSize); + } + return nullptr; +} + +template<typename Func> +VMExtFunction unaryArithmeticIntInstHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return unaryArithmeticInstHandler<Func, int8_t>(arithExtCode.vectorSize); + case 1: + return unaryArithmeticInstHandler<Func, int16_t>(arithExtCode.vectorSize); + case 2: + return unaryArithmeticInstHandler<Func, int32_t>(arithExtCode.vectorSize); + case 3: + return unaryArithmeticInstHandler<Func, int64_t>(arithExtCode.vectorSize); + } + case kSlangByteCodeScalarTypeUnsignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return unaryArithmeticInstHandler<Func, uint8_t>(arithExtCode.vectorSize); + case 1: + return unaryArithmeticInstHandler<Func, uint16_t>(arithExtCode.vectorSize); + case 2: + return unaryArithmeticInstHandler<Func, uint32_t>(arithExtCode.vectorSize); + case 3: + return unaryArithmeticInstHandler<Func, uint64_t>(arithExtCode.vectorSize); + } + } + return nullptr; +} + +template<typename Func> +VMExtFunction negInstHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + case kSlangByteCodeScalarTypeUnsignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return unaryArithmeticInstHandler<Func, int8_t>(arithExtCode.vectorSize); + case 1: + return unaryArithmeticInstHandler<Func, int16_t>(arithExtCode.vectorSize); + case 2: + return unaryArithmeticInstHandler<Func, int32_t>(arithExtCode.vectorSize); + case 3: + return unaryArithmeticInstHandler<Func, int64_t>(arithExtCode.vectorSize); + } + case kSlangByteCodeScalarTypeFloat: + switch (arithExtCode.scalarBitWidth) + { + case 2: + return unaryArithmeticInstHandler<Func, float>(arithExtCode.vectorSize); + case 3: + return unaryArithmeticInstHandler<Func, double>(arithExtCode.vectorSize); + default: + return nullptr; // Unsupported scalar bit width + } + } + return nullptr; +} + +static void nopHandler(IByteCodeRunner*, VMExecInstHeader*, void*) {} + +void callHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void*) +{ + auto ctx = convert(inCtx); + auto funcId = inst->getOperand(1).offset; + auto& func = ctx->m_functions[funcId]; + auto funcHeader = func.m_header; + + // Alloc working set. + ctx->pushFrame(funcHeader->workingSetSizeInBytes); + + // Save current instruction pointer. + auto& stackFrame = ctx->m_stack.getLast(); + stackFrame.m_currentInst = inst; + stackFrame.m_currentFuncCode = ctx->m_currentFuncCode; + auto newWorkingSetPtr = (uint8_t*)ctx->m_currentWorkingSet; + auto callerWorkingSetPtr = + (uint8_t*)(ctx->m_workingSetBuffer.getBuffer() + stackFrame.m_workingSetOffset); + + // Set working set pointer to the caller's working set. + ctx->m_currentWorkingSet = callerWorkingSetPtr; + + // Copy arguments to the callee's working set. + for (uint32_t i = 0; i < funcHeader->parameterCount; ++i) + { + auto dst = newWorkingSetPtr + func.m_parameterOffsets[i]; + auto src = (uint8_t*)inst->getOperand(i + 2).getPtr(); + + // func.m_parameterOffsets should be initialized to contain parameterCount+1 elements, + // where the last element is the total size of the parameters. + auto nextParamOffset = func.m_parameterOffsets[i + 1]; + memcpy(dst, src, nextParamOffset - func.m_parameterOffsets[i]); + } + ctx->m_currentWorkingSet = newWorkingSetPtr; + ctx->m_currentFuncCode = func.m_codeBuffer.getBuffer(); + ctx->m_currentInst = (VMExecInstHeader*)func.m_codeBuffer.getBuffer(); +} + +static void retHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void*) +{ + auto ctx = convert(inCtx); + if (inst->opcodeExtension != 0) + { + void* resultPtr = nullptr; + if (ctx->m_stack.getCount()) + { + auto callInst = ctx->m_stack.getLast().m_currentInst; + auto callerWorkingSetPtr = (uint8_t*)(ctx->m_workingSetBuffer.getBuffer() + + ctx->m_stack.getLast().m_workingSetOffset); + resultPtr = callerWorkingSetPtr + callInst->getOperand(0).offset; + } + else + { + // If there is no stack frame, we assume the result is stored in the return register. + ctx->m_returnRegister.setCount(inst->opcodeExtension); + resultPtr = ctx->m_returnRegister.getBuffer(); + ctx->m_returnValSize = inst->opcodeExtension; + } + memcpy(resultPtr, inst->getOperand(0).getPtr(), inst->opcodeExtension); + } + + // If we are returning from a main function, there is nothing to pop from the stack frame, + // and we should stop execution. + if (ctx->m_stack.getCount() == 0) + { + ctx->m_currentInst = nullptr; + return; + } + + // Pop the working set. + ctx->popFrame(); +} + +static void jumpHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void*) +{ + auto ctx = convert(inCtx); + ctx->m_currentInst = (VMExecInstHeader*)inst->getOperand(0).getPtr(); +} + +static void jumpIfHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void*) +{ + auto ctx = convert(inCtx); + + auto cond = *(uint32_t*)inst->getOperand(0).getPtr(); + if (cond) + { + ctx->m_currentInst = (VMExecInstHeader*)inst->getOperand(1).getPtr(); + } + else + { + ctx->m_currentInst = (VMExecInstHeader*)inst->getOperand(2).getPtr(); + } +} + +static void getWorkingSetPtrHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void*) +{ + auto ctx = convert(inCtx); + auto dst = (void**)inst->getOperand(0).getPtr(); + auto ptr = (uint8_t*)ctx->m_currentWorkingSet + inst->opcodeExtension; + *dst = ptr; +} + +static void getElementPtrHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (void**)inst->getOperand(0).getPtr(); + auto basePtr = *(uint8_t**)inst->getOperand(1).getPtr(); + auto elementIndex = *(uint32_t*)inst->getOperand(2).getPtr(); + *dst = (uint8_t*)basePtr + elementIndex * inst->opcodeExtension; +} + +static void getElementHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (void*)inst->getOperand(0).getPtr(); + auto basePtr = (uint8_t*)inst->getOperand(1).getPtr(); + auto elementIndex = *(uint32_t*)inst->getOperand(2).getPtr(); + memcpy(dst, basePtr + elementIndex * inst->opcodeExtension, inst->opcodeExtension); +} + +static void offsetPtrHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (void**)inst->getOperand(0).getPtr(); + auto basePtr = *(uint8_t**)inst->getOperand(1).getPtr(); + auto offset = *(int32_t*)inst->getOperand(2).getPtr(); + *dst = basePtr + offset * inst->opcodeExtension; +} + +void loadHandler8(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint8_t*)inst->getOperand(0).getPtr(); + auto src = *(uint8_t**)inst->getOperand(1).getPtr(); + *dst = *src; +} +void loadHandler16(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint16_t*)inst->getOperand(0).getPtr(); + auto src = *(uint16_t**)inst->getOperand(1).getPtr(); + *dst = *src; +} +void loadHandler32(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint32_t*)inst->getOperand(0).getPtr(); + auto src = *(uint32_t**)inst->getOperand(1).getPtr(); + *dst = *src; +} +void loadHandler64(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint64_t*)inst->getOperand(0).getPtr(); + auto src = *(uint64_t**)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void generalLoadHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint8_t*)inst->getOperand(0).getPtr(); + auto src = *(uint8_t**)inst->getOperand(1).getPtr(); + memcpy(dst, src, inst->opcodeExtension); +} + +VMExtFunction getLoadHandler(uint32_t extCode) +{ + switch (extCode) + { + case 1: + return loadHandler8; + case 2: + return loadHandler16; + case 4: + return loadHandler32; + case 8: + return loadHandler64; + default: + return generalLoadHandler; + } +} + +void storeHandler8(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = *(uint8_t**)inst->getOperand(0).getPtr(); + auto src = (uint8_t*)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void storeHandler16(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = *(uint16_t**)inst->getOperand(0).getPtr(); + auto src = (uint16_t*)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void storeHandler32(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = *(uint32_t**)inst->getOperand(0).getPtr(); + auto src = (uint32_t*)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void storeHandler64(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = *(uint64_t**)inst->getOperand(0).getPtr(); + auto src = (uint64_t*)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void generalStoreHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = *(uint8_t**)inst->getOperand(0).getPtr(); + auto src = (uint8_t*)inst->getOperand(1).getPtr(); + memcpy(dst, src, inst->opcodeExtension); +} + +VMExtFunction getStoreHandler(uint32_t extCode) +{ + switch (extCode) + { + case 1: + return storeHandler8; + case 2: + return storeHandler16; + case 4: + return storeHandler32; + case 8: + return storeHandler64; + default: + return generalStoreHandler; + } +} + +void copyHandler8(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint8_t*)inst->getOperand(0).getPtr(); + auto src = (uint8_t*)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void copyHandler16(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint16_t*)inst->getOperand(0).getPtr(); + auto src = (uint16_t*)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void copyHandler32(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint32_t*)inst->getOperand(0).getPtr(); + auto src = (uint32_t*)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void copyHandler64(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint64_t*)inst->getOperand(0).getPtr(); + auto src = (uint64_t*)inst->getOperand(1).getPtr(); + *dst = *src; +} + +void generalCopyHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + auto dst = (uint8_t*)inst->getOperand(0).getPtr(); + auto src = (uint8_t*)inst->getOperand(1).getPtr(); + memcpy(dst, src, inst->opcodeExtension); +} + +VMExtFunction getCopyHandler(uint32_t extCode) +{ + switch (extCode) + { + case 1: + return copyHandler8; + case 2: + return copyHandler16; + case 4: + return copyHandler32; + case 8: + return copyHandler64; + default: + return generalCopyHandler; + } +} + +template<typename T> +void swizzleHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void* userData) +{ + SLANG_UNUSED(ctx); + SLANG_UNUSED(userData); + auto dst = (T*)inst->getOperand(0).getPtr(); + auto src = (T*)inst->getOperand(1).getPtr(); + for (uint32_t i = 2; i < inst->operandCount; ++i) + { + dst[i - 2] = src[inst->getOperand(i).offset]; + } +} + +VMExtFunction getSwizzleHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return swizzleHandler<int8_t>; + case 1: + return swizzleHandler<int16_t>; + case 2: + return swizzleHandler<int32_t>; + case 3: + return swizzleHandler<int64_t>; + } + case kSlangByteCodeScalarTypeUnsignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return swizzleHandler<uint8_t>; + case 1: + return swizzleHandler<uint16_t>; + case 2: + return swizzleHandler<uint32_t>; + case 3: + return swizzleHandler<uint64_t>; + } + case kSlangByteCodeScalarTypeFloat: + switch (arithExtCode.scalarBitWidth) + { + case 2: + return swizzleHandler<float>; + case 3: + return swizzleHandler<double>; + default: + return nullptr; // Unsupported scalar bit width + } + } + return nullptr; +} + +template<typename To, typename From, int vectorSize> +void castHandler(IByteCodeRunner* ctx, VMExecInstHeader* inst, void*) +{ + SLANG_UNUSED(ctx); + To* dst = (To*)inst->getOperand(0).getPtr(); + From* src = (From*)inst->getOperand(1).getPtr(); + for (int i = 0; i < vectorSize; ++i) + { + dst[i] = static_cast<To>(src[i]); + } +} + +template<typename From, int vectorSize> +VMExtFunction getCastHandler(uint32_t extCode) +{ + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &extCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return castHandler<uint8_t, From, vectorSize>; + case 1: + return castHandler<uint16_t, From, vectorSize>; + case 2: + return castHandler<uint32_t, From, vectorSize>; + case 3: + return castHandler<uint64_t, From, vectorSize>; + } + case kSlangByteCodeScalarTypeUnsignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return castHandler<uint8_t, From, vectorSize>; + case 1: + return castHandler<uint16_t, From, vectorSize>; + case 2: + return castHandler<uint32_t, From, vectorSize>; + case 3: + return castHandler<uint64_t, From, vectorSize>; + } + case kSlangByteCodeScalarTypeFloat: + switch (arithExtCode.scalarBitWidth) + { + case 2: + return castHandler<float, From, vectorSize>; + case 3: + return castHandler<double, From, vectorSize>; + default: + return nullptr; // Unsupported scalar bit width + } + } + return nullptr; +} + +template<int vectorSize> +VMExtFunction getCastHandler(uint32_t extCode) +{ + uint32_t fromExtCode = extCode >> 16; + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &fromExtCode, sizeof(arithExtCode)); + switch (arithExtCode.scalarType) + { + case kSlangByteCodeScalarTypeSignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return getCastHandler<uint8_t, vectorSize>(extCode); + case 1: + return getCastHandler<uint16_t, vectorSize>(extCode); + case 2: + return getCastHandler<uint32_t, vectorSize>(extCode); + case 3: + return getCastHandler<uint64_t, vectorSize>(extCode); + } + case kSlangByteCodeScalarTypeUnsignedInt: + switch (arithExtCode.scalarBitWidth) + { + case 0: + return getCastHandler<uint8_t, vectorSize>(extCode); + case 1: + return getCastHandler<uint16_t, vectorSize>(extCode); + case 2: + return getCastHandler<uint32_t, vectorSize>(extCode); + case 3: + return getCastHandler<uint64_t, vectorSize>(extCode); + } + case kSlangByteCodeScalarTypeFloat: + switch (arithExtCode.scalarBitWidth) + { + case 2: + return getCastHandler<float, vectorSize>(extCode); + case 3: + return getCastHandler<double, vectorSize>(extCode); + default: + return nullptr; // Unsupported scalar bit width + } + } + return nullptr; +} + +VMExtFunction getCastHandler(uint32_t extCode) +{ + uint32_t fromExtCode = extCode >> 16; + ArithmeticExtCode arithExtCode; + memcpy(&arithExtCode, &fromExtCode, sizeof(arithExtCode)); + switch (arithExtCode.vectorSize) + { + case 0: + case 1: + return getCastHandler<1>(extCode); + case 2: + return getCastHandler<2>(extCode); + case 3: + return getCastHandler<3>(extCode); + case 4: + return getCastHandler<4>(extCode); + case 6: + return getCastHandler<6>(extCode); + case 8: + return getCastHandler<8>(extCode); + case 9: + return getCastHandler<9>(extCode); + case 12: + return getCastHandler<12>(extCode); + case 16: + return getCastHandler<16>(extCode); + } + return nullptr; +} + +void printHandler(IByteCodeRunner* inCtx, VMExecInstHeader* inst, void* userData) +{ + auto ctx = convert(inCtx); + SLANG_UNUSED(userData); + const char* formatString = nullptr; + formatString = *(const char**)inst->getOperand(0).getPtr(); + + List<List<uint8_t>> args; + List<const void*> argPtrs; + for (uint32_t i = 1; i < inst->operandCount; ++i) + { + auto& arg = inst->getOperand(i); + List<uint8_t> data; + data.setCount(arg.size); + memcpy(data.getBuffer(), arg.getPtr(), arg.size); + args.add(data); + } + for (auto& arg : args) + { + argPtrs.add(arg.getBuffer()); + } + auto result = + StringUtil::makeStringWithFormatFromArgArray(formatString, argPtrs.getArrayView()); + ctx->m_printCallback(result.getBuffer(), ctx->m_printCallbackUserData); +} + + +VMExtFunction mapInstToFunction( + VMInstHeader* instHeader, + VMModuleView* module, + Dictionary<String, slang::VMExtFunction>& extInstHandlers) +{ + switch (instHeader->opcode) + { + case VMOp::Nop: + return nopHandler; + case VMOp::Add: + return binaryArithmeticInstHandler<AddScalarFunc>(instHeader->opcodeExtension); + case VMOp::Sub: + return binaryArithmeticInstHandler<SubScalarFunc>(instHeader->opcodeExtension); + case VMOp::Mul: + return binaryArithmeticInstHandler<MulScalarFunc>(instHeader->opcodeExtension); + case VMOp::Div: + return binaryArithmeticInstHandler<DivScalarFunc>(instHeader->opcodeExtension); + case VMOp::Rem: + return binaryArithmeticInstHandler<ModScalarFunc>(instHeader->opcodeExtension); + case VMOp::And: + return binaryArithmeticLogicalInstHandler<AndScalarFunc>(instHeader->opcodeExtension); + case VMOp::Or: + return binaryArithmeticLogicalInstHandler<OrScalarFunc>(instHeader->opcodeExtension); + case VMOp::BitAnd: + return binaryArithmeticLogicalInstHandler<BitAndScalarFunc>(instHeader->opcodeExtension); + case VMOp::BitOr: + return binaryArithmeticLogicalInstHandler<BitOrScalarFunc>(instHeader->opcodeExtension); + case VMOp::BitXor: + return binaryArithmeticLogicalInstHandler<BitXorScalarFunc>(instHeader->opcodeExtension); + case VMOp::Shl: + return binaryArithmeticIntInstHandler<ShlScalarFunc>(instHeader->opcodeExtension); + case VMOp::Shr: + return binaryArithmeticIntInstHandler<ShrScalarFunc>(instHeader->opcodeExtension); + case VMOp::Less: + return binaryArithmeticCompareInstHandler<LessScalarFunc>(instHeader->opcodeExtension); + case VMOp::Leq: + return binaryArithmeticCompareInstHandler<LeqScalarFunc>(instHeader->opcodeExtension); + case VMOp::Greater: + return binaryArithmeticCompareInstHandler<GreaterScalarFunc>(instHeader->opcodeExtension); + case VMOp::Geq: + return binaryArithmeticCompareInstHandler<GeqScalarFunc>(instHeader->opcodeExtension); + case VMOp::Equal: + return binaryArithmeticCompareInstHandler<EqualScalarFunc>(instHeader->opcodeExtension); + case VMOp::Neq: + return binaryArithmeticCompareInstHandler<NeqScalarFunc>(instHeader->opcodeExtension); + case VMOp::Neg: + return negInstHandler<NegScalarFunc>(instHeader->opcodeExtension); + case VMOp::Not: + return unaryArithmeticLogicalInstHandler<NotScalarFunc>(instHeader->opcodeExtension); + case VMOp::BitNot: + return unaryArithmeticIntInstHandler<BitNotScalarFunc>(instHeader->opcodeExtension); + case VMOp::Ret: + return retHandler; + case VMOp::Call: + return callHandler; + case VMOp::Jump: + return jumpHandler; + case VMOp::JumpIf: + return jumpIfHandler; + case VMOp::Load: + return getLoadHandler(instHeader->opcodeExtension); + case VMOp::Store: + return getStoreHandler(instHeader->opcodeExtension); + case VMOp::Copy: + return getCopyHandler(instHeader->opcodeExtension); + case VMOp::GetWorkingSetPtr: + return getWorkingSetPtrHandler; + case VMOp::GetElementPtr: + return getElementPtrHandler; + case VMOp::OffsetPtr: + return offsetPtrHandler; + case VMOp::GetElement: + return getElementHandler; + case VMOp::Swizzle: + return getSwizzleHandler(instHeader->opcodeExtension); + case VMOp::Cast: + return getCastHandler(instHeader->opcodeExtension); + case VMOp::CallExt: + { + if (instHeader->getOperand(0).offset >= module->stringCount) + return nullptr; + auto funcName = (const char*)module->constants + + module->stringOffsets[instHeader->getOperand(0).offset]; + VMExtFunction handler = nullptr; + if (!extInstHandlers.tryGetValue(funcName, handler)) + return nullptr; + return handler; + } + case VMOp::Print: + return printHandler; + } + return VMExtFunction(); +} + +} // namespace Slang diff --git a/source/slang/slang-vm-inst-impl.h b/source/slang/slang-vm-inst-impl.h new file mode 100644 index 000000000..4331a79ec --- /dev/null +++ b/source/slang/slang-vm-inst-impl.h @@ -0,0 +1,16 @@ +#ifndef SLANG_VM_INST_IMPL_H +#define SLANG_VM_INST_IMPL_H + +#include "slang-vm-bytecode.h" + +namespace Slang +{ + +slang::VMExtFunction mapInstToFunction( + VMInstHeader* instHeader, + VMModuleView* module, + Dictionary<String, slang::VMExtFunction>& extInstHandlers); + +} // namespace Slang + +#endif diff --git a/source/slang/slang-vm.cpp b/source/slang/slang-vm.cpp new file mode 100644 index 000000000..05d53acb8 --- /dev/null +++ b/source/slang/slang-vm.cpp @@ -0,0 +1,254 @@ +#include "slang-vm.h" + +#include "core/slang-blob.h" +#include "slang-vm-inst-impl.h" + +namespace Slang +{ + +// Our VM insts need to be 8-byte aligned, so we can replace the opcode with function pointers and +// sectionId with data pointers. +static_assert(sizeof(VMOperand) % 8 == 0); +static_assert(sizeof(VMInstHeader) % 8 == 0); +static_assert(sizeof(VMOperand) == sizeof(VMExecOperand)); +static_assert(sizeof(VMInstHeader) == sizeof(VMExecInstHeader)); + +ISlangUnknown* ByteCodeInterpreter::getInterface(const Guid& guid) +{ + if (guid == ISlangUnknown::getTypeGuid() || guid == IByteCodeRunner::getTypeGuid()) + return static_cast<IByteCodeRunner*>(this); + + return nullptr; +} + +SlangResult ByteCodeInterpreter::prepareModuleForExecution() +{ + m_stringLits.clear(); + m_stringLits.setCount(m_moduleView.stringCount); + for (uint32_t i = 0; i < m_moduleView.stringCount; i++) + { + auto strOffset = m_moduleView.stringOffsets[i]; + const char* str = (const char*)m_moduleView.constants + strOffset; + m_stringLits[i] = str; + } + m_stringLitsPtr = m_stringLits.getBuffer(); + + m_functions.setCount(m_moduleView.functionCount); + for (uint32_t i = 0; i < m_moduleView.functionCount; i++) + { + auto func = m_moduleView.getFunction(i); + auto& exeFunc = m_functions[i]; + exeFunc.m_codeBuffer.setCount(func.header->codeSize / sizeof(uint64_t)); + exeFunc.m_header = func.header; + for (uint32_t j = 0; j < func.header->parameterCount; j++) + { + exeFunc.m_parameterOffsets.add(func.header->getParameterOffset(j)); + } + exeFunc.m_parameterOffsets.add(func.header->parameterSizeInBytes); + + // Copy the code into the executable function buffer + memcpy(exeFunc.m_codeBuffer.getBuffer(), func.functionCode, func.header->codeSize); + + // Replace the instruction headers with function pointers + for (auto inst : exeFunc) + { + VMInstHeader* instHeader = reinterpret_cast<VMInstHeader*>(inst); + auto handler = mapInstToFunction(instHeader, &m_moduleView, m_extInstHandlers); + if (!handler) + { + StringBuilder instStr; + printVMInst(instStr, &m_moduleView, instHeader); + reportError( + "Cannot find execution handler for instruction %s\n", + instStr.toString().getBuffer()); + return SLANG_FAIL; + } + inst->functionPtr = handler; + for (uint32_t operandIdx = 0; operandIdx < instHeader->operandCount; operandIdx++) + { + auto& operand = instHeader->getOperand(operandIdx); + auto& execOpernad = inst->getOperand(operandIdx); + switch (operand.sectionId) + { + case kSlangByteCodeSectionConstants: + execOpernad.section = &m_moduleView.constants; + break; + case kSlangByteCodeSectionInsts: + execOpernad.section = (uint8_t**)&m_currentFuncCode; + break; + case kSlangByteCodeSectionWorkingSet: + execOpernad.section = (uint8_t**)&m_currentWorkingSet; + break; + case kSlangByteCodeSectionStrings: + execOpernad.section = (uint8_t**)&m_stringLitsPtr; + execOpernad.offset *= sizeof(const char*); + break; + } + } + } + } + + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL ByteCodeInterpreter::loadModule(IBlob* moduleBlob) +{ + m_stack.reserve(128); + m_workingSetBuffer.reserve(1024 * 1024); // Reserve 1MB for working set + m_currentWorkingSet = m_workingSetBuffer.getBuffer(); + + m_errorBuilder.clear(); + m_code.addRange((uint8_t*)(moduleBlob->getBufferPointer()), moduleBlob->getBufferSize()); + SLANG_RETURN_ON_FAIL( + initVMModule(m_code.getBuffer(), (uint32_t)moduleBlob->getBufferSize(), &m_moduleView)); + SLANG_RETURN_ON_FAIL(prepareModuleForExecution()); + return SLANG_OK; +} + +SLANG_NO_THROW void SLANG_MCALL ByteCodeInterpreter::getErrorString(slang::IBlob** outBlob) +{ + *outBlob = StringBlob::moveCreate(m_errorBuilder.produceString()).detach(); + m_errorBuilder.clear(); +} + +SLANG_NO_THROW int SLANG_MCALL ByteCodeInterpreter::findFunctionByName(const char* name) +{ + for (uint32_t i = 0; i < m_moduleView.functionCount; i++) + { + auto func = m_moduleView.getFunction(i); + if (UnownedStringSlice(func.name) == name) + { + return (int)i; + } + } + return -1; // Function not found +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +ByteCodeInterpreter::getFunctionInfo(uint32_t index, slang::ByteCodeFuncInfo* outInfo) +{ + if (index >= m_moduleView.functionCount) + { + return SLANG_FAIL; + } + auto func = m_moduleView.getFunction(index); + outInfo->parameterCount = func.header->parameterCount; + outInfo->returnValueSize = func.header->returnValueSizeInBytes; + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +ByteCodeInterpreter::selectFunctionByIndex(uint32_t functionIndex) +{ + if (functionIndex >= m_moduleView.functionCount) + { + reportError( + "Function index %u out of range [0, %u)", + functionIndex, + m_moduleView.functionCount); + return SLANG_FAIL; + } + auto func = m_moduleView.getFunction(functionIndex); + m_currentFuncCode = m_functions[functionIndex].m_codeBuffer.getBuffer(); + m_currentInst = reinterpret_cast<VMExecInstHeader*>(m_currentFuncCode); + m_workingSetBuffer.setCount(func.header->workingSetSizeInBytes / sizeof(uint64_t)); + m_currentWorkingSet = m_workingSetBuffer.getBuffer(); + return SLANG_OK; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +ByteCodeInterpreter::execute(void* argumentData, size_t argumentSize) +{ + if (!m_currentInst) + { + reportError("No function selected for execution"); + return SLANG_FAIL; + } + if (!m_currentWorkingSet) + { + reportError("No working set allocated for execution"); + return SLANG_FAIL; + } + if ((uint8_t*)m_currentWorkingSet + argumentSize > + (uint8_t*)(m_workingSetBuffer.getBuffer() + m_workingSetBuffer.getCount())) + { + reportError("Argument size exceeds working set."); + return SLANG_FAIL; + } + // Copy the arguments into the working set + if (argumentData && argumentSize > 0) + { + memcpy(m_currentWorkingSet, argumentData, argumentSize); + } + m_returnValSize = 0; + while (m_currentInst) + { + auto nextInst = m_currentInst->getNextInst(); + auto currentInst = m_currentInst; + m_currentInst = nextInst; + currentInst->functionPtr(this, currentInst, m_extInstHandlerUserData); + } + return SLANG_OK; +} + +ByteCodeInterpreter::ByteCodeInterpreter() +{ + m_printCallback = defaultPrintCallback; + m_printCallbackUserData = this; +} + +SLANG_NO_THROW SlangResult SLANG_MCALL +ByteCodeInterpreter::setPrintCallback(slang::VMPrintFunc callback, void* userData) +{ + m_printCallback = callback; + m_printCallbackUserData = userData; + return SLANG_OK; +} + +void ByteCodeInterpreter::defaultPrintCallback(const char* str, void* userData) +{ + SLANG_UNUSED(userData); + printf("%s", str); +} + +ExecutableFunction::InstIterator ExecutableFunction::begin() +{ + ExecutableFunction::InstIterator iter; + iter.codePtr = (uint8_t*)m_codeBuffer.getBuffer(); + return iter; +} + +ExecutableFunction::InstIterator ExecutableFunction::end() +{ + ExecutableFunction::InstIterator iter; + iter.codePtr = (uint8_t*)(m_codeBuffer.getBuffer() + m_codeBuffer.getCount()); + return iter; +} + + +} // namespace Slang + + +SLANG_EXTERN_C SLANG_API SlangResult slang_createByteCodeRunner( + const slang::ByteCodeRunnerDesc* desc, + slang::IByteCodeRunner** outByteCodeRunner) +{ + SLANG_UNUSED(desc); + Slang::RefPtr<Slang::ByteCodeInterpreter> runner = new Slang::ByteCodeInterpreter(); + *outByteCodeRunner = static_cast<slang::IByteCodeRunner*>(runner.detach()); + return SLANG_OK; +} + +SLANG_EXTERN_C SLANG_API SlangResult +slang_disassembleByteCode(slang::IBlob* moduleBlob, slang::IBlob** outDisassemblyBlob) +{ + Slang::VMModuleView moduleView; + SLANG_RETURN_ON_FAIL(Slang::initVMModule( + (uint8_t*)moduleBlob->getBufferPointer(), + (uint32_t)moduleBlob->getBufferSize(), + &moduleView)); + Slang::StringBuilder sb; + sb << moduleView; + *outDisassemblyBlob = Slang::StringBlob::moveCreate(sb.produceString()).detach(); + return SLANG_OK; +} diff --git a/source/slang/slang-vm.h b/source/slang/slang-vm.h new file mode 100644 index 000000000..7f85e398f --- /dev/null +++ b/source/slang/slang-vm.h @@ -0,0 +1,140 @@ +#ifndef SLANG_VM_H +#define SLANG_VM_H + +#include "core/slang-string-util.h" +#include "slang-vm-bytecode.h" + +using namespace slang; + +namespace Slang +{ + +struct ByteCodeExecutionContext +{ + void* currentWorkingSet; + uint32_t currentWorkingSetSizeInBytes; +}; + +class ByteCodeInterpreter; + +// Represents a relocated function code ready for execution. +// Relocated functions are VMInsts allocated in a 8-byte aligned buffer, and instruction headers +// Replaced with actual function pointers that can execute the instruction. +class ExecutableFunction +{ +public: + typedef VMInstIterator<VMExecOperand, VMExecInstHeader> InstIterator; + List<uint64_t> m_codeBuffer; + VMFuncHeader* m_header; + List<uint32_t> m_parameterOffsets; + + InstIterator begin(); + InstIterator end(); +}; + +struct StackFrame +{ + VMExecInstHeader* m_currentInst = nullptr; + void* m_currentFuncCode = nullptr; + size_t m_workingSetOffset = 0; +}; + +class ByteCodeInterpreter : public RefObject, public IByteCodeRunner +{ +public: + SLANG_REF_OBJECT_IUNKNOWN_ALL + ISlangUnknown* getInterface(const Guid& guid); + +public: + VMModuleView m_moduleView; + List<uint8_t> m_code; + StringBuilder m_errorBuilder; + List<ExecutableFunction> m_functions; + Dictionary<String, VMExtFunction> m_extInstHandlers; + SlangResult prepareModuleForExecution(); + void* m_extInstHandlerUserData = nullptr; + List<uint8_t> m_returnRegister; + List<uint64_t> m_workingSetBuffer; + List<StackFrame> m_stack; + List<const char*> m_stringLits; + const char** m_stringLitsPtr = nullptr; + + size_t m_returnValSize = 0; + + void pushFrame(uint32_t size) + { + StackFrame frame; + frame.m_workingSetOffset = + (uint32_t)((uint64_t*)m_currentWorkingSet - m_workingSetBuffer.getBuffer()); + m_stack.add(frame); + auto stackBufferCount = m_workingSetBuffer.getCount(); + m_workingSetBuffer.setCount(m_workingSetBuffer.getCount() + size / sizeof(uint64_t)); + m_currentWorkingSet = m_workingSetBuffer.getBuffer() + stackBufferCount; + } + void popFrame() + { + auto& stackFrame = m_stack.getLast(); + auto lastWorkingSetBufferCount = + (uint32_t)((uint64_t*)m_currentWorkingSet - m_workingSetBuffer.getBuffer()); + m_workingSetBuffer.setCount(lastWorkingSetBufferCount); + m_currentInst = stackFrame.m_currentInst->getNextInst(); + m_currentFuncCode = stackFrame.m_currentFuncCode; + m_currentWorkingSet = m_workingSetBuffer.getBuffer() + stackFrame.m_workingSetOffset; + m_stack.removeLast(); + } + + VMExecInstHeader* m_currentInst = nullptr; + void* m_currentFuncCode = nullptr; + void* m_currentWorkingSet = nullptr; + + VMPrintFunc m_printCallback = nullptr; + void* m_printCallbackUserData = nullptr; + + template<typename... Args> + void reportError(const char* format, Args... args) + { + m_errorBuilder.append(StringUtil::makeStringWithFormat(format, args...)); + m_errorBuilder.append("\n"); + } + + static void defaultPrintCallback(const char* message, void* userData); + ByteCodeInterpreter(); + +public: + virtual SLANG_NO_THROW SlangResult SLANG_MCALL loadModule(IBlob* moduleBlob) override; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + selectFunctionByIndex(uint32_t functionIndex) override; + virtual SLANG_NO_THROW int SLANG_MCALL findFunctionByName(const char* name) override; + + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + getFunctionInfo(uint32_t index, ByteCodeFuncInfo* outInfo) override; + virtual SLANG_NO_THROW void* SLANG_MCALL getCurrentWorkingSet() override + { + return m_currentWorkingSet; + } + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + execute(void* argumentData, size_t argumentSize) override; + virtual SLANG_NO_THROW void SLANG_MCALL getErrorString(slang::IBlob** outBlob) override; + virtual SLANG_NO_THROW void* SLANG_MCALL getReturnValue(size_t* outValueSize) override + { + *outValueSize = m_returnValSize; + return m_returnRegister.getBuffer(); + } + virtual SLANG_NO_THROW void SLANG_MCALL setExtInstHandlerUserData(void* userData) override + { + m_extInstHandlerUserData = userData; + } + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + registerExtCall(const char* name, VMExtFunction functionPtr) override + { + m_extInstHandlers[name] = functionPtr; + return SLANG_OK; + } + + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + setPrintCallback(VMPrintFunc callback, void* userData) override; +}; + +} // namespace Slang + +#endif diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index df07a7637..613fb7090 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -5099,6 +5099,8 @@ RefPtr<EntryPoint> Module::findAndCheckEntryPoint( if (auto existingEntryPoint = findEntryPointByName(name)) return existingEntryPoint; + SLANG_AST_BUILDER_RAII(m_astBuilder); + // If the function hasn't been marked as [shader], then it won't be discovered // by findEntryPointByName. We need to route this to the `findAndValidateEntryPoint` // function. To do that we need to setup a FrontEndCompileRequest and a diff --git a/tests/byte-code/composite.slang b/tests/byte-code/composite.slang new file mode 100644 index 000000000..41729b6c2 --- /dev/null +++ b/tests/byte-code/composite.slang @@ -0,0 +1,34 @@ +//TEST:INTERPRET(filecheck=CHECK): +struct Inner +{ + int header; + int array[3]; +} +struct MyType +{ + float3 value1; + float3 value2; + Inner inner; +} +MyType getValue() +{ + MyType t = { float3(1, 2, 3), float3(-1,-2, 3), { 4, { 5, 6, 7 } } }; + return t; +} +int main() +{ + var t = getValue(); + int sum = 0; + for (int i = 0; i < 3; ++i) + { + sum += t.inner.array[i]; + } + t.value1 += t.value2; + for (int i = 0; i < 3; i++) + { + sum += (int)t.value1[i]; + } + //CHECK: 24 + printf("%d\n", sum); + return 0; +} diff --git a/tests/byte-code/hello.slang b/tests/byte-code/hello.slang new file mode 100644 index 000000000..f3c13cbab --- /dev/null +++ b/tests/byte-code/hello.slang @@ -0,0 +1,17 @@ +//TEST:INTERPRET(filecheck=BC): -disasm +//TEST:INTERPRET(filecheck=CHECK): + +int main(int argc, NativeString* argv) +{ + printf("hello world\n"); + for (int i = 0; i < argc; ++i) + { + printf("%s\n", argv[i]); + } + return 100; +} + +// CHECK: hello world +// CHECK: {{.*}}hello.slang +// BC: func main +// BC: ret
\ No newline at end of file diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 18a50d357..53698726b 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -96,6 +96,19 @@ if(SLANG_ENABLE_SLANGD) ) endif() +# +# Slang Interpreter +# +if(SLANG_ENABLE_SLANGI) + slang_add_target( + slangi + EXECUTABLE + LINK_WITH_PRIVATE core compiler-core slang + INSTALL + EXPORT_SET_NAME SlangTargets + ) +endif() + if(SLANG_ENABLE_GFX) # # `platform` contains all the platform abstractions for a GUI application. diff --git a/tools/slang-test/slang-test-main.cpp b/tools/slang-test/slang-test-main.cpp index 3f7e41cf6..c0697b4a4 100644 --- a/tools/slang-test/slang-test-main.cpp +++ b/tools/slang-test/slang-test-main.cpp @@ -28,6 +28,7 @@ #include "options.h" #include "parse-diagnostic-util.h" #include "slangc-tool.h" +#include "slangi-tool.h" #include "test-context.h" #include "test-reporter.h" @@ -860,7 +861,7 @@ Result spawnAndWaitSharedLibrary( stdWriters.setWriter(SLANG_WRITER_CHANNEL_STD_ERROR, &stdError); stdWriters.setWriter(SLANG_WRITER_CHANNEL_STD_OUTPUT, &stdOut); - if (exeName == "slangc") + if (exeName == "slangc" || exeName == "slangi") { stdWriters.setWriter(SLANG_WRITER_CHANNEL_DIAGNOSTIC, &stdError); } @@ -902,7 +903,7 @@ Result spawnAndWaitProxy( // Get the name of the thing to execute String exeName = Path::getFileNameWithoutExt(inCmdLine.m_executableLocation.m_pathOrName); - if (exeName == "slangc") + if (exeName == "slangc" || exeName == "slangi") { // If the test is slangc there is a command line version we can just directly use // return spawnAndWaitExe(context, testPath, inCmdLine, outRes); @@ -1065,6 +1066,7 @@ static PassThroughFlags _getPassThroughFlagsForTarget(SlangCompileTarget target) case SLANG_CUDA_SOURCE: case SLANG_METAL: case SLANG_WGSL: + case SLANG_HOST_VM: { return 0; } @@ -1303,6 +1305,10 @@ static SlangResult _extractTestRequirements(const CommandLine& cmdLine, TestRequ { return _extractSlangCTestRequirements(cmdLine, ioInfo); } + else if (exeName == "slangi") + { + return SLANG_OK; + } else if (exeName == "slang-reflection-test") { return _extractReflectionTestRequirements(cmdLine, ioInfo); @@ -1562,6 +1568,12 @@ String findExpectedPath(const TestInput& input, const char* postFix) return ""; } +static SlangResult _initSlangInterpreter(TestContext* context, CommandLine& ioCmdLine) +{ + ioCmdLine.setExecutableLocation(ExecutableLocation(context->options.binDir, "slangi")); + return SLANG_OK; +} + static SlangResult _initSlangCompiler(TestContext* context, CommandLine& ioCmdLine) { ioCmdLine.setExecutableLocation(ExecutableLocation(context->options.binDir, "slangc")); @@ -2373,6 +2385,67 @@ TestResult runSimpleLineTest(TestContext* context, TestInput& input) return _validateOutput(context, input, actualOutput, false); } +TestResult runInterpreterTest(TestContext* context, TestInput& input) +{ + // need to execute the stand-alone Slang compiler on the file, and compare its output to what we + // expect + auto outputStem = input.outputStem; + + CommandLine cmdLine; + + List<String> args; + + for (Index i = 0; i < input.testOptions->args.getCount(); i++) + { + auto& arg = input.testOptions->args[i]; + if (arg == "-disasm") + cmdLine.addArg(arg); + else if (arg == "-entry") + { + cmdLine.addArg(arg); + i++; + if (i < input.testOptions->args.getCount()) + { + cmdLine.addArg(input.testOptions->args[i]); + } + } + else + { + args.add(arg); + } + } + + cmdLine.addArg(input.filePath); + + for (auto arg : args) + { + cmdLine.addArg(arg); + } + + if (SLANG_FAILED(_initSlangInterpreter(context, cmdLine))) + { + return TestResult::Ignored; + } + + ExecuteResult exeRes; + TEST_RETURN_ON_DONE(spawnAndWait(context, outputStem, input.spawnType, cmdLine, exeRes)); + + if (context->isCollectingRequirements()) + { + return TestResult::Pass; + } + + String actualOutput = getOutput(exeRes); + + return _validateOutput( + context, + input, + actualOutput, + false, + "result code = 0\nstandard error = {\n}\nstandard output = {\n}\n", + [&input](auto e, auto a) { return _areResultsEqual(input.testOptions->type, e, a); }); +} + TestResult runCompile(TestContext* context, TestInput& input) { auto outputStem = input.outputStem; @@ -3952,6 +4025,7 @@ static const TestCommandInfo s_testCommandInfos[] = { {"SIMPLE", &runSimpleTest, 0}, {"SIMPLE_EX", &runSimpleTest, 0}, {"SIMPLE_LINE", &runSimpleLineTest, 0}, + {"INTERPRET", &runInterpreterTest, 0}, {"REFLECTION", &runReflectionTest, 0}, {"CPU_REFLECTION", &runReflectionTest, 0}, {"COMMAND_LINE_SIMPLE", &runSimpleCompareCommandLineTest, 0}, @@ -4851,6 +4925,11 @@ SlangResult innerMain(int argc, char** argv) context.setInnerMainFunc("slangc", &SlangCTool::innerMain); } + { + // We can set the slangc command line tool, to just use the function defined here + context.setInnerMainFunc("slangi", &SlangITool::innerMain); + } + SLANG_RETURN_ON_FAIL( Options::parse(argc, argv, &categorySet, StdWriters::getError(), &context.options)); diff --git a/tools/slang-test/slangi-tool-impl.h b/tools/slang-test/slangi-tool-impl.h new file mode 100644 index 000000000..b442800f2 --- /dev/null +++ b/tools/slang-test/slangi-tool-impl.h @@ -0,0 +1,234 @@ +namespace SlangITool +{ +static void printCallback(const char* message, void* userData) +{ + auto stdWriters = (StdWriters*)userData; + if (stdWriters) + { + stdWriters->getOut().print("%s", message); + } +} + +static SlangResult compileAndInterpret( + slang::IGlobalSession* sharedSession, + StdWriters* stdWriters, + UnownedStringSlice fileName, + const char* entryPointName, + bool disasm, + int argc, + const char* const* argv) +{ + auto maybePrintDiagnostic = [&](const ComPtr<slang::IBlob>& diagnosticBlob) + { + if (diagnosticBlob) + { + const char* diagText = (const char*)diagnosticBlob->getBufferPointer(); + stdWriters->getError().print("%s\n", diagText); + } + }; + + ComPtr<slang::IGlobalSession> globalSession; + SLANG_RETURN_ON_FAIL(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef())); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HOST_VM; + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + sessionDesc.compilerOptionEntryCount = 0; + String pathName = Path::getParentDirectory(fileName); + String moduleName = Path::getFileNameWithoutExt(fileName); + const char* searchPaths[] = {pathName.getBuffer()}; + if (pathName.getLength()) + { + sessionDesc.searchPathCount = 1; + sessionDesc.searchPaths = searchPaths; + } + ComPtr<slang::ISession> session; + SLANG_RETURN_ON_FAIL(globalSession->createSession(sessionDesc, session.writeRef())); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModule(moduleName.getBuffer(), diagnosticBlob.writeRef()); + if (!module) + { + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + ComPtr<slang::IEntryPoint> entryPoint; + if (SLANG_FAILED(module->findAndCheckEntryPoint( + entryPointName, + SLANG_STAGE_DISPATCH, + entryPoint.writeRef(), + diagnosticBlob.writeRef()))) + { + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + + ComPtr<slang::IComponentType> compositeComponent; + slang::IComponentType* components[] = {module, entryPoint.get()}; + if (SLANG_FAILED(session->createCompositeComponentType( + components, + 2, + compositeComponent.writeRef(), + diagnosticBlob.writeRef()))) + { + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + + ComPtr<slang::IComponentType> linkedProgram; + if (SLANG_FAILED(compositeComponent->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()))) + { + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + ComPtr<slang::IBlob> code; + + if (SLANG_FAILED(linkedProgram->getTargetCode(0, code.writeRef(), diagnosticBlob.writeRef()))) + { + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + + if (code->getBufferSize() == 0) + { + return SLANG_FAIL; + } + + if (disasm) + { + ComPtr<slang::IBlob> disasmBlob; + if (SLANG_FAILED(slang_disassembleByteCode(code, disasmBlob.writeRef()))) + { + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + const char* disasmText = (const char*)disasmBlob->getBufferPointer(); + stdWriters->getOut().print("%s\n", disasmText); + return SLANG_OK; + } + + // Create a byte code runner and interpret the code. + ComPtr<slang::IByteCodeRunner> runner; + slang::ByteCodeRunnerDesc runnerDesc = {}; + SLANG_RETURN_ON_FAIL(slang_createByteCodeRunner(&runnerDesc, runner.writeRef())); + runner->setPrintCallback(printCallback, stdWriters); + + if (SLANG_FAILED(runner->loadModule(code))) + { + runner->getErrorString(diagnosticBlob.writeRef()); + maybePrintDiagnostic(diagnosticBlob); + } + auto funcIndex = runner->findFunctionByName(entryPointName); + if (funcIndex < 0) + { + stdWriters->getError().print("Function '%s' not found in byte code.\n", entryPointName); + return SLANG_FAIL; + } + + if (SLANG_FAILED(runner->selectFunctionByIndex((uint32_t)funcIndex))) + { + runner->getErrorString(diagnosticBlob.writeRef()); + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + + struct Arguments + { + uint32_t argc; + const char* const* argv; + }; + Arguments args; + args.argc = argc; + args.argv = argv; + void* arguments = nullptr; + size_t argSize = 0; + slang::ByteCodeFuncInfo funcInfo; + if (SLANG_FAILED(runner->getFunctionInfo((uint32_t)funcIndex, &funcInfo))) + { + runner->getErrorString(diagnosticBlob.writeRef()); + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + if (funcInfo.parameterCount == 2) + { + arguments = &args; + argSize = sizeof(Arguments); + } + if (SLANG_FAILED(runner->execute(arguments, argSize))) + { + runner->getErrorString(diagnosticBlob.writeRef()); + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + return SLANG_OK; +} + +SlangResult innerMain( + StdWriters* stdWriters, + slang::IGlobalSession* sharedSession, + int argc, + const char* const* argv) +{ + StdWriters::setSingleton(stdWriters); + + // Assume we will used the shared session + ComPtr<slang::IGlobalSession> session(sharedSession); + + // The sharedSession always has a pre-loaded core module. + // This differed test checks if the command line has an option to setup the core module. + // If so we *don't* use the sharedSession, and create a new session without the core module just + // for this compilation. + if (TestToolUtil::hasDeferredCoreModule(Index(argc - 1), argv + 1)) + { + SLANG_RETURN_ON_FAIL( + slang_createGlobalSessionWithoutCoreModule(SLANG_API_VERSION, session.writeRef())); + } + + String entryPointName = toSlice("main"); + UnownedStringSlice fileName; + bool disasm = false; + int innerArgIndex = 0; + if (argc < 2) + { + return SLANG_FAIL; + } + for (auto i = 1; i < argc; i++) + { + auto arg = UnownedStringSlice(argv[i]); + if (arg == "-entry") + { + entryPointName = UnownedStringSlice(argv[++i]); + } + else if (arg == "-disasm") + { + disasm = true; + } + else if (arg.startsWith("-")) + { + return SLANG_FAIL; + } + else + { + fileName = arg; + innerArgIndex = i; + break; + } + } + if (!fileName.getLength()) + { + return SLANG_FAIL; + } + + auto result = compileAndInterpret( + session, + stdWriters, + fileName, + entryPointName.getBuffer(), + disasm, + argc - innerArgIndex, + argv + innerArgIndex); + + return result; +} +} // namespace SlangITool diff --git a/tools/slang-test/slangi-tool.cpp b/tools/slang-test/slangi-tool.cpp new file mode 100644 index 000000000..ef9d311f7 --- /dev/null +++ b/tools/slang-test/slangi-tool.cpp @@ -0,0 +1,10 @@ +// test-context.cpp +#include "slangi-tool.h" + +#include "../../source/core/slang-exception.h" +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-test-tool-util.h" + +using namespace Slang; + +#include "slangi-tool-impl.h" diff --git a/tools/slang-test/slangi-tool.h b/tools/slang-test/slangi-tool.h new file mode 100644 index 000000000..1cba47018 --- /dev/null +++ b/tools/slang-test/slangi-tool.h @@ -0,0 +1,19 @@ +// slangi-tool.h + +#ifndef SLANGI_TOOL_H_INCLUDED +#define SLANGI_TOOL_H_INCLUDED + +#include "../../source/core/slang-std-writers.h" + +/* The slangi 'tool' interface, such that slangc like functionality is available directly without +invoking slangc command line tool, or need for a dll/shared library. */ +namespace SlangITool +{ +SlangResult innerMain( + Slang::StdWriters* stdWriters, + SlangSession* session, + int argc, + const char* const* argv); +}; + +#endif // SLANGI_TOOL_H_INCLUDED diff --git a/tools/slang-unit-test/unit-test-slang-vm.cpp b/tools/slang-unit-test/unit-test-slang-vm.cpp new file mode 100644 index 000000000..0f5e0f9f3 --- /dev/null +++ b/tools/slang-unit-test/unit-test-slang-vm.cpp @@ -0,0 +1,102 @@ +// unit-test-slang-vm.cpp + +#include "core/slang-memory-file-system.h" +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +#include <stdio.h> +#include <stdlib.h> + +using namespace Slang; + +SLANG_UNIT_TEST(slangVM) +{ + const char* testSource = R"( + int one() { return 1; } + int sum(int x) + { + int result = 0; + for (int i = 0; i <= x; i++) + { + result += i; + } + return result + one(); + } + [shader("dispatch")] + int dispatchMain(uniform int2 v, out int c) + { + int a = v.x; + int b = v.y; + int tmp = 0; + if (a > 0) + tmp = a + b; + else + tmp = b - a; + tmp += sum(b); + c = tmp; + return 100; + } + )"; + + // Create Slang session and compile code. + ComPtr<slang::IBlob> code; + String disasmText; + { + ComPtr<slang::IGlobalSession> globalSession; + SLANG_CHECK( + slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HOST_VM; + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + sessionDesc.compilerOptionEntryCount = 0; + + ComPtr<slang::ISession> session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModuleFromSourceString( + "test", + "test.slang", + testSource, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + ComPtr<slang::IComponentType> linkedProgram; + module->link(linkedProgram.writeRef()); + + + linkedProgram->getTargetCode(0, code.writeRef(), diagnosticBlob.writeRef()); + + SLANG_CHECK(code->getBufferSize() > 0); + + ComPtr<slang::IBlob> disasmBlob; + SLANG_CHECK(slang_disassembleByteCode(code, disasmBlob.writeRef()) == SLANG_OK); + disasmText = (const char*)disasmBlob->getBufferPointer(); + SLANG_CHECK(disasmText.indexOf("ret") != -1); + } + + // Create a byte code runner and interpret the code. + ComPtr<slang::IByteCodeRunner> runner; + slang::ByteCodeRunnerDesc runnerDesc = {}; + SLANG_CHECK(slang_createByteCodeRunner(&runnerDesc, runner.writeRef()) == SLANG_OK); + SLANG_CHECK(runner->loadModule(code) == SLANG_OK); + SLANG_CHECK(runner->selectFunctionByIndex(0) == SLANG_OK); + struct Params + { + int a; + int b; + int* result; + }; + int result = 0; + Params params = {1, 2, &result}; + SLANG_CHECK(runner->execute(¶ms, sizeof(params)) == SLANG_OK); + SLANG_CHECK(result == 7); + + size_t returnValSize = 0; + int* returnVal = (int*)runner->getReturnValue(&returnValSize); + SLANG_CHECK(returnValSize == sizeof(int)); + SLANG_CHECK(*returnVal == 100); +} diff --git a/tools/slangi/main.cpp b/tools/slangi/main.cpp new file mode 100644 index 000000000..2ef680324 --- /dev/null +++ b/tools/slangi/main.cpp @@ -0,0 +1,232 @@ +// main.cpp + +// This file implements the entry point for `slangi`, an interpreter for the Slang language. + +#include "../../source/core/slang-basic.h" +#include "core/slang-io.h" +#include "slang-com-ptr.h" +#include "slang.h" + +using namespace Slang; +using namespace slang; + +void printUsage() +{ + printf("Slang Interpreter (Experimental)\n"); + printf("Compile and interpret Slang code.\n"); + printf("Usage: slangi [options] <filename>\n"); + printf("Options:\n"); + printf(" -entry <name> Specify the entry point function name to run. (default: main)\n"); + printf(" -disasm Disassemble the bytecode after compilation.\n"); + printf(" -help Show this help message\n"); +} + +void maybePrintDiagnostic(const ComPtr<slang::IBlob>& diagnosticBlob) +{ + if (diagnosticBlob) + { + const char* diagText = (const char*)diagnosticBlob->getBufferPointer(); + fprintf(stderr, "%s\n", diagText); + } +} + +SlangResult compileAndInterpret( + UnownedStringSlice fileName, + const char* entryPointName, + bool disasm, + int argc, + const char* const* argv) +{ + ComPtr<slang::IGlobalSession> globalSession; + SLANG_RETURN_ON_FAIL(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef())); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HOST_VM; + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + sessionDesc.compilerOptionEntryCount = 0; + String pathName = Path::getParentDirectory(fileName); + String moduleName = Path::getFileNameWithoutExt(fileName); + const char* searchPaths[] = {pathName.getBuffer()}; + if (pathName.getLength()) + { + sessionDesc.searchPathCount = 1; + sessionDesc.searchPaths = searchPaths; + } + ComPtr<slang::ISession> session; + SLANG_RETURN_ON_FAIL(globalSession->createSession(sessionDesc, session.writeRef())); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModule(moduleName.getBuffer(), diagnosticBlob.writeRef()); + if (!module) + { + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + ComPtr<slang::IEntryPoint> entryPoint; + if (SLANG_FAILED(module->findAndCheckEntryPoint( + entryPointName, + SLANG_STAGE_DISPATCH, + entryPoint.writeRef(), + diagnosticBlob.writeRef()))) + { + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + + ComPtr<slang::IComponentType> compositeComponent; + slang::IComponentType* components[] = {module, entryPoint.get()}; + if (SLANG_FAILED(session->createCompositeComponentType( + components, + 2, + compositeComponent.writeRef(), + diagnosticBlob.writeRef()))) + { + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + + ComPtr<slang::IComponentType> linkedProgram; + if (SLANG_FAILED(compositeComponent->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()))) + { + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + ComPtr<slang::IBlob> code; + + if (SLANG_FAILED(linkedProgram->getTargetCode(0, code.writeRef(), diagnosticBlob.writeRef()))) + { + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + + if (code->getBufferSize() == 0) + { + return SLANG_FAIL; + } + + if (disasm) + { + ComPtr<slang::IBlob> disasmBlob; + if (SLANG_FAILED(slang_disassembleByteCode(code, disasmBlob.writeRef()))) + { + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + const char* disasmText = (const char*)disasmBlob->getBufferPointer(); + printf("%s\n", disasmText); + } + + // Create a byte code runner and interpret the code. + ComPtr<slang::IByteCodeRunner> runner; + slang::ByteCodeRunnerDesc runnerDesc = {}; + SLANG_RETURN_ON_FAIL(slang_createByteCodeRunner(&runnerDesc, runner.writeRef())); + if (SLANG_FAILED(runner->loadModule(code))) + { + runner->getErrorString(diagnosticBlob.writeRef()); + maybePrintDiagnostic(diagnosticBlob); + } + auto funcIndex = runner->findFunctionByName(entryPointName); + if (funcIndex < 0) + { + printf("Function '%s' not found in byte code.\n", entryPointName); + return SLANG_FAIL; + } + + if (SLANG_FAILED(runner->selectFunctionByIndex((uint32_t)funcIndex))) + { + runner->getErrorString(diagnosticBlob.writeRef()); + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + + struct Arguments + { + uint32_t argc; + const char* const* argv; + }; + Arguments args; + args.argc = argc; + args.argv = argv; + void* arguments = nullptr; + size_t argSize = 0; + slang::ByteCodeFuncInfo funcInfo; + if (SLANG_FAILED(runner->getFunctionInfo((uint32_t)funcIndex, &funcInfo))) + { + runner->getErrorString(diagnosticBlob.writeRef()); + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + if (funcInfo.parameterCount == 2) + { + arguments = &args; + argSize = sizeof(Arguments); + } + if (SLANG_FAILED(runner->execute(arguments, argSize))) + { + runner->getErrorString(diagnosticBlob.writeRef()); + maybePrintDiagnostic(diagnosticBlob); + return SLANG_FAIL; + } + size_t returnValueSize = 0; + void* returnVal = runner->getReturnValue(&returnValueSize); + SlangResult result = SLANG_OK; + memcpy(&result, returnVal, returnValueSize); + return result; +} + +int main(int argc, const char* const* argv) +{ + String entryPointName = toSlice("main"); + UnownedStringSlice fileName; + bool disasm = false; + int innerArgIndex = 0; + if (argc < 2) + { + printUsage(); + return 0; + } + for (auto i = 1; i < argc; i++) + { + auto arg = UnownedStringSlice(argv[i]); + if (arg == "-entry") + { + entryPointName = UnownedStringSlice(argv[++i]); + } + else if (arg == "-help" || arg == "--help") + { + printUsage(); + return 0; + } + else if (arg == "-disasm") + { + disasm = true; + } + else if (arg.startsWith("-")) + { + fprintf(stderr, "Unknown option: %s\n", arg.begin()); + printUsage(); + return -1; + } + else + { + fileName = arg; + innerArgIndex = i; + break; + } + } + if (!fileName.getLength()) + { + printUsage(); + return 0; + } + + auto result = compileAndInterpret( + fileName, + entryPointName.getBuffer(), + disasm, + argc - innerArgIndex, + argv + innerArgIndex); + slang::shutdown(); + return result; +} diff --git a/tools/test-server/test-server-main.cpp b/tools/test-server/test-server-main.cpp index 633a23d7e..63535ec92 100644 --- a/tools/test-server/test-server-main.cpp +++ b/tools/test-server/test-server-main.cpp @@ -188,6 +188,9 @@ SlangResult innerMain( } // namespace SlangCTool +// SlangITool +#include "../slang-test/slangi-tool-impl.h" + /* !!!!!!!!!!!!!!!!!!!!!!!!!!!! TestServer !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ SlangResult TestServer::init(int argc, const char* const* argv) @@ -308,6 +311,10 @@ TestServer::InnerMainFunc TestServer::getToolFunction(const String& name, Diagno { return &SlangCTool::innerMain; } + else if (name == "slangi") + { + return &SlangITool::innerMain; + } StringBuilder sharedLibToolBuilder; sharedLibToolBuilder.append(name); |
