diff options
| author | Yong He <yonghe@outlook.com> | 2021-08-12 13:14:15 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-08-12 13:14:15 -0700 |
| commit | 6406523511037987d8b8ab881aea41389afd57eb (patch) | |
| tree | 79f24b6cba377340c2f4d3dcf9fed78fc586f3e0 /source/slang/slang-emit-spirv.cpp | |
| parent | 389d21d982da34815b65b10cae63088c397eecc8 (diff) | |
Further implementation of SPIRV direct emit. (#1920)
* Further implementation of SPIRV direct emit.
This change implements:
- Struct, Vector, Matrix and Unsized Array types.
- Basic arithmetic opcodes, vector construct, swizzle etc.
- getElementPtr, getElement, fieldAddress, extractField.
- SPIRV target intrinsics with SPIRV asm code in stdlib.
- RWStructuredBuffer and StructuredBuffer.
- Pointer storage class propagation.
- Control flow.
* Fix.
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 1261 |
1 files changed, 1225 insertions, 36 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index fe039feb0..37fd673ed 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -2,11 +2,14 @@ #include "slang-emit.h" #include "slang-compiler.h" +#include "slang-emit-base.h" + #include "slang-ir.h" #include "slang-ir-insts.h" - +#include "slang-ir-layout.h" +#include "slang-ir-spirv-snippet.h" +#include "slang-ir-spirv-legalize.h" #include "spirv/unified1/spirv.h" - #include "../core/slang-memory-arena.h" namespace Slang @@ -36,16 +39,7 @@ namespace Slang // [2.3: Physical Layout of a SPIR-V Module and Instruction] // // > A SPIR-V module is a single linear stream of words. -// -// [2.2: Terms] -// -// > Word: 32 bits. -// -// Despite the importance to SPIR-V, the `spirv.h` header doesn't -// define a type for words, so we'll do it here. - /// A SPIR-V word. -typedef uint32_t SpvWord; // [2.3: Physical Layout of a SPIR-V Module and Instruction] // @@ -268,6 +262,14 @@ void SpvInstParent::dumpTo(List<SpvWord>& ioWords) } } +/// The context for inlining a SPV assembly snippet. +struct SpvSnippetEmitContext +{ + SpvInst* resultType; + Dictionary<SpvStorageClass, IRInst*> qualifiedResultTypes; + List<SpvWord> argumentIds; +}; + // Now that we've defined the intermediate data structures we will // use to represent SPIR-V code during emission, we will move on // to defining the main context type that will drive SPIR-V @@ -275,10 +277,14 @@ void SpvInstParent::dumpTo(List<SpvWord>& ioWords) /// Context used for translating a Slang IR module to SPIR-V struct SPIRVEmitContext + : public SourceEmitterBase + , public SPIRVEmitSharedContext { /// The Slang IR module being translated IRModule* m_irModule; + DiagnosticSink* m_sink; + // [2.2: Terms] // // > <id>: A numerical name; the name used to refer to an object, a type, @@ -385,12 +391,36 @@ struct SPIRVEmitContext /// Map a Slang IR instruction to the corresponding SPIR-V instruction Dictionary<IRInst*, SpvInst*> m_mapIRInstToSpvInst; + // Sometimes we need to reserve an ID for an `IRInst` without actually + // emitting it. We use `m_mapIRInstToSpvID` to hold all reserved SpvIDs. + // Use `getIRInstSpvID` to obtain an SpvID for an `IRInst` if the + // `IRInst` may not have been emitted. + Dictionary<IRInst*, SpvWord> m_mapIRInstToSpvID; + /// Register that `irInst` maps to `spvInst` void registerInst(IRInst* irInst, SpvInst* spvInst) { m_mapIRInstToSpvInst.Add(irInst, spvInst); } + /// Get or reserve a SpvID for an IR value. + SpvWord getIRInstSpvID(IRInst* inst) + { + // If we have already emitted an SpvInst for `inst`, return its ID. + SpvInst* spvInst = nullptr; + if (m_mapIRInstToSpvInst.TryGetValue(inst, spvInst)) + return getID(spvInst); + // Check if we have reserved an ID for `inst`. + SpvWord result = 0; + if (m_mapIRInstToSpvID.TryGetValue(inst, result)) + return result; + // Otherwise, reserve a new ID for inst, and register it in `m_mapIRInstToSpvID`. + result = m_nextID; + ++m_nextID; + m_mapIRInstToSpvID[inst] = result; + return result; + } + // When we are emitting an instruction that can produce // a result, we will allocate an <id> to it so that other // instructions can refer to it. @@ -409,6 +439,18 @@ struct SPIRVEmitContext return id; } + struct VectorTypeKey + { + BaseType baseType; + IRIntegerValue elementCount; + HashCode getHashCode() { return combineHash((int)baseType, (HashCode)elementCount); } + bool operator==(const VectorTypeKey& other) + { + return baseType == other.baseType && elementCount == other.elementCount; + } + }; + Dictionary<VectorTypeKey, SpvInst*> m_vectorTypes; + // We will build up `SpvInst`s in a stateful fashion, // mostly for convenience. We could in theory compute // the number of words each instruction needs, then allocate @@ -467,6 +509,8 @@ struct SPIRVEmitContext if(irInst) { registerInst(irInst, spvInst); + // If we have reserved an SpvID for `irInst`, make sure to use it. + m_mapIRInstToSpvID.TryGetValue(irInst, spvInst->id); } // Set up the scope @@ -561,9 +605,6 @@ struct SPIRVEmitContext /// Emit an operand to the current instruction, which references `src` by its <id> void emitOperand(IRInst* src) { - // We first ensure that the `src` instruction has been emitted, - // and then handle it as for any other <id> operand. - // SpvInst* spvSrc = ensureInst(src); emitOperand(getID(spvSrc)); } @@ -629,6 +670,25 @@ struct SPIRVEmitContext emitOperand(getID(m_currentInst)); } + void emitOperand(SpvDecoration decoration) { emitOperand((SpvWord)decoration); } + + void emitOperand(SpvBuiltIn builtin) { emitOperand((SpvWord)builtin); } + void emitOperand(SpvStorageClass val) { emitOperand((SpvWord)val); } + + Dictionary<IRIntegerValue, SpvInst*> m_spvIntConstants; + SpvInst* emitConstant(IRIntegerValue val, IRType* type) + { + SpvInst* result = nullptr; + if (m_spvIntConstants.TryGetValue(val, result)) + return result; + return emitInst( + getSection(SpvLogicalSectionID::Constants), + nullptr, + SpvOpConstant, + type, + kResultID, + (SpvWord)val); + } // As another convenience, there are often cases where // we will want to emit all of the operands of some // IR instruction as <id> operands of a SPIR-V @@ -742,6 +802,16 @@ struct SPIRVEmitContext return spvInst; } + template<typename OperandEmitFunc> + SpvInst* emitInstCustomOperandFunc(SpvInstParent* parent, IRInst* irInst, SpvOp opcode, const OperandEmitFunc& f) + { + InstConstructScope scopeInst(this, opcode, irInst); + SpvInst* spvInst = scopeInst; + f(); + parent->addInst(spvInst); + return spvInst; + } + // Now that we've gotten the core infrastructure out of the way, // let's start looking at emitting some instructions that make // up a SPIR-V module. @@ -826,14 +896,110 @@ struct SPIRVEmitContext CASE(kIROp_DoubleType, 64); #undef CASE - - // > OpTypeVector - // > OpTypeMatrix + case kIROp_PtrType: + case kIROp_RefType: + case kIROp_OutType: + case kIROp_InOutType: + { + SpvStorageClass storageClass = SpvStorageClassFunction; + auto ptrType = as<IRPtrTypeBase>(inst); + if (ptrType->hasAddressSpace()) + storageClass = (SpvStorageClass)ptrType->getAddressSpace(); + return emitInst( + getSection(SpvLogicalSectionID::Types), + inst, + SpvOpTypePointer, + kResultID, + storageClass, + inst->getOperand(0)); + } + case kIROp_StructType: + { + return emitInstCustomOperandFunc( + getSection(SpvLogicalSectionID::Types), inst, SpvOpTypeStruct, [&]() { + emitOperand(kResultID); + for (auto field : static_cast<IRStructType*>(inst)->getFields()) + { + emitOperand(field->getFieldType()); + // TODO: decorate offset + } + }); + } + case kIROp_VectorType: + { + auto vectorType = static_cast<IRVectorType*>(inst); + return ensureVectorType( + static_cast<IRBasicType*>(vectorType->getElementType())->getBaseType(), + static_cast<IRIntLit*>(vectorType->getElementCount())->getValue(), + vectorType); + } + case kIROp_MatrixType: + { + auto matrixType = static_cast<IRMatrixType*>(inst); + auto vectorSpvType = ensureVectorType( + static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(), + static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(), + nullptr); + auto matrixSPVType = emitInst( + getSection(SpvLogicalSectionID::Types), + inst, + SpvOpTypeMatrix, + kResultID, + vectorSpvType, + (SpvWord)static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue()); + // TODO: properly compute matrix stride. + auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(); + uint32_t stride = 0; + switch (columnCount) + { + case 1: + stride = 4; + break; + case 2: + stride = 8; + break; + case 3: + case 4: + stride = 16; + break; + default: + break; + } + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + matrixSPVType, + SpvDecorationRowMajor, + SpvDecorationMatrixStride, + stride); + return matrixSPVType; + } + case kIROp_UnsizedArrayType: + { + auto elementType = static_cast<IRUnsizedArrayType*>(inst)->getElementType(); + auto runtimeArrayType = emitInst( + getSection(SpvLogicalSectionID::Types), + nullptr, + SpvOpTypeRuntimeArray, + kResultID, + elementType); + // TODO: properly decorate stride. + IRSizeAndAlignment sizeAndAlignment; + getNaturalSizeAndAlignment(this->m_targetRequest, elementType, &sizeAndAlignment); + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + runtimeArrayType, + SpvDecorationArrayStride, + (SpvWord)sizeAndAlignment.getStride()); + return runtimeArrayType; + } // > OpTypeImage // > OpTypeSampler // > OpTypeArray // > OpTypeRuntimeArray - // > OpTypeStruct // > OpTypeOpaque // > OpTypePointer @@ -858,6 +1024,15 @@ struct SPIRVEmitContext // return emitFunc(as<IRFunc>(inst)); + case kIROp_BoolLit: + case kIROp_IntLit: + case kIROp_FloatLit: + return emitLit(inst); + + case kIROp_GlobalParam: + return emitGlobalParam(as<IRGlobalParam>(inst)); + case kIROp_GlobalVar: + return emitGlobalVar(as<IRGlobalVar>(inst)); // ... default: @@ -866,6 +1041,162 @@ struct SPIRVEmitContext } } + // Ensures an SpvInst for the specified vector type is emitted. + // `inst` represents an optional `IRVectorType` inst representing the vector type, if + // it is nullptr, this function will create one. + SpvInst* ensureVectorType(BaseType baseType, IRIntegerValue elementCount, IRVectorType* inst) + { + VectorTypeKey key = {baseType, elementCount}; + SpvInst* result = nullptr; + if (m_vectorTypes.TryGetValue(key, result)) + return result; + if (!inst) + { + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertInto(m_irModule->getModuleInst()); + inst = builder.getVectorType( + builder.getBasicType(baseType), + builder.getIntValue(builder.getIntType(), elementCount)); + } + result = emitInst( + getSection(SpvLogicalSectionID::Types), + inst, + SpvOpTypeVector, + kResultID, + inst->getElementType(), + (SpvWord)elementCount); + m_vectorTypes[key] = result; + return result; + } + + void emitVarLayout(SpvInst* varInst, IRVarLayout* layout) + { + for (auto rr : layout->getOffsetAttrs()) + { + UInt index = rr->getOffset(); + UInt space = rr->getSpace(); + switch (rr->getResourceKind()) + { + case LayoutResourceKind::Uniform: + break; + + case LayoutResourceKind::VaryingInput: + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationLocation, + (SpvWord)index); + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationIndex, + (SpvWord)space); + break; + case LayoutResourceKind::VaryingOutput: + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationLocation, + (SpvWord)index); + if (space) + { + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationIndex, + (SpvWord)space); + } + break; + + case LayoutResourceKind::SpecializationConstant: + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationSpecId, + (SpvWord)index); + break; + + case LayoutResourceKind::ConstantBuffer: + case LayoutResourceKind::ShaderResource: + case LayoutResourceKind::UnorderedAccess: + case LayoutResourceKind::SamplerState: + case LayoutResourceKind::DescriptorTableSlot: + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationBinding, + (SpvWord)index); + if (space) + { + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationDescriptorSet, + (SpvWord)space); + } + break; + default: + break; + } + } + } + /// Emit a global parameter definition. + SpvInst* emitGlobalParam(IRGlobalParam* param) + { + auto layout = getVarLayout(param); + auto storageClass = SpvStorageClassUniform; + if (auto ptrType = as<IRPtrTypeBase>(param->getDataType())) + { + if (ptrType->hasAddressSpace()) + storageClass = (SpvStorageClass)ptrType->getAddressSpace(); + } + auto varInst = emitInst( + getSection(SpvLogicalSectionID::GlobalVariables), + param, + SpvOpVariable, + param->getDataType(), + kResultID, + storageClass); + emitVarLayout(varInst, layout); + return varInst; + } + + /// Emit a global variable definition. + SpvInst* emitGlobalVar(IRGlobalVar* globalVar) + { + auto layout = getVarLayout(globalVar); + auto storageClass = SpvStorageClassUniform; + if (auto ptrType = as<IRPtrTypeBase>(globalVar->getDataType())) + { + if (ptrType->hasAddressSpace()) + storageClass = (SpvStorageClass)ptrType->getAddressSpace(); + } + auto varInst = emitInst( + getSection(SpvLogicalSectionID::GlobalVariables), + globalVar, + SpvOpVariable, + globalVar->getDataType(), + kResultID, + storageClass); + emitVarLayout(varInst, layout); + return varInst; + } + /// Emit the given `irFunc` to SPIR-V SpvInst* emitFunc(IRFunc* irFunc) { @@ -951,9 +1282,7 @@ struct SPIRVEmitContext // for( auto irParam : irFunc->getParams() ) { - emitInst(spvFunc, irParam, SpvOpFunctionParameter, - irParam->getFullType(), - kResultID); + emitParam(spvFunc, irParam); } // [3.32.17. Control-Flow Instructions] @@ -992,11 +1321,13 @@ struct SPIRVEmitContext // [3.32.17. Control-Flow Instructions] // // > OpPhi - // - // TODO: We eventually need to emit `OpPhi` instructions corresponding - // to the parameters of any non-entry block, with operands representing - // the values passed along incoming edges from the predecessor blocks. - + if (irBlock != irFunc->getFirstBlock()) + { + for (auto irParam : irBlock->getParams()) + { + emitPhi(spvBlock, irParam); + } + } for( auto irInst : irBlock->getOrdinaryInsts() ) { // Any instructions local to the block will be emitted as children @@ -1036,16 +1367,243 @@ struct SPIRVEmitContext /// Emit an instruction that is local to the body of the given `parent`. SpvInst* emitLocalInst(SpvInstParent* parent, IRInst* inst) { + auto getBlockID = [=](IRBlock* block) + { + SpvInst* spvInst = nullptr; + m_mapIRInstToSpvInst.TryGetValue(block, spvInst); + SLANG_ASSERT(spvInst); + return getID(spvInst); + }; switch( inst->getOp() ) { default: SLANG_UNIMPLEMENTED_X("unhandled instruction opcode"); break; + case kIROp_Specialize: + return nullptr; + case kIROp_Var: + return emitVar(parent, inst); + case kIROp_Call: + return emitCall(parent, inst); + case kIROp_FieldAddress: + return emitFieldAddress(parent, as<IRFieldAddress>(inst)); + case kIROp_FieldExtract: + return emitFieldExtract(parent, as<IRFieldExtract>(inst)); + case kIROp_getElementPtr: + return emitGetElementPtr(parent, as<IRGetElementPtr>(inst)); + case kIROp_getElement: + return emitGetElement(parent, as<IRGetElement>(inst)); + case kIROp_Load: + return emitLoad(parent, as<IRLoad>(inst)); + case kIROp_Store: + return emitStore(parent, as<IRStore>(inst)); + case kIROp_swizzle: + return emitSwizzle(parent, as<IRSwizzle>(inst)); + case kIROp_Construct: + return emitConstruct(parent, inst); + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_IRem: + case kIROp_FRem: + case kIROp_Neg: + case kIROp_Not: + case kIROp_And: + case kIROp_Or: + case kIROp_BitNot: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_Less: + case kIROp_Leq: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Greater: + case kIROp_Geq: + case kIROp_Rsh: + case kIROp_Lsh: + return emitArithmetic(parent, inst); + case kIROp_ReturnVal: + return emitInst( + parent, inst, SpvOpReturnValue, as<IRReturnVal>(inst)->getVal()); + case kIROp_ReturnVoid: + return emitInst(parent, inst, SpvOpReturn); + case kIROp_discard: + return emitInst(parent, inst, SpvOpKill); + case kIROp_unconditionalBranch: + return emitInst( + parent, + inst, + SpvOpBranch, + getBlockID(as<IRUnconditionalBranch>(inst)->getTargetBlock())); + case kIROp_loop: + { + auto loopInst = as<IRLoop>(inst); + + SpvWord loopControl = 0; + if (auto loopControlDecoration = + loopInst->findDecoration<IRLoopControlDecoration>()) + { + switch (loopControlDecoration->getMode()) + { + case IRLoopControl::kIRLoopControl_Unroll: + loopControl = 0x1; + break; + case IRLoopControl::kIRLoopControl_Loop: + loopControl = 0x2; + break; + default: + break; + } + } + emitInst( + parent, + nullptr, + SpvOpLoopMerge, + getBlockID(loopInst->getBreakBlock()), + getBlockID(loopInst->getContinueBlock()), + loopControl); + + return emitInst(parent, inst, SpvOpBranch, loopInst->getTargetBlock()); + } + case kIROp_ifElse: + { + auto ifelseInst = as<IRIfElse>(inst); + auto afterBlockID = getBlockID(ifelseInst->getAfterBlock()); + emitInst( + parent, + nullptr, + SpvOpSelectionMerge, + afterBlockID); + auto falseLabel = ifelseInst->getFalseBlock(); + return emitInst( + parent, + inst, + SpvOpBranchConditional, + ifelseInst->getCondition(), + ifelseInst->getTrueBlock(), + falseLabel ? getID(ensureInst(falseLabel)) : afterBlockID); + } + case kIROp_Switch: + { + auto switchInst = as<IRSwitch>(inst); + auto mergeBlockID = getBlockID(switchInst->getBreakLabel()); + emitInst( + parent, + nullptr, + SpvOpSelectionMerge, mergeBlockID); + return emitInstCustomOperandFunc(parent, inst, SpvOpSwitch, [&]() { + emitOperand(switchInst->getCondition()); + auto defaultLabel = switchInst->getDefaultLabel(); + emitOperand(defaultLabel ? getID(ensureInst(defaultLabel)) : mergeBlockID); + for (UInt c = 0; c < switchInst->getCaseCount(); c++) + { + auto value = switchInst->getCaseValue(c); + auto intLit = as<IRIntLit>(value); + SLANG_ASSERT(intLit); + emitOperand((SpvWord)intLit->getValue()); + auto caseLabel = switchInst->getCaseLabel(c); + emitOperand(caseLabel ? getID(ensureInst(caseLabel)) : mergeBlockID); + } + }); + } + case kIROp_Unreachable: + return emitInst(parent, inst, SpvOpUnreachable); + case kIROp_conditionalBranch: + SLANG_UNEXPECTED("Unstructured branching is not supported by SPIRV."); + } + } - // [3.32.17. Control-Flow Instructions] - // - // > OpReturn - case kIROp_ReturnVoid: return emitInst(parent, inst, SpvOpReturn); + SpvInst* emitLit(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_IntLit: + { + auto value = as<IRIntLit>(inst)->getValue(); + switch (as<IRBasicType>(inst->getDataType())->getBaseType()) + { + case BaseType::Int64: + case BaseType::UInt64: + return emitInst( + getSection(SpvLogicalSectionID::Constants), + inst, + SpvOpConstant, + inst->getDataType(), + kResultID, + (SpvWord)(value & 0xFFFFFFFF), + (SpvWord)((value >> 32) & 0xFFFFFFFF)); + default: + return emitInst( + getSection(SpvLogicalSectionID::Constants), + inst, + SpvOpConstant, + inst->getDataType(), + kResultID, + (SpvWord)value); + } + } + case kIROp_FloatLit: + { + auto value = as<IRConstant>(inst)->value.floatVal; + switch (as<IRBasicType>(inst->getDataType())->getBaseType()) + { + case BaseType::Half: + return emitInst( + getSection(SpvLogicalSectionID::Constants), + inst, + SpvOpConstant, + inst->getDataType(), + kResultID, + (SpvWord)(FloatToHalf((float)value))); + case BaseType::Float: + return emitInst( + getSection(SpvLogicalSectionID::Constants), + inst, + SpvOpConstant, + inst->getDataType(), + kResultID, + (SpvWord)(FloatAsInt((float)value))); + case BaseType::Double: + { + auto ival = DoubleAsInt64(value); + return emitInst( + getSection(SpvLogicalSectionID::Constants), + inst, + SpvOpConstant, + inst->getDataType(), + kResultID, + (SpvWord)(ival&0xFFFFFFFF), + (SpvWord)(ival>>32)); + } + default: + return nullptr; + } + } + case kIROp_BoolLit: + { + if (as<IRBoolLit>(inst)->getValue()) + { + return emitInst( + getSection(SpvLogicalSectionID::Constants), + inst, + SpvOpConstantTrue, + inst->getDataType(), + kResultID); + } + else + { + return emitInst( + getSection(SpvLogicalSectionID::Constants), + inst, + SpvOpConstantFalse, + inst->getDataType(), + kResultID); + } + } + default: + return nullptr; } } @@ -1184,24 +1742,655 @@ struct SPIRVEmitContext } } - SPIRVEmitContext(IRModule* module) : - m_irModule(module), - m_memoryArena(2048) + SpvInst* emitBuiltinSystemVal(SpvInstParent* parent, IRInst* inst, SpvBuiltIn builtinVal) + { + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertBefore(inst); + + auto ptrIRType = builder.getPtrType(inst->getDataType()); + auto varInst = emitInst(parent, inst, SpvOpVariable, ptrIRType, kResultID); + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationBuiltIn, + builtinVal); + return varInst; + } + + SpvInst* emitParam(SpvInstParent* parent, IRInst* inst) + { + if (auto layout = getVarLayout(inst)) + { + if (auto systemValueAttr = layout->findAttr<IRSystemValueSemanticAttr>()) + { + String semanticName = systemValueAttr->getName(); + semanticName = semanticName.toLower(); + if (semanticName == "sv_dispatchthreadid") + { + return emitBuiltinSystemVal(parent, inst, SpvBuiltInGlobalInvocationId); + } + } + } + return emitInst(parent, inst, SpvOpFunctionParameter, inst->getFullType(), kResultID); + } + + SpvInst* emitVar(SpvInstParent* parent, IRInst* inst) + { + SpvWord storageClass = SpvStorageClassFunction; + auto rate = inst->getFullType()->getRate(); + if (rate) + { + switch (rate->getOp()) + { + case kIROp_GroupSharedRate: + storageClass = SpvStorageClassWorkgroup; + break; + default: + break; + } + } + return emitInst(parent, inst, SpvOpVariable, inst->getFullType(), kResultID, storageClass); + } + + /// Cached `IRParam` indices in an `IRBlock`. For use in `getParamIndexInBlock`. + struct BlockParamIndexInfo : public RefObject + { + Dictionary<IRParam*, int> mapParamToIndex; + }; + Dictionary<IRBlock*, RefPtr<BlockParamIndexInfo>> m_mapIRBlockToParamIndexInfo; + + /// Returns the index of an `IRParam` inside a `IRBlock`. + /// The results are cached in `m_mapIRBlockToParamIndexInfo` to avoid linear search. + int getParamIndexInBlock(IRBlock* block, IRParam* paramInst) + { + RefPtr<BlockParamIndexInfo> info; + int result = -1; + if (m_mapIRBlockToParamIndexInfo.TryGetValue(block, info)) + { + info->mapParamToIndex.TryGetValue(paramInst, result); + SLANG_ASSERT(result != -1); + return result; + } + info = new BlockParamIndexInfo(); + int paramIndex = 0; + for (auto param : block->getParams()) + { + info->mapParamToIndex[param] = paramIndex; + if (param == paramInst) + result = paramIndex; + paramIndex++; + } + m_mapIRBlockToParamIndexInfo[block] = info; + SLANG_ASSERT(result != -1); + return result; + } + + SpvInst* emitPhi(SpvInstParent* parent, IRParam* inst) + { + // An `IRParam` in an ordinary `IRBlock` represents a phi value. + // We can translate them directly to SPIRV's `Phi` instruction. + // In order to do that, we need to figure out the source values + // of this `IRParam`, which can be done by looking at the users + // of current `IRBlock`. + + // First, we find the index of this param. + IRBlock* block = as<IRBlock>(inst->getParent()); + SLANG_ASSERT(block); + int paramIndex = getParamIndexInBlock(block, inst); + + // Emit a Phi instruction. + return emitInstCustomOperandFunc(parent, inst, SpvOpPhi, [&]() { + emitOperand(inst->getFullType()); + emitOperand(kResultID); + // Find phi arguments from incoming branch instructions that target `block`. + for (auto use = block->firstUse; use; use = use->nextUse) + { + auto branchInst = use->getUser(); + UInt argStartIndex = 0; + switch (branchInst->getOp()) + { + case kIROp_unconditionalBranch: + argStartIndex = 1; + break; + case kIROp_loop: + argStartIndex = 3; + break; + default: + // A phi argument can only come from an unconditional branch inst. + // Other uses are not relavent so we should skip. + continue; + } + SLANG_ASSERT(argStartIndex + paramIndex < branchInst->getOperandCount()); + auto valueInst = branchInst->getOperand(argStartIndex + paramIndex); + emitOperand(valueInst); + auto sourceBlock = as<IRBlock>(branchInst->getParent()); + SLANG_ASSERT(sourceBlock); + emitOperand(getIRInstSpvID(sourceBlock)); + } + }); + } + + SpvInst* emitCall(SpvInstParent* parent, IRInst* inst) + { + auto funcValue = inst->getOperand(0); + + // Does this function declare any requirements. + handleRequiredCapabilities(funcValue); + + // We want to detect any call to an intrinsic operation, and inline + // the SPIRV snippet directly at the call site. + if (auto targetIntrinsic = Slang::findBestTargetIntrinsicDecoration( + funcValue, m_targetRequest->getTargetCaps())) + { + return emitIntrinsicCallExpr(parent, static_cast<IRCall*>(inst), targetIntrinsic); + } + else + { + return emitInst( + parent, inst, SpvOpFunctionCall, inst->getFullType(), kResultID, OperandsOf(inst)); + } + } + + SpvInst* emitIntrinsicCallExpr( + SpvInstParent* parent, + IRCall* inst, + IRTargetIntrinsicDecoration* intrinsic) + { + SpvSnippet* snippet = getParsedSpvSnippet(intrinsic); + SpvSnippetEmitContext context; + context.resultType = ensureInst(inst->getFullType()); + for (SlangUInt i = 0; i < inst->getArgCount(); i++) + { + auto argInst = ensureInst(inst->getArg(i)); + if (argInst) + { + context.argumentIds.add(getID(argInst)); + } + else + { + context.argumentIds.add(0xFFFFFFFF); + } + } + // A SPIRV snippet may refer to the result type of this inst with a + // different storage-class qualifier. We need to pre-create these + // storage-class-qualified result pointer types so they can be used + // during inlining of the snippet. + if (auto oldPtrType = as<IRPtrTypeBase>(inst->getDataType())) + { + for (auto storageClass : snippet->usedResultTypeStorageClasses) + { + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertBefore(inst); + auto newPtrType = builder.getPtrType( + oldPtrType->getOp(), oldPtrType->getValueType(), storageClass); + context.qualifiedResultTypes[storageClass] = newPtrType; + } + } + return emitSpvSnippet(parent, inst, context, snippet); + } + + SpvInst* emitSpvSnippet( + SpvInstParent* parent, + IRCall* inst, + const SpvSnippetEmitContext& context, + SpvSnippet* snippet) + { + ShortList<SpvInst*> emittedInsts; + for (Index i = 0; i < snippet->instructions.getCount(); i++) + { + auto& spvSnippetInst = snippet->instructions[i]; + InstConstructScope scopeInst(this, (SpvOp)spvSnippetInst.opCode, nullptr); + SpvInst* spvInst = scopeInst; + for (auto operand : spvSnippetInst.operands) + { + switch (operand.type) + { + case SpvSnippet::ASMOperandType::SpvWord: + emitOperand((SpvWord)operand.content); + break; + case SpvSnippet::ASMOperandType::ObjectReference: + SLANG_ASSERT( + operand.content >= 0 && operand.content < context.argumentIds.getCount()); + emitOperand(context.argumentIds[operand.content]); + break; + case SpvSnippet::ASMOperandType::ResultId: + emitOperand(kResultID); + break; + case SpvSnippet::ASMOperandType::ResultTypeId: + if (operand.content != -1) + { + emitOperand(context.qualifiedResultTypes[(SpvStorageClass)operand.content] + .GetValue()); + } + else + { + emitOperand(context.resultType); + } + break; + case SpvSnippet::ASMOperandType::InstReference: + SLANG_ASSERT(operand.content >= 0 && operand.content < emittedInsts.getCount()); + emitOperand(getID(emittedInsts[operand.content])); + break; + } + } + parent->addInst(spvInst); + emittedInsts.add(spvInst); + } + auto resultInst = emittedInsts.getLast(); + registerInst(inst, resultInst); + return resultInst; + } + + struct StructTypeInfo : public RefObject + { + Dictionary<IRStructKey*, Index> structFieldIndices; + }; + + Dictionary<IRStructType*, RefPtr<StructTypeInfo>> m_structTypeInfos; + + RefPtr<StructTypeInfo> createStructTypeInfo(IRStructType* structType) + { + RefPtr<StructTypeInfo> typeInfo = new StructTypeInfo(); + Index index = 0; + for (auto field : structType->getFields()) + { + typeInfo->structFieldIndices[field->getKey()] = index; + index++; + } + return typeInfo; + } + Index getStructFieldId(IRStructType* structType, IRStructKey* structFieldKey) + { + RefPtr<StructTypeInfo> info; + if (!m_structTypeInfos.TryGetValue(structType, info)) + { + info = createStructTypeInfo(structType); + m_structTypeInfos[structType] = info; + } + Index fieldIndex = -1; + info->structFieldIndices.TryGetValue(structFieldKey, fieldIndex); + SLANG_ASSERT(fieldIndex != -1); + return fieldIndex; + } + + SpvInst* emitFieldAddress(SpvInstParent* parent, IRFieldAddress* fieldAddress) + { + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertBefore(fieldAddress); + + auto base = fieldAddress->getBase(); + SpvWord baseId = 0; + IRStructType* baseStructType = nullptr; + + if (auto ptrLikeType = as<IRPointerLikeType>(base->getDataType())) + { + baseStructType = as<IRStructType>(ptrLikeType->getElementType()); + baseId = getID(ensureInst(base)); + } + else if (auto ptrType = as<IRPtrTypeBase>(base->getDataType())) + { + baseStructType = as<IRStructType>(ptrType->getValueType()); + baseId = getID(ensureInst(base)); + } + else + { + baseStructType = as<IRStructType>(base->getDataType()); + + auto structPtrType = builder.getPtrType(baseStructType); + auto varInst = emitInst( + parent, nullptr, SpvOpVariable, structPtrType, kResultID, SpvStorageClassFunction); + emitInst(parent, nullptr, SpvOpStore, varInst, base); + baseId = getID(varInst); + } + SLANG_ASSERT(baseStructType && "field_address require base to be a struct."); + auto fieldId = emitConstant( + getStructFieldId(baseStructType, as<IRStructKey>(fieldAddress->getField())), + builder.getIntType()); + return emitInst( + parent, + fieldAddress, + SpvOpAccessChain, + fieldAddress->getFullType(), + kResultID, + baseId, + fieldId); + } + + SpvInst* emitFieldExtract(SpvInstParent* parent, IRFieldExtract* inst) + { + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertBefore(inst); + + IRStructType* baseStructType = as<IRStructType>(inst->getBase()->getDataType()); + SLANG_ASSERT(baseStructType && "field_extract require base to be a struct."); + auto fieldId = emitConstant( + getStructFieldId(baseStructType, as<IRStructKey>(inst->getField())), + builder.getIntType()); + + return emitInst( + parent, + inst, + SpvOpCompositeExtract, + inst->getDataType(), + kResultID, + inst->getBase(), + fieldId); + } + + SpvInst* emitGetElementPtr(SpvInstParent* parent, IRGetElementPtr* inst) + { + auto base = inst->getBase(); + SpvWord baseId = 0; + IRArrayType* baseArrayType = nullptr; + + if (auto ptrLikeType = as<IRPointerLikeType>(base->getDataType())) + { + baseArrayType = as<IRArrayType>(ptrLikeType->getElementType()); + baseId = getID(ensureInst(base)); + } + else if (auto ptrType = as<IRPtrTypeBase>(base->getDataType())) + { + baseArrayType = as<IRArrayType>(ptrType->getValueType()); + baseId = getID(ensureInst(base)); + } + else + { + SLANG_ASSERT(!"invalid IR: base of getElementPtr must be a pointer."); + } + SLANG_ASSERT(baseArrayType && "getElementPtr require base to be an array."); + return emitInst( + parent, + inst, + SpvOpAccessChain, + inst->getFullType(), + kResultID, + baseId, + inst->getIndex()); + } + + SpvInst* emitGetElement(SpvInstParent* parent, IRGetElement* inst) + { + auto base = inst->getBase(); + SpvWord baseId = 0; + IRArrayType* baseArrayType = nullptr; + + if (auto ptrLikeType = as<IRPointerLikeType>(base->getDataType())) + { + baseArrayType = as<IRArrayType>(ptrLikeType->getElementType()); + baseId = getID(ensureInst(base)); + } + else if (auto ptrType = as<IRPtrTypeBase>(base->getDataType())) + { + baseArrayType = as<IRArrayType>(ptrType->getValueType()); + baseId = getID(ensureInst(base)); + } + else + { + SLANG_ASSERT(!"invalid IR: base of getElement must be a pointer."); + } + SLANG_ASSERT(baseArrayType && "getElement require base to be an array."); + + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertBefore(inst); + + auto ptr = emitInst( + parent, + nullptr, + SpvOpAccessChain, + builder.getPtrType(inst->getFullType()), + kResultID, + baseId, + inst->getIndex()); + return emitInst(parent, inst, SpvOpLoad, inst->getFullType(), kResultID, ptr); + } + + SpvInst* emitLoad(SpvInstParent* parent, IRLoad* inst) + { + return emitInst(parent, inst, SpvOpLoad, inst->getDataType(), kResultID, inst->getPtr()); + } + + SpvInst* emitStore(SpvInstParent* parent, IRStore* inst) + { + return emitInst(parent, inst, SpvOpStore, inst->getPtr(), inst->getVal()); + } + + SpvInst* emitSwizzle(SpvInstParent* parent, IRSwizzle* inst) + { + return emitInstCustomOperandFunc(parent, inst, SpvOpVectorShuffle, [&]() { + emitOperand(inst->getDataType()); + emitOperand(kResultID); + emitOperand(inst->getBase()); + emitOperand(inst->getBase()); + for (UInt i = 0; i < inst->getElementCount(); i++) + { + auto index = as<IRIntLit>(inst->getElementIndex(i)); + emitOperand((SpvWord)index->getValue()); + } + }); + } + + SpvInst* emitConstruct(SpvInstParent* parent, IRInst* inst) + { + if (as<IRBasicType>(inst->getDataType())) + { + if (inst->getOperandCount() == 1) + { + if (inst->getDataType() == inst->getOperand(0)->getDataType()) + return emitInst(parent, inst, SpvOpCopyObject, kResultID, inst->getOperand(0)); + else + return emitInst(parent, inst, SpvOpBitcast, inst->getDataType(), kResultID, inst->getOperand(0)); + } + else + { + SLANG_ASSERT(!"spirv emit: unsupported Construct inst."); + return nullptr; + } + } + else + { + return emitInst( + parent, + inst, + SpvOpCompositeConstruct, + inst->getDataType(), + kResultID, + OperandsOf(inst)); + } + } + + bool isSignedType(IRBasicType* basicType) + { + switch (basicType->getBaseType()) + { + case BaseType::Float: + case BaseType::Double: + return true; + case BaseType::Int: + case BaseType::Int16: + case BaseType::Int64: + case BaseType::Int8: + return true; + default: + return false; + } + } + + SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) + { + IRType* elementType = inst->getDataType(); + if (auto vectorType = as<IRVectorType>(inst->getDataType())) + { + elementType = vectorType->getElementType(); + } + else if (auto matrixType = as<IRMatrixType>(inst->getDataType())) + { + //TODO: implement. + SLANG_ASSERT(!"unimplemented: matrix arithemetic"); + } + IRBasicType* basicType = as<IRBasicType>(elementType); + bool isFloatingPoint = false; + bool isBool = false; + switch (basicType->getBaseType()) + { + case BaseType::Float: + case BaseType::Double: + isFloatingPoint = true; + break; + case BaseType::Bool: + isBool = true; + default: + break; + } + SpvOp opCode = SpvOpUndef; + bool isSigned = isSignedType(basicType); + switch (inst->getOp()) + { + case kIROp_Add: + opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd; + break; + case kIROp_Sub: + opCode = isFloatingPoint ? SpvOpFSub : SpvOpISub; + break; + case kIROp_Mul: + opCode = isFloatingPoint ? SpvOpFMul : SpvOpIMul; + break; + case kIROp_Div: + opCode = isFloatingPoint ? SpvOpFDiv : isSigned ? SpvOpSDiv : SpvOpUDiv; + break; + case kIROp_IRem: + opCode = isSigned ? SpvOpSRem : SpvOpUMod; + break; + case kIROp_FRem: + opCode = SpvOpFRem; + break; + case kIROp_Less: + opCode = isFloatingPoint ? SpvOpFOrdLessThan + : isSigned ? SpvOpSLessThan : SpvOpULessThan; + break; + case kIROp_Leq: + opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual + : isSigned ? SpvOpSLessThanEqual : SpvOpULessThanEqual; + break; + case kIROp_Eql: + opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual; + break; + case kIROp_Neq: + opCode = isFloatingPoint ? SpvOpFOrdNotEqual + : isBool ? SpvOpLogicalNotEqual : SpvOpINotEqual; + break; + case kIROp_Geq: + opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual + : isSigned ? SpvOpSGreaterThanEqual : SpvOpUGreaterThanEqual; + break; + case kIROp_Greater: + opCode = isFloatingPoint ? SpvOpFOrdGreaterThan + : isSigned ? SpvOpSGreaterThan : SpvOpUGreaterThan; + break; + case kIROp_Neg: + opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate; + break; + case kIROp_And: + opCode = SpvOpLogicalAnd; + break; + case kIROp_Or: + opCode = SpvOpLogicalOr; + break; + case kIROp_Not: + opCode = SpvOpLogicalNot; + break; + case kIROp_BitAnd: + opCode = SpvOpBitwiseAnd; + break; + case kIROp_BitOr: + opCode = SpvOpBitwiseOr; + break; + case kIROp_BitXor: + opCode = SpvOpBitwiseXor; + break; + case kIROp_BitNot: + opCode = SpvOpBitReverse; + break; + case kIROp_Rsh: + opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical; + break; + case kIROp_Lsh: + opCode = SpvOpShiftLeftLogical; + break; + default: + SLANG_ASSERT(!"unknown arithmetic opcode"); + break; + } + return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, OperandsOf(inst)); + } + + OrderedHashSet<SpvCapability> m_capabilities; + + void requireSPIRVCapability(SpvCapability capability) + { + if (m_capabilities.Add(capability)) + { + emitInst( + getSection(SpvLogicalSectionID::Capabilities), + nullptr, + SpvOpCapability, + capability); + } + } + + void handleRequiredCapabilitiesImpl(IRInst* inst) + { + // TODO: declare required SPV capabilities. + + for (auto decoration : inst->getDecorations()) + { + switch (decoration->getOp()) + { + default: + break; + + case kIROp_RequireGLSLExtensionDecoration: + { + break; + } + case kIROp_RequireGLSLVersionDecoration: + { + break; + } + case kIROp_RequireSPIRVVersionDecoration: + { + break; + } + } + } + } + + SPIRVEmitContext(IRModule* module, TargetRequest* target, DiagnosticSink* sink) + : SPIRVEmitSharedContext(module, target) + , m_irModule(module) + , m_sink(sink) + , m_memoryArena(2048) { } }; SlangResult emitSPIRVFromIR( BackEndCompileRequest* compileRequest, + TargetRequest* targetRequest, IRModule* irModule, const List<IRFunc*>& irEntryPoints, List<uint8_t>& spirvOut) { - SLANG_UNUSED(compileRequest); - spirvOut.clear(); - SPIRVEmitContext context(irModule); + SPIRVEmitContext context(irModule, targetRequest, compileRequest->getSink()); + legalizeIRForSPIRV(&context, irModule, compileRequest->getSink()); context.emitFrontMatter(); for (auto irEntryPoint : irEntryPoints) |
