diff options
| author | Yong He <yonghe@outlook.com> | 2021-08-17 09:39:02 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-08-17 09:39:02 -0700 |
| commit | 858c7c57b125afed9b5b2329d6b02477284e4803 (patch) | |
| tree | 49f67b342448dcfb19913d8ccc089d956de14462 /source/slang/slang-emit-spirv.cpp | |
| parent | 6406523511037987d8b8ab881aea41389afd57eb (diff) | |
Add GLSL450 intrinsics to SPIRV direct emit. (#1921)
* Add GLSL450 intrinsics to SPIRV direct emit.
* Fix.
* Fix compiler error.
* Fix.
* Fix compiler error.
* Make direct-spirv tests actually run.
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 786 |
1 files changed, 635 insertions, 151 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 37fd673ed..21f5c1bc8 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -266,6 +266,11 @@ void SpvInstParent::dumpTo(List<SpvWord>& ioWords) struct SpvSnippetEmitContext { SpvInst* resultType; + IRType* irResultType; + // True if resultType is float or vector of float. + bool isResultTypeFloat; + // True if resultType is signed. + bool isResultTypeSigned; Dictionary<SpvStorageClass, IRInst*> qualifiedResultTypes; List<SpvWord> argumentIds; }; @@ -401,6 +406,16 @@ struct SPIRVEmitContext void registerInst(IRInst* irInst, SpvInst* spvInst) { m_mapIRInstToSpvInst.Add(irInst, spvInst); + + // If we have reserved an SpvID for `irInst`, make sure to use it. + SpvWord reservedID = 0; + m_mapIRInstToSpvID.TryGetValue(irInst, reservedID); + + if (reservedID) + { + SLANG_ASSERT(spvInst->id == 0); + spvInst->id = reservedID; + } } /// Get or reserve a SpvID for an IR value. @@ -439,18 +454,6 @@ struct SPIRVEmitContext return id; } - struct VectorTypeKey - { - BaseType baseType; - IRIntegerValue elementCount; - HashCode getHashCode() { return combineHash((int)baseType, (HashCode)elementCount); } - bool operator==(const VectorTypeKey& other) - { - return baseType == other.baseType && elementCount == other.elementCount; - } - }; - Dictionary<VectorTypeKey, SpvInst*> m_vectorTypes; - // We will build up `SpvInst`s in a stateful fashion, // mostly for convenience. We could in theory compute // the number of words each instruction needs, then allocate @@ -509,8 +512,6 @@ struct SPIRVEmitContext if(irInst) { registerInst(irInst, spvInst); - // If we have reserved an SpvID for `irInst`, make sure to use it. - m_mapIRInstToSpvID.TryGetValue(irInst, spvInst->id); } // Set up the scope @@ -675,19 +676,93 @@ struct SPIRVEmitContext void emitOperand(SpvBuiltIn builtin) { emitOperand((SpvWord)builtin); } void emitOperand(SpvStorageClass val) { emitOperand((SpvWord)val); } - Dictionary<IRIntegerValue, SpvInst*> m_spvIntConstants; - SpvInst* emitConstant(IRIntegerValue val, IRType* type) + template<typename TConstant> + struct ConstantValueKey + { + IRType* type; + TConstant value; + HashCode getHashCode() const + { + return combineHash(Slang::getHashCode(type), Slang::getHashCode(value)); + } + bool operator==(const ConstantValueKey& other) const + { + return type == other.type && value == other.value; + } + }; + Dictionary<ConstantValueKey<IRIntegerValue>, SpvInst*> m_spvIntConstants; + Dictionary<ConstantValueKey<IRFloatingPointValue>, SpvInst*> m_spvFloatConstants; + SpvInst* emitIntConstant(IRIntegerValue val, IRType* type) { + ConstantValueKey<IRIntegerValue> key; + key.value = val; + key.type = type; SpvInst* result = nullptr; - if (m_spvIntConstants.TryGetValue(val, result)) + if (m_spvIntConstants.TryGetValue(key, result)) return result; - return emitInst( - getSection(SpvLogicalSectionID::Constants), - nullptr, - SpvOpConstant, - type, - kResultID, - (SpvWord)val); + SpvWord valWord; + memcpy(&valWord, &val, sizeof(SpvWord)); + if (type->getOp() == kIROp_Int64Type || type->getOp() == kIROp_UInt64Type) + { + SpvWord valHighWord; + memcpy(&valHighWord, (char*)(&val) + 4, sizeof(SpvWord)); + result = emitInst( + getSection(SpvLogicalSectionID::Constants), + nullptr, + SpvOpConstant, + type, + kResultID, + valWord, + valHighWord); + } + else + { + result = emitInst( + getSection(SpvLogicalSectionID::Constants), + nullptr, + SpvOpConstant, + type, + kResultID, + valWord); + } + m_spvIntConstants[key] = result; + return result; + } + SpvInst* emitFloatConstant(IRFloatingPointValue val, IRType* type) + { + ConstantValueKey<IRFloatingPointValue> key; + key.value = val; + key.type = type; + SpvInst* result = nullptr; + if (m_spvFloatConstants.TryGetValue(key, result)) + return result; + SpvWord valWord; + memcpy(&valWord, &val, sizeof(SpvWord)); + if (type->getOp() == kIROp_DoubleType) + { + SpvWord valHighWord; + memcpy(&valHighWord, (char*)(&val) + 4, sizeof(SpvWord)); + result = emitInst( + getSection(SpvLogicalSectionID::Constants), + nullptr, + SpvOpConstant, + type, + kResultID, + valWord, + valHighWord); + } + else + { + result = emitInst( + getSection(SpvLogicalSectionID::Constants), + nullptr, + SpvOpConstant, + type, + kResultID, + valWord); + } + m_spvFloatConstants[key] = result; + return result; } // As another convenience, there are often cases where // we will want to emit all of the operands of some @@ -812,6 +887,22 @@ struct SPIRVEmitContext return spvInst; } + /// The SPIRV OpExtInstImport inst that represents the GLSL450 + /// extended instruction set. + SpvInst* m_glsl450ExtInst = nullptr; + + SpvInst* getGLSL450ExtInst() + { + if (m_glsl450ExtInst) + return m_glsl450ExtInst; + m_glsl450ExtInst = emitInst( + getSection(SpvLogicalSectionID::ExtIntInstImports), + nullptr, + SpvOpExtInstImport, + UnownedStringSlice("GLSL.std.450")); + return m_glsl450ExtInst; + } + // Now that we've gotten the core infrastructure out of the way, // let's start looking at emitting some instructions that make // up a SPIR-V module. @@ -849,6 +940,66 @@ struct SPIRVEmitContext emitInst(getSection(SpvLogicalSectionID::MemoryModel), nullptr, SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450); } + Dictionary<UnownedStringSlice, SpvInst*> m_extensionInsts; + SpvInst* ensureExtensionDeclaration(UnownedStringSlice name) + { + SpvInst* result = nullptr; + if (m_extensionInsts.TryGetValue(name, result)) + return result; + result = + emitInst(getSection(SpvLogicalSectionID::Extensions), nullptr, SpvOpExtension, name); + m_extensionInsts[name] = result; + return result; + } + + struct SpvTypeInstKey + { + List<SpvWord> words; + bool operator==(const SpvTypeInstKey& other) + { + if (words.getCount() != other.words.getCount()) + return false; + for (Index i = 0; i < words.getCount(); i++) + if (words[i] != other.words[i]) + return false; + return true; + } + HashCode getHashCode() + { + HashCode result = 0; + for (auto word : words) + result = combineHash(result, word); + return result; + } + }; + + Dictionary<SpvTypeInstKey, SpvInst*> m_spvTypeInsts; + + // Emits a SPV Inst that represents a type, with deduplications since + // our IR doesn't currently guarantee types are unique in generated SPV. + SpvInst* emitTypeInst(IRInst* typeInst, SpvOp opcode, ArrayView<SpvWord> operands) + { + SpvTypeInstKey key; + key.words.add((SpvWord)opcode); + for (auto op : operands) + key.words.add(op); + SpvInst* result = nullptr; + if (m_spvTypeInsts.TryGetValue(key, result)) + { + return result; + } + result = emitInstCustomOperandFunc( + getSection(SpvLogicalSectionID::Types), typeInst, opcode, [&]() { + emitOperand(kResultID); + for (auto op : operands) + { + emitOperand(op); + } + }); + m_spvTypeInsts[key] = result; + return result; + } + // Next, let's look at emitting some of the instructions // that can occur at global scope. @@ -864,7 +1015,7 @@ struct SPIRVEmitContext // #define CASE(IROP, SPVOP) \ - case IROP: return emitInst(getSection(SpvLogicalSectionID::Types), inst, SPVOP, kResultID) + case IROP: return emitTypeInst(inst, SPVOP, ArrayView<SpvWord>()); // > OpTypeVoid CASE(kIROp_VoidType, SpvOpTypeVoid); @@ -877,7 +1028,8 @@ struct SPIRVEmitContext // > OpTypeInt #define CASE(IROP, BITS, SIGNED) \ - case IROP: return emitInst(getSection(SpvLogicalSectionID::Types), inst, SpvOpTypeInt, kResultID, BITS, SIGNED) + case IROP: \ + return emitTypeInst(inst, SpvOpTypeInt, makeArray<SpvWord>((SpvWord)BITS, (SpvWord)SIGNED).getView()); CASE(kIROp_IntType, 32, 1); CASE(kIROp_UIntType, 32, 0); @@ -889,7 +1041,9 @@ struct SPIRVEmitContext // > OpTypeFloat #define CASE(IROP, BITS) \ - case IROP: return emitInst(getSection(SpvLogicalSectionID::Types), inst, SpvOpTypeFloat, kResultID, BITS) + case IROP: \ + return emitTypeInst( \ + inst, SpvOpTypeFloat, makeArray<SpvWord>(BITS).getView()); \ CASE(kIROp_HalfType, 16); CASE(kIROp_FloatType, 32); @@ -905,17 +1059,16 @@ struct SPIRVEmitContext auto ptrType = as<IRPtrTypeBase>(inst); if (ptrType->hasAddressSpace()) storageClass = (SpvStorageClass)ptrType->getAddressSpace(); - return emitInst( - getSection(SpvLogicalSectionID::Types), - inst, - SpvOpTypePointer, - kResultID, - storageClass, - inst->getOperand(0)); + if (storageClass == SpvStorageClassStorageBuffer) + ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_storage_buffer_storage_class")); + auto operands = makeArray<SpvWord>( + (SpvWord)storageClass, getID(ensureInst(inst->getOperand(0)))); + return emitTypeInst( + inst, SpvOpTypePointer, operands.getView()); } case kIROp_StructType: { - return emitInstCustomOperandFunc( + auto spvStructType = emitInstCustomOperandFunc( getSection(SpvLogicalSectionID::Types), inst, SpvOpTypeStruct, [&]() { emitOperand(kResultID); for (auto field : static_cast<IRStructType*>(inst)->getFields()) @@ -924,6 +1077,8 @@ struct SPIRVEmitContext // TODO: decorate offset } }); + emitDecorations(inst, getID(spvStructType)); + return spvStructType; } case kIROp_VectorType: { @@ -1012,6 +1167,12 @@ struct SPIRVEmitContext // return emitInst(getSection(SpvLogicalSectionID::Types), inst, SpvOpTypeFunction, kResultID, OperandsOf(inst)); + case kIROp_RateQualifiedType: + { + auto result = emitGlobalInst(as<IRRateQualifiedType>(inst)->getValueType()); + registerInst(inst, result); + return result; + } // > OpTypeForwardPointer case kIROp_Func: @@ -1046,10 +1207,6 @@ struct SPIRVEmitContext // it is nullptr, this function will create one. SpvInst* ensureVectorType(BaseType baseType, IRIntegerValue elementCount, IRVectorType* inst) { - VectorTypeKey key = {baseType, elementCount}; - SpvInst* result = nullptr; - if (m_vectorTypes.TryGetValue(key, result)) - return result; if (!inst) { IRBuilder builder; @@ -1059,14 +1216,9 @@ struct SPIRVEmitContext builder.getBasicType(baseType), builder.getIntValue(builder.getIntType(), elementCount)); } - result = emitInst( - getSection(SpvLogicalSectionID::Types), - inst, - SpvOpTypeVector, - kResultID, - inst->getElementType(), - (SpvWord)elementCount); - m_vectorTypes[key] = result; + auto operands = + makeArray<SpvWord>(getID(ensureInst(inst->getElementType())), (SpvWord)elementCount); + auto result = emitTypeInst(inst, SpvOpTypeVector, operands.getView()); return result; } @@ -1139,16 +1291,13 @@ struct SPIRVEmitContext varInst, SpvDecorationBinding, (SpvWord)index); - if (space) - { - emitInst( - getSection(SpvLogicalSectionID::Annotations), - nullptr, - SpvOpDecorate, - varInst, - SpvDecorationDescriptorSet, - (SpvWord)space); - } + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationDescriptorSet, + (SpvWord)space); break; default: break; @@ -1165,6 +1314,11 @@ struct SPIRVEmitContext if (ptrType->hasAddressSpace()) storageClass = (SpvStorageClass)ptrType->getAddressSpace(); } + if (auto systemValInst = maybeEmitSystemVal(param)) + { + registerInst(param, systemValInst); + return systemValInst; + } auto varInst = emitInst( getSection(SpvLogicalSectionID::GlobalVariables), param, @@ -1304,11 +1458,25 @@ struct SPIRVEmitContext for( auto irBlock : irFunc->getBlocks() ) { emitInst(spvFunc, irBlock, SpvOpLabel, kResultID); + + // In addition to normal basic blocks, + // all loops gets a header block. + for (auto irInst : irBlock->getChildren()) + { + if (irInst->getOp() == kIROp_loop) + { + emitInst(spvFunc, irInst, SpvOpLabel, kResultID); + } + } } // Once all the basic blocks have had instructions allocated // for them, we go through and fill them in with their bodies. // + // Each loop inst results in a loop header block. + // We will defer the emit of the contents in loop header block + // until all Phi insts are emitted. + List<IRLoop*> pendingLoopInsts; for( auto irBlock : irFunc->getBlocks() ) { // Note: because we already created the block above, @@ -1334,9 +1502,20 @@ struct SPIRVEmitContext // of the block. // emitLocalInst(spvBlock, irInst); + if (irInst->getOp() == kIROp_loop) + pendingLoopInsts.add(as<IRLoop>(irInst)); } } + // Finally, we generate the body of loop header blocks. + for (auto loopInst : pendingLoopInsts) + { + SpvInst* headerBlock = nullptr; + m_mapIRInstToSpvInst.TryGetValue(loopInst, headerBlock); + SLANG_ASSERT(headerBlock); + emitLoopHeaderBlock(loopInst, headerBlock); + } + // [3.32.9. Function Instructions] // // > OpFunctionEnd @@ -1356,6 +1535,21 @@ struct SPIRVEmitContext return spvFunc; } + /// Check if a block is a loop's target block. + bool isLoopTargetBlock(IRInst* block, IRInst*& loopInst) + { + for (auto use = block->firstUse; use; use = use->nextUse) + { + if (use->getUser()->getOp() == kIROp_loop && + as<IRLoop>(use->getUser())->getTargetBlock() == block) + { + loopInst = use->getUser(); + return true; + } + } + return false; + } + // The instructions that appear inside the basic blocks of // functions are what we will call "local" instructions. // @@ -1367,13 +1561,6 @@ struct SPIRVEmitContext /// Emit an instruction that is local to the body of the given `parent`. SpvInst* emitLocalInst(SpvInstParent* parent, IRInst* inst) { - auto getBlockID = [=](IRBlock* block) - { - SpvInst* spvInst = nullptr; - m_mapIRInstToSpvInst.TryGetValue(block, spvInst); - SLANG_ASSERT(spvInst); - return getID(spvInst); - }; switch( inst->getOp() ) { default: @@ -1401,6 +1588,9 @@ struct SPIRVEmitContext return emitSwizzle(parent, as<IRSwizzle>(inst)); case kIROp_Construct: return emitConstruct(parent, inst); + case kIROp_BitCast: + return emitInst( + parent, inst, SpvOpBitcast, inst->getDataType(), kResultID, inst->getOperand(0)); case kIROp_Add: case kIROp_Sub: case kIROp_Mul: @@ -1432,50 +1622,49 @@ struct SPIRVEmitContext case kIROp_discard: return emitInst(parent, inst, SpvOpKill); case kIROp_unconditionalBranch: - return emitInst( - parent, - inst, - SpvOpBranch, - getBlockID(as<IRUnconditionalBranch>(inst)->getTargetBlock())); - case kIROp_loop: { - auto loopInst = as<IRLoop>(inst); - - SpvWord loopControl = 0; - if (auto loopControlDecoration = - loopInst->findDecoration<IRLoopControlDecoration>()) + // If we are jumping to the main block of a loop, + // emit a branch to the loop header instead. + // The SPV id of the resulting loop header block is associated with the loop inst. + auto targetBlock = as<IRUnconditionalBranch>(inst)->getTargetBlock(); + IRInst* loopInst = nullptr; + if (isLoopTargetBlock(targetBlock, loopInst)) { - switch (loopControlDecoration->getMode()) - { - case IRLoopControl::kIRLoopControl_Unroll: - loopControl = 0x1; - break; - case IRLoopControl::kIRLoopControl_Loop: - loopControl = 0x2; - break; - default: - break; - } + return emitInst(parent, inst, SpvOpBranch, getIRInstSpvID(loopInst)); } - emitInst( + // Otherwise, emit a normal branch inst into the target block. + return emitInst( parent, - nullptr, - SpvOpLoopMerge, - getBlockID(loopInst->getBreakBlock()), - getBlockID(loopInst->getContinueBlock()), - loopControl); - - return emitInst(parent, inst, SpvOpBranch, loopInst->getTargetBlock()); + inst, + SpvOpBranch, + getIRInstSpvID(targetBlock)); + } + case kIROp_loop: + { + // Return loop header block in its own block. + auto blockId = getIRInstSpvID(inst); + SpvInst* block = nullptr; + m_mapIRInstToSpvInst.TryGetValue(inst, block); + SLANG_ASSERT(block); + + // Emit a jump to the loop header block. + // Note: the body of the loop header block is emitted + // after everything else to ensure Phi instructions (which come + // from the actual loop target block) are emitted first. + emitInst(parent, nullptr, SpvOpBranch, blockId); + + return block; } case kIROp_ifElse: { auto ifelseInst = as<IRIfElse>(inst); - auto afterBlockID = getBlockID(ifelseInst->getAfterBlock()); + auto afterBlockID = getIRInstSpvID(ifelseInst->getAfterBlock()); emitInst( parent, nullptr, SpvOpSelectionMerge, - afterBlockID); + afterBlockID, + 0); auto falseLabel = ifelseInst->getFalseBlock(); return emitInst( parent, @@ -1488,11 +1677,8 @@ struct SPIRVEmitContext case kIROp_Switch: { auto switchInst = as<IRSwitch>(inst); - auto mergeBlockID = getBlockID(switchInst->getBreakLabel()); - emitInst( - parent, - nullptr, - SpvOpSelectionMerge, mergeBlockID); + auto mergeBlockID = getIRInstSpvID(switchInst->getBreakLabel()); + emitInst(parent, nullptr, SpvOpSelectionMerge, mergeBlockID, 0); return emitInstCustomOperandFunc(parent, inst, SpvOpSwitch, [&]() { emitOperand(switchInst->getCondition()); auto defaultLabel = switchInst->getDefaultLabel(); @@ -1685,7 +1871,23 @@ struct SPIRVEmitContext auto entryPointDecor = cast<IREntryPointDecoration>(decoration); auto spvStage = mapStageToExecutionModel(entryPointDecor->getProfile().getStage()); auto name = entryPointDecor->getName()->getStringSlice(); - emitInst(section, decoration, SpvOpEntryPoint, spvStage, dstID, name); + emitInstCustomOperandFunc(section, decoration, SpvOpEntryPoint, [&]() { + emitOperand(spvStage); + emitOperand(dstID); + emitOperand(name); + // `interface` part: reference all global variables that are used by this entrypoint. + // TODO: we may want to perform more accurate tracking. + for (auto globalInst : m_irModule->getModuleInst()->getChildren()) + { + switch (globalInst->getOp()) + { + case kIROp_GlobalVar: + case kIROp_GlobalParam: + emitOperand(getIRInstSpvID(globalInst)); + break; + } + } + }); } break; @@ -1713,6 +1915,24 @@ struct SPIRVEmitContext } break; + case kIROp_SPIRVBufferBlockDecoration: + { + emitInst( + getSection(SpvLogicalSectionID::Annotations), + decoration, + SpvOpDecorate, + dstID, + SpvDecorationBlock); + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpMemberDecorate, + dstID, + 0, + SpvDecorationOffset, + 0); + } + break; // ... } } @@ -1742,14 +1962,26 @@ struct SPIRVEmitContext } } - SpvInst* emitBuiltinSystemVal(SpvInstParent* parent, IRInst* inst, SpvBuiltIn builtinVal) + Dictionary<SpvBuiltIn, SpvInst*> m_builtinGlobalVars; + SpvInst* getBuiltinGlobalVar(IRType* type, SpvBuiltIn builtinVal) { + SpvInst* result = nullptr; + if (m_builtinGlobalVars.TryGetValue(builtinVal, result)) + { + return result; + } IRBuilder builder; builder.sharedBuilder = &m_sharedIRBuilder; - builder.setInsertBefore(inst); - - auto ptrIRType = builder.getPtrType(inst->getDataType()); - auto varInst = emitInst(parent, inst, SpvOpVariable, ptrIRType, kResultID); + builder.setInsertBefore(type); + auto ptrType = as<IRPtrTypeBase>(type); + SLANG_ASSERT(ptrType && "`getBuiltinGlobalVar`: `type` must be ptr type."); + auto varInst = emitInst( + getSection(SpvLogicalSectionID::GlobalVariables), + nullptr, + SpvOpVariable, + type, + kResultID, + (SpvStorageClass)ptrType->getAddressSpace()); emitInst( getSection(SpvLogicalSectionID::Annotations), nullptr, @@ -1757,11 +1989,15 @@ struct SPIRVEmitContext varInst, SpvDecorationBuiltIn, builtinVal); + m_builtinGlobalVars[builtinVal] = varInst; return varInst; } - SpvInst* emitParam(SpvInstParent* parent, IRInst* inst) + SpvInst* maybeEmitSystemVal(IRInst* inst) { + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertBefore(inst); if (auto layout = getVarLayout(inst)) { if (auto systemValueAttr = layout->findAttr<IRSystemValueSemanticAttr>()) @@ -1770,27 +2006,26 @@ struct SPIRVEmitContext semanticName = semanticName.toLower(); if (semanticName == "sv_dispatchthreadid") { - return emitBuiltinSystemVal(parent, inst, SpvBuiltInGlobalInvocationId); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInGlobalInvocationId); } } } + return nullptr; + } + + SpvInst* emitParam(SpvInstParent* parent, IRInst* inst) + { return emitInst(parent, inst, SpvOpFunctionParameter, inst->getFullType(), kResultID); } SpvInst* emitVar(SpvInstParent* parent, IRInst* inst) { - SpvWord storageClass = SpvStorageClassFunction; - auto rate = inst->getFullType()->getRate(); - if (rate) + auto ptrType = as<IRPtrTypeBase>(inst->getDataType()); + SLANG_ASSERT(ptrType); + SpvStorageClass storageClass = SpvStorageClassFunction; + if (ptrType->hasAddressSpace()) { - switch (rate->getOp()) - { - case kIROp_GroupSharedRate: - storageClass = SpvStorageClassWorkgroup; - break; - default: - break; - } + storageClass = (SpvStorageClass)ptrType->getAddressSpace(); } return emitInst(parent, inst, SpvOpVariable, inst->getFullType(), kResultID, storageClass); } @@ -1828,6 +2063,48 @@ struct SPIRVEmitContext return result; } + bool isGlobalValueInst(IRInst* inst) + { + if (as<IRConstant>(inst)) + return true; + switch (inst->getOp()) + { + case kIROp_Func: + case kIROp_GlobalParam: + case kIROp_GlobalVar: + return true; + default: + return false; + } + } + + void emitLoopHeaderBlock(IRLoop* loopInst, SpvInst* loopHeaderBlock) + { + SpvWord loopControl = 0; + if (auto loopControlDecoration = loopInst->findDecoration<IRLoopControlDecoration>()) + { + switch (loopControlDecoration->getMode()) + { + case IRLoopControl::kIRLoopControl_Unroll: + loopControl = 0x1; + break; + case IRLoopControl::kIRLoopControl_Loop: + loopControl = 0x2; + break; + default: + break; + } + } + emitInst( + loopHeaderBlock, + nullptr, + SpvOpLoopMerge, + getIRInstSpvID(loopInst->getBreakBlock()), + getIRInstSpvID(loopInst->getContinueBlock()), + loopControl); + emitInst(loopHeaderBlock, nullptr, SpvOpBranch, loopInst->getTargetBlock()); + } + SpvInst* emitPhi(SpvInstParent* parent, IRParam* inst) { // An `IRParam` in an ordinary `IRBlock` represents a phi value. @@ -1838,6 +2115,16 @@ struct SPIRVEmitContext // First, we find the index of this param. IRBlock* block = as<IRBlock>(inst->getParent()); + // Special case: if block is a loop's target block, emit phis into the header block instead. + IRInst* loopInst = nullptr; + if (isLoopTargetBlock(block, loopInst)) + { + SpvInst* loopSpvBlockInst = nullptr; + m_mapIRInstToSpvInst.TryGetValue(loopInst, loopSpvBlockInst); + SLANG_ASSERT(loopSpvBlockInst); + parent = loopSpvBlockInst; + } + SLANG_ASSERT(block); int paramIndex = getParamIndexInBlock(block, inst); @@ -1865,7 +2152,9 @@ struct SPIRVEmitContext } SLANG_ASSERT(argStartIndex + paramIndex < branchInst->getOperandCount()); auto valueInst = branchInst->getOperand(argStartIndex + paramIndex); - emitOperand(valueInst); + if (isGlobalValueInst(valueInst)) + ensureInst(valueInst); + emitOperand(getIRInstSpvID(valueInst)); auto sourceBlock = as<IRBlock>(branchInst->getParent()); SLANG_ASSERT(sourceBlock); emitOperand(getIRInstSpvID(sourceBlock)); @@ -1901,7 +2190,10 @@ struct SPIRVEmitContext { SpvSnippet* snippet = getParsedSpvSnippet(intrinsic); SpvSnippetEmitContext context; + context.irResultType = inst->getDataType(); context.resultType = ensureInst(inst->getFullType()); + context.isResultTypeFloat = isFloatType(inst->getDataType()); + context.isResultTypeSigned = isSignedType((IRType*)inst->getDataType()); for (SlangUInt i = 0; i < inst->getArgCount(); i++) { auto argInst = ensureInst(inst->getArg(i)); @@ -1933,6 +2225,89 @@ struct SPIRVEmitContext return emitSpvSnippet(parent, inst, context, snippet); } + Dictionary<SpvSnippet::ASMConstant, SpvInst*> m_spvSnippetConstantInsts; + + // Emit SPV Inst that represents a constant defined in a SpvSnippet. + SpvInst* maybeEmitSpvConstant(SpvSnippet::ASMConstant constant) + { + SpvInst* result = nullptr; + if (m_spvSnippetConstantInsts.TryGetValue(constant, result)) + return result; + + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertInto(m_irModule->getModuleInst()); + switch (constant.type) + { + case SpvSnippet::ASMType::Float: + result = emitFloatConstant(constant.floatValues[0], builder.getType(kIROp_FloatType)); + break; + case SpvSnippet::ASMType::Float2: + { + auto floatType = builder.getType(kIROp_FloatType); + auto element1 = emitFloatConstant(constant.floatValues[0], floatType); + auto element2 = emitFloatConstant(constant.floatValues[1], floatType); + result = emitInst( + getSection(SpvLogicalSectionID::Constants), + nullptr, + SpvOpConstantComposite, + builder.getVectorType(floatType, builder.getIntValue(builder.getIntType(), 2)), + kResultID, + element1, + element2); + } + case SpvSnippet::ASMType::Int: + result = emitIntConstant((IRIntegerValue)constant.intValues[0], builder.getIntType()); + break; + case SpvSnippet::ASMType::UInt2: + { + auto uintType = builder.getType(kIROp_UIntType); + auto element1 = emitIntConstant((IRIntegerValue)constant.intValues[0], uintType); + auto element2 = emitIntConstant((IRIntegerValue)constant.intValues[1], uintType); + result = emitInst( + getSection(SpvLogicalSectionID::Constants), + nullptr, + SpvOpConstantComposite, + builder.getVectorType(uintType, builder.getIntValue(builder.getIntType(), 2)), + kResultID, + element1, + element2); + } + break; + } + m_spvSnippetConstantInsts[constant] = result; + return result; + } + + // Emit SPV Inst that represents a type defined in a SpvSnippet. + void emitSpvSnippetASMTypeOperand(SpvSnippet::ASMType type) + { + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertInto(m_irModule->getModuleInst()); + IRType* irType = nullptr; + switch (type) + { + case SpvSnippet::ASMType::Float: + irType = builder.getType(kIROp_FloatType); + break; + case SpvSnippet::ASMType::Int: + irType = builder.getIntType(); + break; + case SpvSnippet::ASMType::Float2: + irType = builder.getVectorType( + builder.getType(kIROp_FloatType), builder.getIntValue(builder.getIntType(), 2)); + break; + case SpvSnippet::ASMType::UInt2: + irType = builder.getVectorType( + builder.getType(kIROp_UIntType), builder.getIntValue(builder.getIntType(), 2)); + break; + default: + break; + } + emitOperand(irType); + } + SpvInst* emitSpvSnippet( SpvInstParent* parent, IRCall* inst, @@ -1950,11 +2325,10 @@ struct SPIRVEmitContext switch (operand.type) { case SpvSnippet::ASMOperandType::SpvWord: - emitOperand((SpvWord)operand.content); + emitOperand(operand.content); break; case SpvSnippet::ASMOperandType::ObjectReference: - SLANG_ASSERT( - operand.content >= 0 && operand.content < context.argumentIds.getCount()); + SLANG_ASSERT(operand.content < (SpvWord)context.argumentIds.getCount()); emitOperand(context.argumentIds[operand.content]); break; case SpvSnippet::ASMOperandType::ResultId: @@ -1972,8 +2346,64 @@ struct SPIRVEmitContext } break; case SpvSnippet::ASMOperandType::InstReference: - SLANG_ASSERT(operand.content >= 0 && operand.content < emittedInsts.getCount()); - emitOperand(getID(emittedInsts[operand.content])); + SLANG_ASSERT(operand.content < (SpvWord)emittedInsts.getCount()); + emitOperand(emittedInsts[operand.content]); + break; + case SpvSnippet::ASMOperandType::GLSL450ExtInstSet: + emitOperand(getGLSL450ExtInst()); + break; + case SpvSnippet::ASMOperandType::FloatIntegerSelection: + if (context.isResultTypeFloat) + { + emitOperand(operand.content); + } + else + { + emitOperand(operand.content2); + } + break; + case SpvSnippet::ASMOperandType::FloatUnsignedSignedSelection: + if (context.isResultTypeFloat) + { + emitOperand(operand.content); + } + else + { + if (context.isResultTypeSigned) + { + emitOperand(operand.content3); + } + else + { + emitOperand(operand.content2); + } + } + break; + case SpvSnippet::ASMOperandType::TypeReference: + { + emitSpvSnippetASMTypeOperand((SpvSnippet::ASMType)operand.content); + } + break; + case SpvSnippet::ASMOperandType::ConstantReference: + { + auto constant = snippet->constants[operand.content]; + if (constant.type == SpvSnippet::ASMType::FloatOrDouble) + { + switch (extractBaseType(context.irResultType)) + { + case BaseType::Float: + constant.type = SpvSnippet::ASMType::Float; + break; + case BaseType::Double: + constant.type = SpvSnippet::ASMType::Double; + break; + default: + break; + } + } + SpvInst* spvConstant = maybeEmitSpvConstant(constant); + emitOperand(spvConstant); + } break; } } @@ -2048,7 +2478,7 @@ struct SPIRVEmitContext baseId = getID(varInst); } SLANG_ASSERT(baseStructType && "field_address require base to be a struct."); - auto fieldId = emitConstant( + auto fieldId = emitIntConstant( getStructFieldId(baseStructType, as<IRStructKey>(fieldAddress->getField())), builder.getIntType()); return emitInst( @@ -2069,7 +2499,7 @@ struct SPIRVEmitContext IRStructType* baseStructType = as<IRStructType>(inst->getBase()->getDataType()); SLANG_ASSERT(baseStructType && "field_extract require base to be a struct."); - auto fieldId = emitConstant( + auto fieldId = emitIntConstant( getStructFieldId(baseStructType, as<IRStructKey>(inst->getField())), builder.getIntType()); @@ -2163,17 +2593,31 @@ struct SPIRVEmitContext SpvInst* emitSwizzle(SpvInstParent* parent, IRSwizzle* inst) { - return emitInstCustomOperandFunc(parent, inst, SpvOpVectorShuffle, [&]() { - emitOperand(inst->getDataType()); - emitOperand(kResultID); - emitOperand(inst->getBase()); - emitOperand(inst->getBase()); - for (UInt i = 0; i < inst->getElementCount(); i++) - { - auto index = as<IRIntLit>(inst->getElementIndex(i)); - emitOperand((SpvWord)index->getValue()); - } - }); + if (inst->getElementCount() == 1) + { + return emitInst( + parent, + inst, + SpvOpCompositeExtract, + inst->getDataType(), + kResultID, + inst->getBase(), + (SpvWord)as<IRIntLit>(inst->getElementIndex(0))->getValue()); + } + else + { + return emitInstCustomOperandFunc(parent, inst, SpvOpVectorShuffle, [&]() { + emitOperand(inst->getDataType()); + emitOperand(kResultID); + emitOperand(inst->getBase()); + emitOperand(inst->getBase()); + for (UInt i = 0; i < inst->getElementCount(); i++) + { + auto index = as<IRIntLit>(inst->getElementIndex(i)); + emitOperand((SpvWord)index->getValue()); + } + }); + } } SpvInst* emitConstruct(SpvInstParent* parent, IRInst* inst) @@ -2183,9 +2627,21 @@ struct SPIRVEmitContext if (inst->getOperandCount() == 1) { if (inst->getDataType() == inst->getOperand(0)->getDataType()) - return emitInst(parent, inst, SpvOpCopyObject, kResultID, inst->getOperand(0)); + return emitInst( + parent, + inst, + SpvOpCopyObject, + inst->getFullType(), + kResultID, + inst->getOperand(0)); else - return emitInst(parent, inst, SpvOpBitcast, inst->getDataType(), kResultID, inst->getOperand(0)); + return emitInst( + parent, + inst, + SpvOpBitcast, + inst->getFullType(), + kResultID, + inst->getOperand(0)); } else { @@ -2205,18 +2661,39 @@ struct SPIRVEmitContext } } - bool isSignedType(IRBasicType* basicType) + bool isSignedType(IRType* type) { - switch (basicType->getBaseType()) + switch (type->getOp()) { - case BaseType::Float: - case BaseType::Double: + case kIROp_FloatType: + case kIROp_DoubleType: return true; - case BaseType::Int: - case BaseType::Int16: - case BaseType::Int64: - case BaseType::Int8: + case kIROp_IntType: + case kIROp_Int16Type: + case kIROp_Int64Type: + case kIROp_Int8Type: return true; + case kIROp_VectorType: + return isSignedType(as<IRVectorType>(type)->getElementType()); + case kIROp_MatrixType: + return isSignedType(as<IRMatrixType>(type)->getElementType()); + default: + return false; + } + } + + bool isFloatType(IRInst* type) + { + switch (type->getOp()) + { + case kIROp_FloatType: + case kIROp_DoubleType: + case kIROp_HalfType: + return true; + case kIROp_VectorType: + return isFloatType(as<IRVectorType>(type)->getElementType()); + case kIROp_MatrixType: + return isFloatType(as<IRMatrixType>(type)->getElementType()); default: return false; } @@ -2224,7 +2701,7 @@ struct SPIRVEmitContext SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) { - IRType* elementType = inst->getDataType(); + IRType* elementType = inst->getOperand(0)->getDataType(); if (auto vectorType = as<IRVectorType>(inst->getDataType())) { elementType = vectorType->getElementType(); @@ -2245,6 +2722,7 @@ struct SPIRVEmitContext break; case BaseType::Bool: isBool = true; + break; default: break; } @@ -2371,6 +2849,12 @@ struct SPIRVEmitContext } } + void diagnoseUnhandledInst(IRInst* inst) + { + m_sink->diagnose( + inst, Diagnostics::unimplemented, "unexpected IR opcode during code emit"); + } + SPIRVEmitContext(IRModule* module, TargetRequest* target, DiagnosticSink* sink) : SPIRVEmitSharedContext(module, target) , m_irModule(module) @@ -2390,7 +2874,7 @@ SlangResult emitSPIRVFromIR( spirvOut.clear(); SPIRVEmitContext context(irModule, targetRequest, compileRequest->getSink()); - legalizeIRForSPIRV(&context, irModule, compileRequest->getSink()); + legalizeIRForSPIRV(&context, irModule, irEntryPoints, compileRequest->getSink()); context.emitFrontMatter(); for (auto irEntryPoint : irEntryPoints) |
