diff options
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 1195 |
1 files changed, 653 insertions, 542 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index bafbb79f1..3947e8468 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -12,6 +12,7 @@ #include "slang-spirv-val.h" #include "spirv/unified1/spirv.h" #include "../core/slang-memory-arena.h" +#include <type_traits> namespace Slang { @@ -275,6 +276,80 @@ struct SpvSnippetEmitContext List<SpvWord> argumentIds; }; +// A structure which can hold an integer literal, either one word or several +struct SpvLiteralInteger +{ + static SpvLiteralInteger from32(int32_t value) { return from32(uint32_t(value)); } + static SpvLiteralInteger from32(uint32_t value) { return SpvLiteralInteger{{value}}; } + static SpvLiteralInteger from64(int64_t value) { return from64(uint64_t(value)); } + static SpvLiteralInteger from64(uint64_t value) { return SpvLiteralInteger{{SpvWord(value), SpvWord(value >> 32)}}; } + List<SpvWord> value; // Words, stored low words to high (TODO, SmallArray or something here) +}; + +// A structure which can hold bitwise literal, either one word or several +struct SpvLiteralBits +{ + static SpvLiteralBits from32(uint32_t value) { return SpvLiteralBits{{value}}; } + static SpvLiteralBits from64(uint64_t value) { return SpvLiteralBits{{SpvWord(value), SpvWord(value >> 32)}}; } + List<SpvWord> value; // Words, stored low words to high (TODO, SmallArray or something here) +}; + +// As a convenience, there are often cases where +// we will want to emit all of the operands of some +// IR instruction as <id> operands of a SPIR-V +// instruction. This is handy in cases where the +// Slang IR and SPIR-V instructions agree on the +// number, order, and meaning of their operands. +/// Helper type for emitting all the operands of the current IR instruction +struct OperandsOf +{ + OperandsOf(IRInst* irInst) + : irInst(irInst) + {} + + IRInst* irInst = nullptr; +}; + +/// Helper type for not emitting an operand in this position +struct SkipThisOptionalOperand {}; + +template<typename T> +struct OptionalOperand +{ + static_assert(std::is_trivial_v<T>); + OptionalOperand(SkipThisOptionalOperand) : present(false) {} + OptionalOperand(T value) : present(true), value(value) {} + bool present; + T value; +}; + +template<typename T> +OptionalOperand<T> nullOptionOperand() +{ + return OptionalOperand<T>{false}; +} + +template<typename T> +OptionalOperand<T> someOptionOperand(T t) +{ + return OptionalOperand<T>{true, t}; +} + +template<typename T> +constexpr bool isPlural = false; +template<typename T> +constexpr bool isPlural<List<T>> = true; +template<typename T> +constexpr bool isPlural<IROperandList<T>> = true; +template<typename T, Index N> +constexpr bool isPlural<Array<T, N>> = true; +template<> +constexpr bool isPlural<OperandsOf> = true; +template<> +constexpr bool isPlural<IRUse*> = true; +template<typename T> +constexpr bool isSingular = !isPlural<T>; + // Now that we've defined the intermediate data structures we will // use to represent SPIR-V code during emission, we will move on // to defining the main context type that will drive SPIR-V @@ -471,10 +546,12 @@ struct SPIRVEmitContext // Holds a stack of instructions operands *BEFORE* they added to the instruction. List<SpvWord> m_operandStack; - // The current instruction being constructed. Cannot add operands unless it is set. + // The current instruction being constructed. Cannot add operands unless it + // is set, or we are peeking at some operands to see if we have them memoized SpvInst* m_currentInst = nullptr; + bool m_peekingOperands = false; - // Operands can only be added when inside of a InstConstructScope + // Operands can only be added when inside of a InstConstructScope or... struct InstConstructScope { SLANG_FORCE_INLINE operator SpvInst*() const { return m_inst; } @@ -495,6 +572,28 @@ struct SPIRVEmitContext Index m_operandsStartIndex; ///< The start index for operands of m_inst }; + // ...If we're speculatively adding them to see if we have a memoized results + struct OperandMemoizeScope + { + OperandMemoizeScope(SPIRVEmitContext* context) : m_context(context) + { + m_tmpOperandStack.swapWith(m_context->m_operandStack); + std::swap(m_tmpPeeking, m_context->m_peekingOperands); + std::swap(m_tmpInst, m_context->m_currentInst); + } + ~OperandMemoizeScope() + { + std::swap(m_tmpInst, m_context->m_currentInst); + std::swap(m_tmpPeeking, m_context->m_peekingOperands); + m_tmpOperandStack.swapWith(m_context->m_operandStack); + } + + SPIRVEmitContext* m_context; + List<SpvWord> m_tmpOperandStack; + bool m_tmpPeeking = true; + SpvInst* m_tmpInst = nullptr; + }; + /// Holds memory for instructions and operands. MemoryArena m_memoryArena; @@ -586,7 +685,7 @@ struct SPIRVEmitContext void emitOperand(SpvWord word) { // Can only add operands if we are constructing an instruction (ie in _beginInst/_endInst) - SLANG_ASSERT(m_currentInst); + SLANG_ASSERT(m_currentInst || m_peekingOperands); m_operandStack.add(word); } @@ -623,7 +722,7 @@ struct SPIRVEmitContext void emitOperand(UnownedStringSlice const& text) { // Can only emitOperands if we are in an instruction - SLANG_ASSERT(m_currentInst); + SLANG_ASSERT(m_currentInst || m_peekingOperands); SLANG_COMPILE_TIME_ASSERT(sizeof(SpvWord) == 4); // Assert that `text` doesn't contain any embedded nul bytes, since they @@ -670,16 +769,53 @@ struct SPIRVEmitContext void emitOperand(ResultIDToken) { + // This is the one case we shouldn't be peeking at operands, as it + // depends on having an instruction under construction SLANG_ASSERT(m_currentInst); // A result <id> operand uses the <id> of the instruction itself (which is m_currentInst) emitOperand(getID(m_currentInst)); } - void emitOperand(SpvDecoration decoration) { emitOperand((SpvWord)decoration); } + void emitOperand(const SpvLiteralBits& bits) + { + for(const auto v : bits.value) + emitOperand(v); + } - void emitOperand(SpvBuiltIn builtin) { emitOperand((SpvWord)builtin); } - void emitOperand(SpvStorageClass val) { emitOperand((SpvWord)val); } + void emitOperand(const SpvLiteralInteger& integer) + { + for(const auto v : integer.value) + emitOperand(v); + } + + template<typename T> + void emitOperand(const List<T>& os) + { + for(const auto& o : os) + emitOperand(o); + } + + template<typename T> + void emitOperand(const IROperandList<T>& os) + { + for(const auto& o : os) + emitOperand(o); + } + + template<typename T, Index N> + void emitOperand(const Array<T, N>& os) + { + for(const auto& o : os) + emitOperand(o); + } + + template<typename T> + void emitOperand(const ArrayView<T>& os) + { + for(const auto& o : os) + emitOperand(o); + } template<typename TConstant> struct ConstantValueKey @@ -697,7 +833,7 @@ struct SPIRVEmitContext }; Dictionary<ConstantValueKey<IRIntegerValue>, SpvInst*> m_spvIntConstants; Dictionary<ConstantValueKey<IRFloatingPointValue>, SpvInst*> m_spvFloatConstants; - SpvInst* emitIntConstant(IRIntegerValue val, IRType* type) + SpvInst* emitIntConstant(IRIntegerValue val, IRType* type, IRInst* inst = nullptr) { ConstantValueKey<IRIntegerValue> key; key.value = val; @@ -705,8 +841,6 @@ struct SPIRVEmitContext SpvInst* result = nullptr; if (m_spvIntConstants.tryGetValue(key, result)) return result; - SpvWord valWord; - memcpy(&valWord, &val, sizeof(SpvWord)); switch (type->getOp()) { case kIROp_Int64Type: @@ -716,34 +850,27 @@ struct SPIRVEmitContext case kIROp_UIntPtrType: #endif { - SpvWord valHighWord; - memcpy(&valHighWord, (char*)(&val) + 4, sizeof(SpvWord)); - result = emitInst( - getSection(SpvLogicalSectionID::ConstantsAndTypes), - nullptr, - SpvOpConstant, + result = emitOpConstant( + inst, type, - kResultID, - valWord, - valHighWord); + SpvLiteralBits::from64(uint64_t(val)) + ); break; } default: { - result = emitInst( - getSection(SpvLogicalSectionID::ConstantsAndTypes), - nullptr, - SpvOpConstant, + result = emitOpConstant( + inst, type, - kResultID, - valWord); + SpvLiteralBits::from32(uint32_t(val)) + ); break; } } m_spvIntConstants[key] = result; return result; } - SpvInst* emitFloatConstant(IRFloatingPointValue val, IRType* type) + SpvInst* emitFloatConstant(IRFloatingPointValue val, IRType* type, IRInst* inst = nullptr) { ConstantValueKey<IRFloatingPointValue> key; key.value = val; @@ -751,50 +878,34 @@ struct SPIRVEmitContext SpvInst* result = nullptr; if (m_spvFloatConstants.tryGetValue(key, result)) return result; - SpvWord valWord; - memcpy(&valWord, &val, sizeof(SpvWord)); if (type->getOp() == kIROp_DoubleType) { - SpvWord valHighWord; - memcpy(&valHighWord, (char*)(&val) + 4, sizeof(SpvWord)); - result = emitInst( - getSection(SpvLogicalSectionID::ConstantsAndTypes), - nullptr, - SpvOpConstant, + result = emitOpConstant( + inst, type, - kResultID, - valWord, - valHighWord); + SpvLiteralBits::from64(uint64_t(DoubleAsInt64(val)))); } - else + else if(type->getOp() == kIROp_FloatType) { - result = emitInst( - getSection(SpvLogicalSectionID::ConstantsAndTypes), - nullptr, - SpvOpConstant, + result = emitOpConstant( + inst, + type, + SpvLiteralBits::from32(uint32_t(FloatAsInt(float(val))))); + } + else if(type->getOp() == kIROp_HalfType) + { + result = emitOpConstant( + inst, type, - kResultID, - valWord); + SpvLiteralBits::from32(uint32_t(FloatToHalf(float(val))))); + } + else + { + SLANG_UNEXPECTED("missing case in SPIR-V emitFloatConstant"); } m_spvFloatConstants[key] = result; return result; } - // As another convenience, there are often cases where - // we will want to emit all of the operands of some - // IR instruction as <id> operands of a SPIR-V - // instruction. This is handy in cases where the - // Slang IR and SPIR-V instructions agree on the - // number, order, and meaning of their operands. - - /// Helper type for emitting all the operands of the current IR instruction - struct OperandsOf - { - OperandsOf(IRInst* irInst) - : irInst(irInst) - {} - - IRInst* irInst = nullptr; - }; /// Emit operand words for all the operands of a given IR instruction void emitOperand(OperandsOf const& other) @@ -807,101 +918,104 @@ struct SPIRVEmitContext } } + /// Do nothing + void emitOperand(SkipThisOptionalOperand) { } + + template<typename T> + void emitOperand(OptionalOperand<T> o) + { + if(o.present) + emitOperand(o.value); + } + // With the above routines, code can easily construct a SPIR-V // instruction with arbitrary operands over multiple lines of code. // - // In many cases, however, it is desirable to be able to emit - // an instruction more compactly, and for that we will introduce - // a number of `emitInst()` helpers that handle creating an - // instruction, filling in its operands, and adding it to a parent. + // The safe way to call these routines is encoded in the below `emitInst` + // function. // - // These routines are overloaded on the number of operands, and - // also templates to work with any of the types for which - // `emitOperand()` works. + // This allows one to generically output a SPIR-V instruction with any + // desired operands. // - // In all of these cases, the caller takes responsibility for - // correctly matching the SPIR-V encoding rules for the chosen - // opcode, including whether a type <id> or result <id> is - // required. - - SpvInst* emitInst(SpvInstParent* parent, IRInst* irInst, SpvOp opcode) + // This function performs no checks that it is actually being used + // correctly with respect to the SPIR-V rules for each opcode. As such, a + // more type safe function for each opcode is included in + // 'slang-emit-spirv-ops.h', and available in this class. You are + // encouraged to use these instead. + // + template<typename... Operands> + SpvInst* emitInst(SpvInstParent* parent, IRInst* irInst, SpvOp opcode, const Operands& ...ops) { InstConstructScope scopeInst(this, opcode, irInst); SpvInst* spvInst = scopeInst; + (emitOperand(ops), ...); parent->addInst(spvInst); return spvInst; } - template<typename A> - SpvInst* emitInst(SpvInstParent* parent, IRInst* irInst, SpvOp opcode, A const& a) + template<typename OperandEmitFunc> + SpvInst* emitInstCustomOperandFunc(SpvInstParent* parent, IRInst* irInst, SpvOp opcode, const OperandEmitFunc& f) { InstConstructScope scopeInst(this, opcode, irInst); SpvInst* spvInst = scopeInst; - emitOperand(a); + f(); parent->addInst(spvInst); return spvInst; } - template<typename A, typename B> - SpvInst* emitInst(SpvInstParent* parent, IRInst* irInst, SpvOp opcode, A const& a, B const& b) + // Emits a SPV Inst with deduplication + // This is used where our IR doesn't guarantee uniqueness but SPIR-V + // requires it + template<typename... Operands> + SpvInst* emitInstMemoized( + SpvInstParent* parent, + IRInst* irInst, + SpvOp opcode, + // We take the resultId here explicitly here to make sure we don't try + // and memoize its value. + ResultIDToken resultId, + const Operands& ...ops + ) { - InstConstructScope scopeInst(this, opcode, irInst); - SpvInst* spvInst = scopeInst; - emitOperand(a); - emitOperand(b); - parent->addInst(spvInst); - return spvInst; - } + List<SpvWord> ourOperands; + { + auto scopePeek = OperandMemoizeScope(this); + (emitOperand(ops), ...); + // Steal our operands back, so we don't have to calculate them + // again + ourOperands = std::move(m_operandStack); + } - template<typename A, typename B, typename C> - SpvInst* emitInst(SpvInstParent* parent, IRInst* irInst, SpvOp opcode, A const& a, B const& b, C const& c) - { - InstConstructScope scopeInst(this, opcode, irInst); - SpvInst* spvInst = scopeInst; - emitOperand(a); - emitOperand(b); - emitOperand(c); - parent->addInst(spvInst); - return spvInst; - } + // Hash the whole global stack and opcode + SpvTypeInstKey key; + key.words.add(opcode); + key.words.addRange(ourOperands); - template<typename A, typename B, typename C, typename D> - SpvInst* emitInst(SpvInstParent* parent, IRInst* irInst, SpvOp opcode, A const& a, B const& b, C const& c, D const& d) - { - InstConstructScope scopeInst(this, opcode, irInst); - SpvInst* spvInst = scopeInst; - emitOperand(a); - emitOperand(b); - emitOperand(c); - emitOperand(d); - parent->addInst(spvInst); - return spvInst; - } + // If we have seen this before, return the memoized instruction + if (SpvInst** memoized = m_spvTypeInsts.tryGetValue(key)) + return *memoized; - template<typename A, typename B, typename C, typename D, typename E> - SpvInst* emitInst(SpvInstParent* parent, IRInst* irInst, SpvOp opcode, A const& a, B const& b, C const& c, D const& d, E const& e) - { + // Otherwise, we can construct our instruction and record the result InstConstructScope scopeInst(this, opcode, irInst); SpvInst* spvInst = scopeInst; - emitOperand(a); - emitOperand(b); - emitOperand(c); - emitOperand(d); - emitOperand(e); - parent->addInst(spvInst); - return spvInst; - } + m_spvTypeInsts[key] = spvInst; + + // Emit our operands, this time with the resultId too + emitOperand(resultId); + m_operandStack.addRange(ourOperands); - template<typename OperandEmitFunc> - SpvInst* emitInstCustomOperandFunc(SpvInstParent* parent, IRInst* irInst, SpvOp opcode, const OperandEmitFunc& f) - { - InstConstructScope scopeInst(this, opcode, irInst); - SpvInst* spvInst = scopeInst; - f(); parent->addInst(spvInst); return spvInst; } + // + // Specific emit funcs + // + +# define SLANG_IN_SPIRV_EMIT_CONTEXT +# include "slang-emit-spirv-ops.h" +# undef SLANG_IN_SPIRV_EMIT_CONTEXT + /// The SPIRV OpExtInstImport inst that represents the GLSL450 /// extended instruction set. SpvInst* m_glsl450ExtInst = nullptr; @@ -910,11 +1024,9 @@ struct SPIRVEmitContext { if (m_glsl450ExtInst) return m_glsl450ExtInst; - m_glsl450ExtInst = emitInst( + m_glsl450ExtInst = emitOpExtInstImport( getSection(SpvLogicalSectionID::ExtIntInstImports), nullptr, - SpvOpExtInstImport, - kResultID, UnownedStringSlice("GLSL.std.450")); return m_glsl450ExtInst; } @@ -937,7 +1049,11 @@ struct SPIRVEmitContext // For now we will always emit the `Shader` capability, // since every Vulkan shader module will use it. // - emitInst(getSection(SpvLogicalSectionID::Capabilities), nullptr, SpvOpCapability, SpvCapabilityShader); + emitOpCapability( + getSection(SpvLogicalSectionID::Capabilities), + nullptr, + SpvCapabilityShader + ); // [2.4: Logical Layout of a Module] // @@ -953,7 +1069,12 @@ struct SPIRVEmitContext // a requirement, but it is what glslang produces, // so we will use it for now. // - emitInst(getSection(SpvLogicalSectionID::MemoryModel), nullptr, SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450); + emitOpMemoryModel( + getSection(SpvLogicalSectionID::MemoryModel), + nullptr, + SpvAddressingModelLogical, + SpvMemoryModelGLSL450 + ); } Dictionary<UnownedStringSlice, SpvInst*> m_extensionInsts; @@ -962,8 +1083,11 @@ struct SPIRVEmitContext SpvInst* result = nullptr; if (m_extensionInsts.tryGetValue(name, result)) return result; - result = - emitInst(getSection(SpvLogicalSectionID::Extensions), nullptr, SpvOpExtension, name); + result = emitOpExtension( + getSection(SpvLogicalSectionID::Extensions), + nullptr, + name + ); m_extensionInsts[name] = result; return result; } @@ -983,31 +1107,6 @@ struct SPIRVEmitContext Dictionary<SpvTypeInstKey, SpvInst*> m_spvTypeInsts; - // Emits a SPV Inst that represents a type, with deduplications since - // our IR doesn't currently guarantee types are unique in generated SPV. - SpvInst* emitTypeInst(IRInst* typeInst, SpvOp opcode, ArrayView<SpvWord> operands) - { - SpvTypeInstKey key; - key.words.add((SpvWord)opcode); - for (auto op : operands) - key.words.add(op); - SpvInst* result = nullptr; - if (m_spvTypeInsts.tryGetValue(key, result)) - { - return result; - } - result = emitInstCustomOperandFunc( - getSection(SpvLogicalSectionID::ConstantsAndTypes), typeInst, opcode, [&]() { - emitOperand(kResultID); - for (auto op : operands) - { - emitOperand(op); - } - }); - m_spvTypeInsts[key] = result; - return result; - } - // Next, let's look at emitting some of the instructions // that can occur at global scope. @@ -1017,21 +1116,13 @@ struct SPIRVEmitContext /// SpvInst* emitGlobalInst(IRInst* inst) { - switch( inst->getOp() ) + switch( inst->getOp() & kIROpMask_OpMask ) { // [3.32.6: Type-Declaration Instructions] // -#define CASE(IROP, SPVOP) \ - case IROP: return emitTypeInst(inst, SPVOP, ArrayView<SpvWord>()); - - // > OpTypeVoid - CASE(kIROp_VoidType, SpvOpTypeVoid); - - // > OpTypeBool - CASE(kIROp_BoolType, SpvOpTypeBool); - -#undef CASE + case kIROp_VoidType: return emitOpTypeVoid(inst); + case kIROp_BoolType: return emitOpTypeBool(inst); // > OpTypeInt @@ -1045,10 +1136,11 @@ struct SPIRVEmitContext case kIROp_Int64Type: { const IntInfo i = getIntTypeInfo(as<IRType>(inst)); - return emitTypeInst( + return emitOpTypeInt( inst, - SpvOpTypeInt, - makeArray(static_cast<SpvWord>(i.width), SpvWord{i.isSigned}).getView()); + SpvLiteralInteger::from32(int32_t(i.width)), + SpvLiteralInteger::from32(i.isSigned) + ); } // > OpTypeFloat @@ -1058,7 +1150,7 @@ struct SPIRVEmitContext case kIROp_DoubleType: { const FloatInfo i = getFloatingTypeInfo(as<IRType>(inst)); - return emitTypeInst(inst, SpvOpTypeFloat, makeArray(static_cast<SpvWord>(i.width)).getView()); + return emitOpTypeFloat(inst, SpvLiteralInteger::from32(int32_t(i.width))); } case kIROp_PtrType: @@ -1073,22 +1165,24 @@ struct SPIRVEmitContext storageClass = (SpvStorageClass)ptrType->getAddressSpace(); if (storageClass == SpvStorageClassStorageBuffer) ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_storage_buffer_storage_class")); - auto operands = makeArray<SpvWord>( - (SpvWord)storageClass, getID(ensureInst(inst->getOperand(0)))); - return emitTypeInst( - inst, SpvOpTypePointer, operands.getView()); + return emitOpTypePointer( + inst, + storageClass, + inst->getOperand(0) + ); } + case kIROp_ConstantBufferType: + SLANG_UNEXPECTED("Constant buffer type remaining in spirv emit"); case kIROp_StructType: { - auto spvStructType = emitInstCustomOperandFunc( - getSection(SpvLogicalSectionID::ConstantsAndTypes), inst, SpvOpTypeStruct, [&]() { - emitOperand(kResultID); - for (auto field : static_cast<IRStructType*>(inst)->getFields()) - { - emitOperand(field->getFieldType()); - // TODO: decorate offset - } - }); + List<IRType*> types; + // TODO: decorate offset + for (auto field : static_cast<IRStructType*>(inst)->getFields()) + types.add(field->getFieldType()); + auto spvStructType = emitOpTypeStruct( + inst, + types + ); emitDecorations(inst, getID(spvStructType)); return spvStructType; } @@ -1107,15 +1201,13 @@ struct SPIRVEmitContext static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(), static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(), nullptr); - auto matrixSPVType = emitInst( - getSection(SpvLogicalSectionID::ConstantsAndTypes), + const auto columnCount = static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(); + auto matrixSPVType = emitOpTypeMatrix( inst, - SpvOpTypeMatrix, - kResultID, vectorSpvType, - (SpvWord)static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue()); + SpvLiteralInteger::from32(int32_t(columnCount)) + ); // TODO: properly compute matrix stride. - auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(); uint32_t stride = 0; switch (columnCount) { @@ -1132,39 +1224,85 @@ struct SPIRVEmitContext default: break; } - emitInst( + // TODO: This decoration is not legal here. It must be placed + // on a struct member (which may entail wrapping matrices) + emitOpDecorate( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + matrixSPVType, + SpvDecorationRowMajor); + emitOpDecorateMatrixStride( getSection(SpvLogicalSectionID::Annotations), nullptr, - SpvOpDecorate, matrixSPVType, - SpvDecorationRowMajor, - SpvDecorationMatrixStride, - stride); + SpvLiteralInteger::from32(stride)); return matrixSPVType; } + case kIROp_ArrayType: case kIROp_UnsizedArrayType: { - auto elementType = static_cast<IRUnsizedArrayType*>(inst)->getElementType(); - auto runtimeArrayType = emitInst( - getSection(SpvLogicalSectionID::ConstantsAndTypes), - nullptr, - SpvOpTypeRuntimeArray, - kResultID, - elementType); + const auto elementType = static_cast<IRArrayTypeBase*>(inst)->getElementType(); + const auto arrayType = inst->getOp() == kIROp_ArrayType + ? emitOpTypeArray(inst, elementType, static_cast<IRArrayTypeBase*>(inst)->getElementCount()) + : emitOpTypeRuntimeArray(inst, elementType); // TODO: properly decorate stride. + // TODO: don't do this more than once IRSizeAndAlignment sizeAndAlignment; getNaturalSizeAndAlignment(this->m_targetRequest, elementType, &sizeAndAlignment); - emitInst( + emitOpDecorateArrayStride( getSection(SpvLogicalSectionID::Annotations), nullptr, - SpvOpDecorate, - runtimeArrayType, - SpvDecorationArrayStride, - (SpvWord)sizeAndAlignment.getStride()); - return runtimeArrayType; + arrayType, + SpvLiteralInteger::from32(int32_t(sizeAndAlignment.getStride()))); + return arrayType; + } + + case kIROp_TextureType: + { + const auto texTypeInst = as<IRTextureType>(inst); + const auto sampledType = texTypeInst->getElementType(); + SpvDim dim = SpvDim1D; // Silence uninitialized warnings from msvc... + switch(texTypeInst->GetBaseShape()) + { + case TextureFlavor::Shape1D: + case TextureFlavor::Shape1DArray: + dim = SpvDim1D; + break; + case TextureFlavor::Shape2D: + case TextureFlavor::Shape2DArray: + dim = SpvDim2D; + break; + case TextureFlavor::Shape3D: + dim = SpvDim3D; + break; + case TextureFlavor::ShapeCube: + case TextureFlavor::ShapeCubeArray: + dim = SpvDimCube; + break; + case TextureFlavor::ShapeBuffer: + dim = SpvDimBuffer; + break; + } + bool arrayed = texTypeInst->isArray(); + SpvWord depth = 2; // No knowledge of if this is a depth image + bool ms = texTypeInst->isMultisample(); + // TODO: can we do better here? + SpvWord sampled = 0; // Only known at run time + // TODO: can we do better? + SpvImageFormat format = SpvImageFormatUnknown; + return emitOpTypeImage( + inst, + sampledType, + dim, + SpvLiteralInteger::from32(depth), + SpvLiteralInteger::from32(arrayed), + SpvLiteralInteger::from32(ms), + SpvLiteralInteger::from32(sampled), + format + ); } - // > OpTypeImage - // > OpTypeSampler + case kIROp_SamplerStateType: + return emitOpTypeSampler(inst); // > OpTypeArray // > OpTypeRuntimeArray // > OpTypeOpaque @@ -1177,7 +1315,11 @@ struct SPIRVEmitContext // with the result-type operand coming first, // followed by operand sfor all the parameter types. // - return emitInst(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst, SpvOpTypeFunction, kResultID, OperandsOf(inst)); + return emitOpTypeFunction( + inst, + static_cast<IRFuncType*>(inst)->getResultType(), + static_cast<IRFuncType*>(inst)->getParamTypes() + ); case kIROp_RateQualifiedType: { @@ -1208,10 +1350,21 @@ struct SPIRVEmitContext return emitGlobalVar(as<IRGlobalVar>(inst)); // ... + case kIROp_Specialize: + { + const auto s = as<IRSpecialize>(inst); + const auto g = s->getBase(); + const auto e = + "Specialize instruction remains in IR for SPIR-V emit, is something undefined?\n" + + dumpIRToString(g); + SLANG_UNEXPECTED(e.getBuffer()); + } default: - String e = "Unhandled global inst in spirv-emit: " - + dumpIRToString(inst, {IRDumpOptions::Mode::Detailed, 0}); - SLANG_UNIMPLEMENTED_X(e.begin()); + { + String e = "Unhandled global inst in spirv-emit:\n" + + dumpIRToString(inst, {IRDumpOptions::Mode::Detailed, 0}); + SLANG_UNIMPLEMENTED_X(e.begin()); + } } } @@ -1228,9 +1381,11 @@ struct SPIRVEmitContext builder.getBasicType(baseType), builder.getIntValue(builder.getIntType(), elementCount)); } - auto operands = - makeArray<SpvWord>(getID(ensureInst(inst->getElementType())), (SpvWord)elementCount); - auto result = emitTypeInst(inst, SpvOpTypeVector, operands.getView()); + auto result = emitOpTypeVector( + inst, + inst->getElementType(), + SpvLiteralInteger::from32(int32_t(elementCount)) + ); return result; } @@ -1246,49 +1401,44 @@ struct SPIRVEmitContext break; case LayoutResourceKind::VaryingInput: - emitInst( + emitOpDecorateLocation( getSection(SpvLogicalSectionID::Annotations), nullptr, - SpvOpDecorate, varInst, - SpvDecorationLocation, - (SpvWord)index); - emitInst( + SpvLiteralInteger::from32(int32_t(index)) + ); + emitOpDecorateIndex( getSection(SpvLogicalSectionID::Annotations), nullptr, - SpvOpDecorate, varInst, - SpvDecorationIndex, - (SpvWord)space); + SpvLiteralInteger::from32(int32_t(space)) + ); break; case LayoutResourceKind::VaryingOutput: - emitInst( + emitOpDecorateLocation( getSection(SpvLogicalSectionID::Annotations), nullptr, - SpvOpDecorate, varInst, - SpvDecorationLocation, - (SpvWord)index); + SpvLiteralInteger::from32(int32_t(index)) + ); if (space) { - emitInst( + emitOpDecorateIndex( getSection(SpvLogicalSectionID::Annotations), nullptr, - SpvOpDecorate, varInst, - SpvDecorationIndex, - (SpvWord)space); + SpvLiteralInteger::from32(int32_t(space)) + ); } break; case LayoutResourceKind::SpecializationConstant: - emitInst( + emitOpDecorateSpecId( getSection(SpvLogicalSectionID::Annotations), nullptr, - SpvOpDecorate, varInst, - SpvDecorationSpecId, - (SpvWord)index); + SpvLiteralInteger::from32(int32_t(index)) + ); break; case LayoutResourceKind::ConstantBuffer: @@ -1296,20 +1446,18 @@ struct SPIRVEmitContext case LayoutResourceKind::UnorderedAccess: case LayoutResourceKind::SamplerState: case LayoutResourceKind::DescriptorTableSlot: - emitInst( + emitOpDecorateBinding( getSection(SpvLogicalSectionID::Annotations), nullptr, - SpvOpDecorate, varInst, - SpvDecorationBinding, - (SpvWord)index); - emitInst( + SpvLiteralInteger::from32(int32_t(index)) + ); + emitOpDecorateDescriptorSet( getSection(SpvLogicalSectionID::Annotations), nullptr, - SpvOpDecorate, varInst, - SpvDecorationDescriptorSet, - (SpvWord)space); + SpvLiteralInteger::from32(int32_t(space)) + ); break; default: break; @@ -1331,13 +1479,12 @@ struct SPIRVEmitContext registerInst(param, systemValInst); return systemValInst; } - auto varInst = emitInst( + auto varInst = emitOpVariable( getSection(SpvLogicalSectionID::GlobalVariables), param, - SpvOpVariable, param->getDataType(), - kResultID, - storageClass); + storageClass + ); emitVarLayout(varInst, layout); return varInst; } @@ -1346,21 +1493,20 @@ struct SPIRVEmitContext SpvInst* emitGlobalVar(IRGlobalVar* globalVar) { auto layout = getVarLayout(globalVar); - SLANG_ASSERT(layout); auto storageClass = SpvStorageClassUniform; if (auto ptrType = as<IRPtrTypeBase>(globalVar->getDataType())) { if (ptrType->hasAddressSpace()) storageClass = (SpvStorageClass)ptrType->getAddressSpace(); } - auto varInst = emitInst( + auto varInst = emitOpVariable( getSection(SpvLogicalSectionID::GlobalVariables), globalVar, - SpvOpVariable, globalVar->getDataType(), - kResultID, - storageClass); - emitVarLayout(varInst, layout); + storageClass + ); + if(layout) + emitVarLayout(varInst, layout); return varInst; } @@ -1438,11 +1584,13 @@ struct SPIRVEmitContext // type is given as a later operand. Slan IR instead uses // the type of a function instruction store, you know, its *type*. // - SpvInst* spvFunc = emitInst(section, irFunc, SpvOpFunction, + SpvInst* spvFunc = emitOpFunction( + section, + irFunc, irFunc->getDataType()->getResultType(), - kResultID, spvFunctionControl, - irFunc->getDataType()); + irFunc->getDataType() + ); // > OpFunctionParameter // @@ -1473,7 +1621,7 @@ struct SPIRVEmitContext // for( auto irBlock : irFunc->getBlocks() ) { - auto spvBlock = emitInst(spvFunc, irBlock, SpvOpLabel, kResultID); + auto spvBlock = emitOpLabel(spvFunc, irBlock); if (irBlock == irFunc->getFirstBlock()) { // OpVariable @@ -1494,7 +1642,7 @@ struct SPIRVEmitContext { if (irInst->getOp() == kIROp_loop) { - emitInst(spvFunc, irInst, SpvOpLabel, kResultID); + emitOpLabel(spvFunc, irInst); } } } @@ -1557,7 +1705,7 @@ struct SPIRVEmitContext // structure we will make the `OpFunctionEnd` be the last child of // the `OpFunction`. // - emitInst(spvFunc, nullptr, SpvOpFunctionEnd); + emitOpFunctionEnd(spvFunc, nullptr); // We will emit any decorations pertinent to the function to the // appropriate section of the module. @@ -1597,7 +1745,7 @@ struct SPIRVEmitContext { default: { - String e = "Unhandled local inst in spirv-emit: " + String e = "Unhandled local inst in spirv-emit:\n" + dumpIRToString(inst, {IRDumpOptions::Mode::Detailed, 0}); SLANG_UNIMPLEMENTED_X(e.getBuffer()); } @@ -1606,7 +1754,7 @@ struct SPIRVEmitContext case kIROp_Var: return emitVar(parent, inst); case kIROp_Call: - return emitCall(parent, inst); + return emitCall(parent, static_cast<IRCall*>(inst)); case kIROp_FieldAddress: return emitFieldAddress(parent, as<IRFieldAddress>(inst)); case kIROp_FieldExtract: @@ -1615,6 +1763,8 @@ struct SPIRVEmitContext return emitGetElementPtr(parent, as<IRGetElementPtr>(inst)); case kIROp_GetElement: return emitGetElement(parent, as<IRGetElement>(inst)); + case kIROp_MakeStruct: + return emitCompositeConstruct(parent, inst); case kIROp_Load: return emitLoad(parent, as<IRLoad>(inst)); case kIROp_Store: @@ -1643,8 +1793,12 @@ struct SPIRVEmitContext // TODO: break emitConstruct into separate functions for each opcode. return emitConstruct(parent, inst); case kIROp_BitCast: - return emitInst( - parent, inst, SpvOpBitcast, inst->getDataType(), kResultID, inst->getOperand(0)); + return emitOpBitcast( + parent, + inst, + inst->getDataType(), + inst->getOperand(0) + ); case kIROp_Add: case kIROp_Sub: case kIROp_Mul: @@ -1670,16 +1824,11 @@ struct SPIRVEmitContext return emitArithmetic(parent, inst); case kIROp_Return: if (as<IRReturn>(inst)->getVal()->getOp() == kIROp_VoidLit) - { - return emitInst(parent, inst, SpvOpReturn); - } + return emitOpReturn(parent, inst); else - { - return emitInst( - parent, inst, SpvOpReturnValue, as<IRReturn>(inst)->getVal()); - } + return emitOpReturnValue(parent, inst, as<IRReturn>(inst)->getVal()); case kIROp_discard: - return emitInst(parent, inst, SpvOpKill); + return emitOpKill(parent, inst); case kIROp_unconditionalBranch: { // If we are jumping to the main block of a loop, @@ -1688,15 +1837,9 @@ struct SPIRVEmitContext auto targetBlock = as<IRUnconditionalBranch>(inst)->getTargetBlock(); IRInst* loopInst = nullptr; if (isLoopTargetBlock(targetBlock, loopInst)) - { - return emitInst(parent, inst, SpvOpBranch, getIRInstSpvID(loopInst)); - } + return emitOpBranch(parent, inst, getIRInstSpvID(loopInst)); // Otherwise, emit a normal branch inst into the target block. - return emitInst( - parent, - inst, - SpvOpBranch, - getIRInstSpvID(targetBlock)); + return emitOpBranch(parent, inst, getIRInstSpvID(targetBlock)); } case kIROp_loop: { @@ -1710,7 +1853,7 @@ struct SPIRVEmitContext // Note: the body of the loop header block is emitted // after everything else to ensure Phi instructions (which come // from the actual loop target block) are emitted first. - emitInst(parent, nullptr, SpvOpBranch, blockId); + emitOpBranch(parent, nullptr, blockId); return block; } @@ -1718,26 +1861,22 @@ struct SPIRVEmitContext { auto ifelseInst = as<IRIfElse>(inst); auto afterBlockID = getIRInstSpvID(ifelseInst->getAfterBlock()); - emitInst( - parent, - nullptr, - SpvOpSelectionMerge, - afterBlockID, - 0); + emitOpSelectionMerge(parent, nullptr, afterBlockID, SpvSelectionControlMaskNone); auto falseLabel = ifelseInst->getFalseBlock(); - return emitInst( + return emitOpBranchConditional( parent, inst, - SpvOpBranchConditional, ifelseInst->getCondition(), ifelseInst->getTrueBlock(), - falseLabel ? getID(ensureInst(falseLabel)) : afterBlockID); + falseLabel ? getID(ensureInst(falseLabel)) : afterBlockID, + makeArray<SpvLiteralInteger>() + ); } case kIROp_Switch: { auto switchInst = as<IRSwitch>(inst); auto mergeBlockID = getIRInstSpvID(switchInst->getBreakLabel()); - emitInst(parent, nullptr, SpvOpSelectionMerge, mergeBlockID, 0); + emitOpSelectionMerge(parent, nullptr, mergeBlockID, SpvSelectionControlMaskNone); return emitInstCustomOperandFunc(parent, inst, SpvOpSwitch, [&]() { emitOperand(switchInst->getCondition()); auto defaultLabel = switchInst->getDefaultLabel(); @@ -1754,9 +1893,22 @@ struct SPIRVEmitContext }); } case kIROp_Unreachable: - return emitInst(parent, inst, SpvOpUnreachable); + return emitOpUnreachable(parent, inst); case kIROp_conditionalBranch: SLANG_UNEXPECTED("Unstructured branching is not supported by SPIRV."); + case kIROp_MakeVector: + return emitConstruct(parent, inst); + case kIROp_MakeVectorFromScalar: + { + const auto scalar = inst->getOperand(0); + const auto vecTy = as<IRVectorType>(inst->getDataType()); + SLANG_ASSERT(vecTy); + const auto numElems = as<IRIntLit>(vecTy->getElementCount()); + SLANG_ASSERT(numElems); + return emitSplat(parent, inst, scalar, numElems->getValue()); + } + case kIROp_MakeArray: + return emitConstruct(parent, inst); } } @@ -1767,86 +1919,29 @@ struct SPIRVEmitContext case kIROp_IntLit: { auto value = as<IRIntLit>(inst)->getValue(); - switch (as<IRBasicType>(inst->getDataType())->getBaseType()) - { - case BaseType::Int64: - case BaseType::UInt64: - case BaseType::IntPtr: - case BaseType::UIntPtr: - return emitInst( - getSection(SpvLogicalSectionID::ConstantsAndTypes), - inst, - SpvOpConstant, - inst->getDataType(), - kResultID, - (SpvWord)(value & 0xFFFFFFFF), - (SpvWord)((value >> 32) & 0xFFFFFFFF)); - default: - return emitInst( - getSection(SpvLogicalSectionID::ConstantsAndTypes), - inst, - SpvOpConstant, - inst->getDataType(), - kResultID, - (SpvWord)value); - } + return emitIntConstant(value, inst->getDataType(), inst); } case kIROp_FloatLit: { - auto value = as<IRConstant>(inst)->value.floatVal; - switch (as<IRBasicType>(inst->getDataType())->getBaseType()) - { - case BaseType::Half: - return emitInst( - getSection(SpvLogicalSectionID::ConstantsAndTypes), - inst, - SpvOpConstant, - inst->getDataType(), - kResultID, - (SpvWord)(FloatToHalf((float)value))); - case BaseType::Float: - return emitInst( - getSection(SpvLogicalSectionID::ConstantsAndTypes), - inst, - SpvOpConstant, - inst->getDataType(), - kResultID, - (SpvWord)(FloatAsInt((float)value))); - case BaseType::Double: - { - auto ival = DoubleAsInt64(value); - return emitInst( - getSection(SpvLogicalSectionID::ConstantsAndTypes), - inst, - SpvOpConstant, - inst->getDataType(), - kResultID, - (SpvWord)(ival&0xFFFFFFFF), - (SpvWord)(ival>>32)); - } - default: - return nullptr; - } + const auto value = as<IRConstant>(inst)->value.floatVal; + const auto type = inst->getDataType(); + return emitFloatConstant(value, type, inst); } case kIROp_BoolLit: { - if (as<IRBoolLit>(inst)->getValue()) + if (cast<IRBoolLit>(inst)->getValue()) { - return emitInst( - getSection(SpvLogicalSectionID::ConstantsAndTypes), + return emitOpConstantTrue( inst, - SpvOpConstantTrue, - inst->getDataType(), - kResultID); + inst->getDataType() + ); } else { - return emitInst( - getSection(SpvLogicalSectionID::ConstantsAndTypes), + return emitOpConstantFalse( inst, - SpvOpConstantFalse, - inst->getDataType(), - kResultID); + inst->getDataType() + ); } } default: @@ -1897,6 +1992,39 @@ struct SPIRVEmitContext default: break; + case kIROp_LayoutDecoration: + { + // Basic offsets for structs used in buffers + if(const auto typeLayout = as<IRTypeLayout>(as<IRLayoutDecoration>(decoration)->getLayout())) + { + if(const auto structTypeLayout = as<IRStructTypeLayout>(typeLayout)) + { + auto section = getSection(SpvLogicalSectionID::Annotations); + SpvWord i = 0; + for(const auto fieldLayoutAttr : structTypeLayout->getFieldLayoutAttrs()) + { + if(const auto structFieldLayoutAttr = as<IRStructFieldLayoutAttr>(fieldLayoutAttr)) + { + const auto varLayout = structFieldLayoutAttr->getLayout(); + if(const auto varOffsetAttr = varLayout->findOffsetAttr(LayoutResourceKind::Uniform)) + { + const auto offset = static_cast<SpvWord>(varOffsetAttr->getOffset()); + emitOpMemberDecorateOffset( + section, + fieldLayoutAttr, + dstID, + SpvLiteralInteger::from32(i), + SpvLiteralInteger::from32(offset) + ); + } + } + ++i; + } + } + } + } + break; + // [3.32.2. Debug Instructions] // // > OpName @@ -1905,7 +2033,11 @@ struct SPIRVEmitContext { auto section = getSection(SpvLogicalSectionID::DebugNames); auto nameHint = cast<IRNameHintDecoration>(decoration); - emitInst(section, decoration, SpvOpName, dstID, nameHint->getName()); + // We can't associate this spirv instruction with our + // irInstruction, our instruction may be a hint on several + // values, however this decoration is specific to a single + // dstID. + emitOpName(section, nullptr, dstID, nameHint->getName()); } break; @@ -1932,23 +2064,27 @@ struct SPIRVEmitContext auto entryPointDecor = cast<IREntryPointDecoration>(decoration); auto spvStage = mapStageToExecutionModel(entryPointDecor->getProfile().getStage()); auto name = entryPointDecor->getName()->getStringSlice(); - emitInstCustomOperandFunc(section, decoration, SpvOpEntryPoint, [&]() { - emitOperand(spvStage); - emitOperand(dstID); - emitOperand(name); - // `interface` part: reference all global variables that are used by this entrypoint. - // TODO: we may want to perform more accurate tracking. - for (auto globalInst : m_irModule->getModuleInst()->getChildren()) + List<IRInst*> params; + // `interface` part: reference all global variables that are used by this entrypoint. + // TODO: we may want to perform more accurate tracking. + for (auto globalInst : m_irModule->getModuleInst()->getChildren()) + { + switch (globalInst->getOp()) { - switch (globalInst->getOp()) - { - case kIROp_GlobalVar: - case kIROp_GlobalParam: - emitOperand(getIRInstSpvID(globalInst)); - break; - } + case kIROp_GlobalVar: + case kIROp_GlobalParam: + params.add(globalInst); + break; } - }); + } + emitOpEntryPoint( + section, + decoration, + spvStage, + dstID, + name, + params + ); } break; @@ -1969,29 +2105,25 @@ struct SPIRVEmitContext // in those positions in the Slang IR). // auto numThreads = cast<IRNumThreadsDecoration>(decoration); - emitInst(section, decoration, SpvOpExecutionMode, dstID, SpvExecutionModeLocalSize, - SpvWord(numThreads->getX()->getValue()), - SpvWord(numThreads->getY()->getValue()), - SpvWord(numThreads->getZ()->getValue())); + emitOpExecutionModeLocalSize( + section, + decoration, + dstID, + SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())), + SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())), + SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue())) + ); } break; case kIROp_SPIRVBufferBlockDecoration: { - emitInst( + emitOpDecorate( getSection(SpvLogicalSectionID::Annotations), decoration, - SpvOpDecorate, - dstID, - SpvDecorationBlock); - emitInst( - getSection(SpvLogicalSectionID::Annotations), - nullptr, - SpvOpMemberDecorate, dstID, - 0, - SpvDecorationOffset, - 0); + SpvDecorationBlock + ); } break; // ... @@ -2035,20 +2167,18 @@ struct SPIRVEmitContext builder.setInsertBefore(type); auto ptrType = as<IRPtrTypeBase>(type); SLANG_ASSERT(ptrType && "`getBuiltinGlobalVar`: `type` must be ptr type."); - auto varInst = emitInst( + auto varInst = emitOpVariable( getSection(SpvLogicalSectionID::GlobalVariables), nullptr, - SpvOpVariable, type, - kResultID, - (SpvStorageClass)ptrType->getAddressSpace()); - emitInst( + static_cast<SpvStorageClass>(ptrType->getAddressSpace()) + ); + emitOpDecorateBuiltIn( getSection(SpvLogicalSectionID::Annotations), nullptr, - SpvOpDecorate, varInst, - SpvDecorationBuiltIn, - builtinVal); + builtinVal + ); m_builtinGlobalVars[builtinVal] = varInst; return varInst; } @@ -2078,7 +2208,7 @@ struct SPIRVEmitContext SpvInst* emitParam(SpvInstParent* parent, IRInst* inst) { - return emitInst(parent, inst, SpvOpFunctionParameter, inst->getFullType(), kResultID); + return emitOpFunctionParameter(parent, inst, inst->getFullType()); } SpvInst* emitVar(SpvInstParent* parent, IRInst* inst) @@ -2090,7 +2220,7 @@ struct SPIRVEmitContext { storageClass = (SpvStorageClass)ptrType->getAddressSpace(); } - return emitInst(parent, inst, SpvOpVariable, inst->getFullType(), kResultID, storageClass); + return emitOpVariable(parent, inst, inst->getFullType(), storageClass); } /// Cached `IRParam` indices in an `IRBlock`. For use in `getParamIndexInBlock`. @@ -2143,29 +2273,29 @@ struct SPIRVEmitContext void emitLoopHeaderBlock(IRLoop* loopInst, SpvInst* loopHeaderBlock) { - SpvWord loopControl = 0; + SpvLoopControlMask loopControl = SpvLoopControlMaskNone; if (auto loopControlDecoration = loopInst->findDecoration<IRLoopControlDecoration>()) { switch (loopControlDecoration->getMode()) { case IRLoopControl::kIRLoopControl_Unroll: - loopControl = 0x1; + loopControl = SpvLoopControlUnrollMask; break; case IRLoopControl::kIRLoopControl_Loop: - loopControl = 0x2; + loopControl = SpvLoopControlDontUnrollMask; break; default: break; } } - emitInst( + emitOpLoopMerge( loopHeaderBlock, nullptr, - SpvOpLoopMerge, getIRInstSpvID(loopInst->getBreakBlock()), getIRInstSpvID(loopInst->getContinueBlock()), - loopControl); - emitInst(loopHeaderBlock, nullptr, SpvOpBranch, loopInst->getTargetBlock()); + loopControl + ); + emitOpBranch(loopHeaderBlock, nullptr, loopInst->getTargetBlock()); } SpvInst* emitPhi(SpvInstParent* parent, IRParam* inst) @@ -2225,9 +2355,9 @@ struct SPIRVEmitContext }); } - SpvInst* emitCall(SpvInstParent* parent, IRInst* inst) + SpvInst* emitCall(SpvInstParent* parent, IRCall* inst) { - auto funcValue = inst->getOperand(0); + auto funcValue = inst->getCallee(); // Does this function declare any requirements. handleRequiredCapabilities(funcValue); @@ -2241,8 +2371,13 @@ struct SPIRVEmitContext } else { - return emitInst( - parent, inst, SpvOpFunctionCall, inst->getFullType(), kResultID, OperandsOf(inst)); + return emitOpFunctionCall( + parent, + inst, + inst->getFullType(), + funcValue, + inst->getArgsList() + ); } } @@ -2274,14 +2409,16 @@ struct SPIRVEmitContext // different storage-class qualifier. We need to pre-create these // storage-class-qualified result pointer types so they can be used // during inlining of the snippet. - if (auto oldPtrType = as<IRPtrTypeBase>(inst->getDataType())) { - for (auto storageClass : snippet->usedResultTypeStorageClasses) + IRBuilder builder(m_irModule); + builder.setInsertBefore(inst); + for (auto storageClass : snippet->usedPtrResultTypeStorageClasses) { - IRBuilder builder(m_irModule); - builder.setInsertBefore(inst); auto newPtrType = builder.getPtrType( - oldPtrType->getOp(), oldPtrType->getValueType(), storageClass); + kIROp_PtrType, + inst->getDataType(), + storageClass + ); context.qualifiedResultTypes[storageClass] = newPtrType; } } @@ -2309,14 +2446,11 @@ struct SPIRVEmitContext auto floatType = builder.getType(kIROp_FloatType); auto element1 = emitFloatConstant(constant.floatValues[0], floatType); auto element2 = emitFloatConstant(constant.floatValues[1], floatType); - result = emitInst( - getSection(SpvLogicalSectionID::ConstantsAndTypes), + result = emitOpConstantComposite( nullptr, - SpvOpConstantComposite, builder.getVectorType(floatType, builder.getIntValue(builder.getIntType(), 2)), - kResultID, - element1, - element2); + makeArray(element1, element2) + ); } break; case SpvSnippet::ASMType::Int: @@ -2327,14 +2461,11 @@ struct SPIRVEmitContext auto uintType = builder.getType(kIROp_UIntType); auto element1 = emitIntConstant((IRIntegerValue)constant.intValues[0], uintType); auto element2 = emitIntConstant((IRIntegerValue)constant.intValues[1], uintType); - result = emitInst( - getSection(SpvLogicalSectionID::ConstantsAndTypes), + result = emitOpConstantComposite( nullptr, - SpvOpConstantComposite, builder.getVectorType(uintType, builder.getIntValue(builder.getIntType(), 2)), - kResultID, - element1, - element2); + makeArray(element1, element2) + ); } break; } @@ -2356,6 +2487,9 @@ struct SPIRVEmitContext case SpvSnippet::ASMType::Int: irType = builder.getIntType(); break; + case SpvSnippet::ASMType::UInt: + irType = builder.getUIntType(); + break; case SpvSnippet::ASMType::Float2: irType = builder.getVectorType( builder.getType(kIROp_FloatType), builder.getIntValue(builder.getIntType(), 2)); @@ -2533,23 +2667,27 @@ struct SPIRVEmitContext baseStructType = as<IRStructType>(base->getDataType()); auto structPtrType = builder.getPtrType(baseStructType); - auto varInst = emitInst( - parent, nullptr, SpvOpVariable, structPtrType, kResultID, SpvStorageClassFunction); - emitInst(parent, nullptr, SpvOpStore, varInst, base); + auto varInst = emitOpVariable( + parent, + nullptr, + structPtrType, + SpvStorageClassFunction + ); + emitOpStore(parent, nullptr, varInst, base); baseId = getID(varInst); } - SLANG_ASSERT(baseStructType && "field_address require base to be a struct."); + SLANG_ASSERT(baseStructType && "field_address requires base to be a struct."); auto fieldId = emitIntConstant( getStructFieldId(baseStructType, as<IRStructKey>(fieldAddress->getField())), builder.getIntType()); - return emitInst( + SLANG_ASSERT(as<IRPtrTypeBase>(fieldAddress->getFullType())); + return emitOpAccessChain( parent, fieldAddress, - SpvOpAccessChain, fieldAddress->getFullType(), - kResultID, baseId, - fieldId); + makeArray(fieldId) + ); } SpvInst* emitFieldExtract(SpvInstParent* parent, IRFieldExtract* inst) @@ -2558,100 +2696,84 @@ struct SPIRVEmitContext builder.setInsertBefore(inst); IRStructType* baseStructType = as<IRStructType>(inst->getBase()->getDataType()); - SLANG_ASSERT(baseStructType && "field_extract require base to be a struct."); - auto fieldId = emitIntConstant( - getStructFieldId(baseStructType, as<IRStructKey>(inst->getField())), - builder.getIntType()); + SLANG_ASSERT(baseStructType && "field_extract requires base to be a struct."); + auto fieldId = static_cast<SpvWord>(getStructFieldId( + baseStructType, + as<IRStructKey>(inst->getField()))); - return emitInst( + return emitOpCompositeExtract( parent, inst, - SpvOpCompositeExtract, inst->getDataType(), - kResultID, inst->getBase(), - fieldId); + makeArray(SpvLiteralInteger::from32(fieldId)) + ); } SpvInst* emitGetElementPtr(SpvInstParent* parent, IRGetElementPtr* inst) { auto base = inst->getBase(); SpvWord baseId = 0; - IRArrayType* baseArrayType = nullptr; // Only used in debug build, but we don't want a warning/error for an unused initialized variable - SLANG_UNUSED(baseArrayType); if (auto ptrLikeType = as<IRPointerLikeType>(base->getDataType())) { - baseArrayType = as<IRArrayType>(ptrLikeType->getElementType()); baseId = getID(ensureInst(base)); } else if (auto ptrType = as<IRPtrTypeBase>(base->getDataType())) { - baseArrayType = as<IRArrayType>(ptrType->getValueType()); baseId = getID(ensureInst(base)); } else { SLANG_ASSERT(!"invalid IR: base of getElementPtr must be a pointer."); } - SLANG_ASSERT(baseArrayType && "getElementPtr require base to be an array."); - return emitInst( + SLANG_ASSERT(as<IRPtrTypeBase>(inst->getFullType())); + return emitOpAccessChain( parent, inst, - SpvOpAccessChain, inst->getFullType(), - kResultID, baseId, - inst->getIndex()); + makeArray(inst->getIndex()) + ); } SpvInst* emitGetElement(SpvInstParent* parent, IRGetElement* inst) { auto base = inst->getBase(); - SpvWord baseId = 0; - IRArrayType* baseArrayType = nullptr; - // Only used in debug build, but we don't want a warning/error for an unused initialized variable - SLANG_UNUSED(baseArrayType); - - if (auto ptrLikeType = as<IRPointerLikeType>(base->getDataType())) - { - baseArrayType = as<IRArrayType>(ptrLikeType->getElementType()); - baseId = getID(ensureInst(base)); - } - else if (auto ptrType = as<IRPtrTypeBase>(base->getDataType())) - { - baseArrayType = as<IRArrayType>(ptrType->getValueType()); - baseId = getID(ensureInst(base)); - } - else - { - SLANG_ASSERT(!"invalid IR: base of getElement must be a pointer."); - } - SLANG_ASSERT(baseArrayType && "getElement require base to be an array."); + const auto baseTy = base->getDataType(); + SLANG_ASSERT( + as<IRPointerLikeType>(baseTy) || + as<IRArrayType>(baseTy) || + as<IRVectorType>(baseTy) || + as<IRMatrixType>(baseTy)); IRBuilder builder(m_irModule); builder.setInsertBefore(inst); - auto ptr = emitInst( + auto ptr = emitOpAccessChain( parent, nullptr, - SpvOpAccessChain, builder.getPtrType(inst->getFullType()), - kResultID, - baseId, - inst->getIndex()); - return emitInst(parent, inst, SpvOpLoad, inst->getFullType(), kResultID, ptr); + inst->getBase(), + makeArray(inst->getIndex()) + ); + return emitOpLoad( + parent, + inst, + inst->getFullType(), + ptr + ); } SpvInst* emitLoad(SpvInstParent* parent, IRLoad* inst) { - return emitInst(parent, inst, SpvOpLoad, inst->getDataType(), kResultID, inst->getPtr()); + return emitOpLoad(parent, inst, inst->getDataType(), inst->getPtr()); } SpvInst* emitStore(SpvInstParent* parent, IRStore* inst) { - return emitInst(parent, inst, SpvOpStore, inst->getPtr(), inst->getVal()); + return emitOpStore(parent, inst, inst->getPtr(), inst->getVal()); } SpvInst* emitStructuredBufferLoad(SpvInstParent* parent, IRInst* inst) @@ -2682,14 +2804,14 @@ struct SPIRVEmitContext { if (inst->getElementCount() == 1) { - return emitInst( + const auto index = as<IRIntLit>(inst->getElementIndex(0))->getValue(); + return emitOpCompositeExtract( parent, inst, - SpvOpCompositeExtract, inst->getDataType(), - kResultID, inst->getBase(), - (SpvWord)as<IRIntLit>(inst->getElementIndex(0))->getValue()); + makeArray(SpvLiteralInteger::from32(int32_t(index))) + ); } else { @@ -2727,25 +2849,22 @@ struct SPIRVEmitContext const auto fromInfo = getIntTypeInfo(fromType); const auto toInfo = getIntTypeInfo(toType); - const auto convertWith = [&](auto op){ - return emitInst(parent, inst, op, toTypeV, kResultID, inst->getOperand(0)); - }; if(fromInfo == toInfo) - return convertWith(SpvOpCopyObject); + return emitOpCopyObject(parent, inst, toTypeV, inst->getOperand(0)); else if(fromInfo.width == toInfo.width) - return convertWith(SpvOpBitcast); + return emitOpBitcast(parent, inst, toTypeV, inst->getOperand(0)); else if(!fromInfo.isSigned && !toInfo.isSigned) // unsigned to unsigned, don't sign extend - return convertWith(SpvOpUConvert); + return emitOpUConvert(parent, inst, toTypeV, inst->getOperand(0)); else if(toInfo.isSigned) // unsigned to signed, sign extend - return convertWith(SpvOpSConvert); + return emitOpSConvert(parent, inst, toTypeV, inst->getOperand(0)); else if(fromInfo.isSigned) // signed to unsigned, sign extend - return convertWith(SpvOpSConvert); + return emitOpSConvert(parent, inst, toTypeV, inst->getOperand(0)); else if(fromInfo.isSigned && toInfo.isSigned) // signed to signed, sign extend - return convertWith(SpvOpSConvert); + return emitOpSConvert(parent, inst, toTypeV, inst->getOperand(0)); SLANG_UNREACHABLE(__func__); } @@ -2761,7 +2880,7 @@ struct SPIRVEmitContext SLANG_ASSERT(isFloatingType(toType)); SLANG_ASSERT(!isTypeEqual(fromType, toType)); - return emitInst(parent, inst, SpvOpFConvert, toTypeV, kResultID, inst->getOperand(0)); + return emitOpFConvert(parent, inst, toTypeV, inst->getOperand(0)); } SpvInst* emitIntToFloatCast(SpvInstParent* parent, IRCastIntToFloat* inst) @@ -2776,11 +2895,9 @@ struct SPIRVEmitContext const auto fromInfo = getIntTypeInfo(fromType); - const auto convertWith = [&](auto op){ - return emitInst(parent, inst, op, toTypeV, kResultID, inst->getOperand(0)); - }; - - return convertWith(fromInfo.isSigned ? SpvOpConvertSToF : SpvOpConvertUToF); + return fromInfo.isSigned + ? emitOpConvertSToF(parent, inst, toTypeV, inst->getOperand(0)) + : emitOpConvertUToF(parent, inst, toTypeV, inst->getOperand(0)); } SpvInst* emitFloatToIntCast(SpvInstParent* parent, IRCastFloatToInt* inst) @@ -2795,11 +2912,14 @@ struct SPIRVEmitContext const auto toInfo = getIntTypeInfo(toType); - const auto convertWith = [&](auto op){ - return emitInst(parent, inst, op, toTypeV, kResultID, inst->getOperand(0)); - }; + return toInfo.isSigned + ? emitOpConvertFToS(parent, inst, toTypeV, inst->getOperand(0)) + : emitOpConvertFToU(parent, inst, toTypeV, inst->getOperand(0)); + } - return convertWith(toInfo.isSigned ? SpvOpConvertFToS : SpvOpConvertFToU); + SpvInst* emitCompositeConstruct(SpvInstParent* parent, IRInst* inst) + { + return emitOpCompositeConstruct(parent, inst, inst->getDataType(), OperandsOf(inst)); } SpvInst* emitConstruct(SpvInstParent* parent, IRInst* inst) @@ -2809,20 +2929,16 @@ struct SPIRVEmitContext if (inst->getOperandCount() == 1) { if (inst->getDataType() == inst->getOperand(0)->getDataType()) - return emitInst( + return emitOpCopyObject( parent, inst, - SpvOpCopyObject, inst->getFullType(), - kResultID, inst->getOperand(0)); else - return emitInst( + return emitOpBitcast( parent, inst, - SpvOpBitcast, inst->getFullType(), - kResultID, inst->getOperand(0)); } else @@ -2833,33 +2949,24 @@ struct SPIRVEmitContext } else { - return emitInst( - parent, - inst, - SpvOpCompositeConstruct, - inst->getDataType(), - kResultID, - OperandsOf(inst)); + return emitCompositeConstruct(parent, inst); } } - SpvInst* emitSplat(SpvInstParent* parent, IRInst* scalar, IRIntegerValue numElems) + SpvInst* emitSplat(SpvInstParent* parent, IRInst* inst, IRInst* scalar, IRIntegerValue numElems) { const auto scalarTy = as<IRBasicType>(scalar->getDataType()); + SLANG_ASSERT(scalarTy); const auto spvVecTy = ensureVectorType( scalarTy->getBaseType(), numElems, nullptr); - return emitInstCustomOperandFunc( + return emitOpCompositeConstruct( parent, - nullptr, - SpvOpCompositeConstruct, - [&](){ - emitOperand(spvVecTy); - emitOperand(kResultID); - for(Int i = 0; i < numElems; ++i) - emitOperand(scalar); - }); + inst, + spvVecTy, + List<IRInst*>::makeRepeated(scalar, Index(numElems)) + ); } bool isSignedType(IRType* type) @@ -3019,13 +3126,13 @@ struct SPIRVEmitContext { const auto len = as<IRIntLit>(lVec->getElementCount()); SLANG_ASSERT(len); - return go(l, emitSplat(parent, r, len->getValue())); + return go(l, emitSplat(parent, nullptr, r, len->getValue())); } else if (!lVec && rVec) { const auto len = as<IRIntLit>(rVec->getElementCount()); SLANG_ASSERT(len); - return go(emitSplat(parent, l, len->getValue()), r); + return go(emitSplat(parent, nullptr, l, len->getValue()), r); } return go(l, r); } @@ -3038,11 +3145,11 @@ struct SPIRVEmitContext { if (m_capabilities.add(capability)) { - emitInst( + emitOpCapability( getSection(SpvLogicalSectionID::Capabilities), nullptr, - SpvOpCapability, - capability); + capability + ); } } @@ -3073,12 +3180,6 @@ struct SPIRVEmitContext } } - void diagnoseUnhandledInst(IRInst* inst) - { - m_sink->diagnose( - inst, Diagnostics::unimplemented, "unexpected IR opcode during code emit"); - } - SPIRVEmitContext(IRModule* module, TargetRequest* target, DiagnosticSink* sink) : SPIRVEmitSharedContext(module, target, sink) , m_irModule(module) @@ -3101,6 +3202,16 @@ SlangResult emitSPIRVFromIR( SPIRVEmitContext context(irModule, targetRequest, sink); legalizeIRForSPIRV(&context, irModule, irEntryPoints, codeGenContext); +#if 0 + DiagnosticSinkWriter writer(codeGenContext->getSink()); + dumpIR( + irModule, + {IRDumpOptions::Mode::Simplified, 0}, + "BEFORE SPIR-V EMIT", + codeGenContext->getSourceManager(), + &writer); +#endif + context.emitFrontMatter(); for (auto irEntryPoint : irEntryPoints) { @@ -3110,7 +3221,7 @@ SlangResult emitSPIRVFromIR( spirvOut.addRange( (uint8_t const*) context.m_words.getBuffer(), - context.m_words.getCount() * sizeof(context.m_words[0])); + context.m_words.getCount() * Index(sizeof(context.m_words[0]))); const auto validationResult = debugValidateSPIRV(spirvOut); // If validation isn't available, don't say it failed, it's just a debug |
