summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2021-08-17 09:39:02 -0700
committerGitHub <noreply@github.com>2021-08-17 09:39:02 -0700
commit858c7c57b125afed9b5b2329d6b02477284e4803 (patch)
tree49f67b342448dcfb19913d8ccc089d956de14462 /source/slang/slang-emit-spirv.cpp
parent6406523511037987d8b8ab881aea41389afd57eb (diff)
Add GLSL450 intrinsics to SPIRV direct emit. (#1921)
* Add GLSL450 intrinsics to SPIRV direct emit. * Fix. * Fix compiler error. * Fix. * Fix compiler error. * Make direct-spirv tests actually run.
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp786
1 files changed, 635 insertions, 151 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 37fd673ed..21f5c1bc8 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -266,6 +266,11 @@ void SpvInstParent::dumpTo(List<SpvWord>& ioWords)
struct SpvSnippetEmitContext
{
SpvInst* resultType;
+ IRType* irResultType;
+ // True if resultType is float or vector of float.
+ bool isResultTypeFloat;
+ // True if resultType is signed.
+ bool isResultTypeSigned;
Dictionary<SpvStorageClass, IRInst*> qualifiedResultTypes;
List<SpvWord> argumentIds;
};
@@ -401,6 +406,16 @@ struct SPIRVEmitContext
void registerInst(IRInst* irInst, SpvInst* spvInst)
{
m_mapIRInstToSpvInst.Add(irInst, spvInst);
+
+ // If we have reserved an SpvID for `irInst`, make sure to use it.
+ SpvWord reservedID = 0;
+ m_mapIRInstToSpvID.TryGetValue(irInst, reservedID);
+
+ if (reservedID)
+ {
+ SLANG_ASSERT(spvInst->id == 0);
+ spvInst->id = reservedID;
+ }
}
/// Get or reserve a SpvID for an IR value.
@@ -439,18 +454,6 @@ struct SPIRVEmitContext
return id;
}
- struct VectorTypeKey
- {
- BaseType baseType;
- IRIntegerValue elementCount;
- HashCode getHashCode() { return combineHash((int)baseType, (HashCode)elementCount); }
- bool operator==(const VectorTypeKey& other)
- {
- return baseType == other.baseType && elementCount == other.elementCount;
- }
- };
- Dictionary<VectorTypeKey, SpvInst*> m_vectorTypes;
-
// We will build up `SpvInst`s in a stateful fashion,
// mostly for convenience. We could in theory compute
// the number of words each instruction needs, then allocate
@@ -509,8 +512,6 @@ struct SPIRVEmitContext
if(irInst)
{
registerInst(irInst, spvInst);
- // If we have reserved an SpvID for `irInst`, make sure to use it.
- m_mapIRInstToSpvID.TryGetValue(irInst, spvInst->id);
}
// Set up the scope
@@ -675,19 +676,93 @@ struct SPIRVEmitContext
void emitOperand(SpvBuiltIn builtin) { emitOperand((SpvWord)builtin); }
void emitOperand(SpvStorageClass val) { emitOperand((SpvWord)val); }
- Dictionary<IRIntegerValue, SpvInst*> m_spvIntConstants;
- SpvInst* emitConstant(IRIntegerValue val, IRType* type)
+ template<typename TConstant>
+ struct ConstantValueKey
+ {
+ IRType* type;
+ TConstant value;
+ HashCode getHashCode() const
+ {
+ return combineHash(Slang::getHashCode(type), Slang::getHashCode(value));
+ }
+ bool operator==(const ConstantValueKey& other) const
+ {
+ return type == other.type && value == other.value;
+ }
+ };
+ Dictionary<ConstantValueKey<IRIntegerValue>, SpvInst*> m_spvIntConstants;
+ Dictionary<ConstantValueKey<IRFloatingPointValue>, SpvInst*> m_spvFloatConstants;
+ SpvInst* emitIntConstant(IRIntegerValue val, IRType* type)
{
+ ConstantValueKey<IRIntegerValue> key;
+ key.value = val;
+ key.type = type;
SpvInst* result = nullptr;
- if (m_spvIntConstants.TryGetValue(val, result))
+ if (m_spvIntConstants.TryGetValue(key, result))
return result;
- return emitInst(
- getSection(SpvLogicalSectionID::Constants),
- nullptr,
- SpvOpConstant,
- type,
- kResultID,
- (SpvWord)val);
+ SpvWord valWord;
+ memcpy(&valWord, &val, sizeof(SpvWord));
+ if (type->getOp() == kIROp_Int64Type || type->getOp() == kIROp_UInt64Type)
+ {
+ SpvWord valHighWord;
+ memcpy(&valHighWord, (char*)(&val) + 4, sizeof(SpvWord));
+ result = emitInst(
+ getSection(SpvLogicalSectionID::Constants),
+ nullptr,
+ SpvOpConstant,
+ type,
+ kResultID,
+ valWord,
+ valHighWord);
+ }
+ else
+ {
+ result = emitInst(
+ getSection(SpvLogicalSectionID::Constants),
+ nullptr,
+ SpvOpConstant,
+ type,
+ kResultID,
+ valWord);
+ }
+ m_spvIntConstants[key] = result;
+ return result;
+ }
+ SpvInst* emitFloatConstant(IRFloatingPointValue val, IRType* type)
+ {
+ ConstantValueKey<IRFloatingPointValue> key;
+ key.value = val;
+ key.type = type;
+ SpvInst* result = nullptr;
+ if (m_spvFloatConstants.TryGetValue(key, result))
+ return result;
+ SpvWord valWord;
+ memcpy(&valWord, &val, sizeof(SpvWord));
+ if (type->getOp() == kIROp_DoubleType)
+ {
+ SpvWord valHighWord;
+ memcpy(&valHighWord, (char*)(&val) + 4, sizeof(SpvWord));
+ result = emitInst(
+ getSection(SpvLogicalSectionID::Constants),
+ nullptr,
+ SpvOpConstant,
+ type,
+ kResultID,
+ valWord,
+ valHighWord);
+ }
+ else
+ {
+ result = emitInst(
+ getSection(SpvLogicalSectionID::Constants),
+ nullptr,
+ SpvOpConstant,
+ type,
+ kResultID,
+ valWord);
+ }
+ m_spvFloatConstants[key] = result;
+ return result;
}
// As another convenience, there are often cases where
// we will want to emit all of the operands of some
@@ -812,6 +887,22 @@ struct SPIRVEmitContext
return spvInst;
}
+ /// The SPIRV OpExtInstImport inst that represents the GLSL450
+ /// extended instruction set.
+ SpvInst* m_glsl450ExtInst = nullptr;
+
+ SpvInst* getGLSL450ExtInst()
+ {
+ if (m_glsl450ExtInst)
+ return m_glsl450ExtInst;
+ m_glsl450ExtInst = emitInst(
+ getSection(SpvLogicalSectionID::ExtIntInstImports),
+ nullptr,
+ SpvOpExtInstImport,
+ UnownedStringSlice("GLSL.std.450"));
+ return m_glsl450ExtInst;
+ }
+
// Now that we've gotten the core infrastructure out of the way,
// let's start looking at emitting some instructions that make
// up a SPIR-V module.
@@ -849,6 +940,66 @@ struct SPIRVEmitContext
emitInst(getSection(SpvLogicalSectionID::MemoryModel), nullptr, SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450);
}
+ Dictionary<UnownedStringSlice, SpvInst*> m_extensionInsts;
+ SpvInst* ensureExtensionDeclaration(UnownedStringSlice name)
+ {
+ SpvInst* result = nullptr;
+ if (m_extensionInsts.TryGetValue(name, result))
+ return result;
+ result =
+ emitInst(getSection(SpvLogicalSectionID::Extensions), nullptr, SpvOpExtension, name);
+ m_extensionInsts[name] = result;
+ return result;
+ }
+
+ struct SpvTypeInstKey
+ {
+ List<SpvWord> words;
+ bool operator==(const SpvTypeInstKey& other)
+ {
+ if (words.getCount() != other.words.getCount())
+ return false;
+ for (Index i = 0; i < words.getCount(); i++)
+ if (words[i] != other.words[i])
+ return false;
+ return true;
+ }
+ HashCode getHashCode()
+ {
+ HashCode result = 0;
+ for (auto word : words)
+ result = combineHash(result, word);
+ return result;
+ }
+ };
+
+ Dictionary<SpvTypeInstKey, SpvInst*> m_spvTypeInsts;
+
+ // Emits a SPV Inst that represents a type, with deduplications since
+ // our IR doesn't currently guarantee types are unique in generated SPV.
+ SpvInst* emitTypeInst(IRInst* typeInst, SpvOp opcode, ArrayView<SpvWord> operands)
+ {
+ SpvTypeInstKey key;
+ key.words.add((SpvWord)opcode);
+ for (auto op : operands)
+ key.words.add(op);
+ SpvInst* result = nullptr;
+ if (m_spvTypeInsts.TryGetValue(key, result))
+ {
+ return result;
+ }
+ result = emitInstCustomOperandFunc(
+ getSection(SpvLogicalSectionID::Types), typeInst, opcode, [&]() {
+ emitOperand(kResultID);
+ for (auto op : operands)
+ {
+ emitOperand(op);
+ }
+ });
+ m_spvTypeInsts[key] = result;
+ return result;
+ }
+
// Next, let's look at emitting some of the instructions
// that can occur at global scope.
@@ -864,7 +1015,7 @@ struct SPIRVEmitContext
//
#define CASE(IROP, SPVOP) \
- case IROP: return emitInst(getSection(SpvLogicalSectionID::Types), inst, SPVOP, kResultID)
+ case IROP: return emitTypeInst(inst, SPVOP, ArrayView<SpvWord>());
// > OpTypeVoid
CASE(kIROp_VoidType, SpvOpTypeVoid);
@@ -877,7 +1028,8 @@ struct SPIRVEmitContext
// > OpTypeInt
#define CASE(IROP, BITS, SIGNED) \
- case IROP: return emitInst(getSection(SpvLogicalSectionID::Types), inst, SpvOpTypeInt, kResultID, BITS, SIGNED)
+ case IROP: \
+ return emitTypeInst(inst, SpvOpTypeInt, makeArray<SpvWord>((SpvWord)BITS, (SpvWord)SIGNED).getView());
CASE(kIROp_IntType, 32, 1);
CASE(kIROp_UIntType, 32, 0);
@@ -889,7 +1041,9 @@ struct SPIRVEmitContext
// > OpTypeFloat
#define CASE(IROP, BITS) \
- case IROP: return emitInst(getSection(SpvLogicalSectionID::Types), inst, SpvOpTypeFloat, kResultID, BITS)
+ case IROP: \
+ return emitTypeInst( \
+ inst, SpvOpTypeFloat, makeArray<SpvWord>(BITS).getView()); \
CASE(kIROp_HalfType, 16);
CASE(kIROp_FloatType, 32);
@@ -905,17 +1059,16 @@ struct SPIRVEmitContext
auto ptrType = as<IRPtrTypeBase>(inst);
if (ptrType->hasAddressSpace())
storageClass = (SpvStorageClass)ptrType->getAddressSpace();
- return emitInst(
- getSection(SpvLogicalSectionID::Types),
- inst,
- SpvOpTypePointer,
- kResultID,
- storageClass,
- inst->getOperand(0));
+ if (storageClass == SpvStorageClassStorageBuffer)
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_storage_buffer_storage_class"));
+ auto operands = makeArray<SpvWord>(
+ (SpvWord)storageClass, getID(ensureInst(inst->getOperand(0))));
+ return emitTypeInst(
+ inst, SpvOpTypePointer, operands.getView());
}
case kIROp_StructType:
{
- return emitInstCustomOperandFunc(
+ auto spvStructType = emitInstCustomOperandFunc(
getSection(SpvLogicalSectionID::Types), inst, SpvOpTypeStruct, [&]() {
emitOperand(kResultID);
for (auto field : static_cast<IRStructType*>(inst)->getFields())
@@ -924,6 +1077,8 @@ struct SPIRVEmitContext
// TODO: decorate offset
}
});
+ emitDecorations(inst, getID(spvStructType));
+ return spvStructType;
}
case kIROp_VectorType:
{
@@ -1012,6 +1167,12 @@ struct SPIRVEmitContext
//
return emitInst(getSection(SpvLogicalSectionID::Types), inst, SpvOpTypeFunction, kResultID, OperandsOf(inst));
+ case kIROp_RateQualifiedType:
+ {
+ auto result = emitGlobalInst(as<IRRateQualifiedType>(inst)->getValueType());
+ registerInst(inst, result);
+ return result;
+ }
// > OpTypeForwardPointer
case kIROp_Func:
@@ -1046,10 +1207,6 @@ struct SPIRVEmitContext
// it is nullptr, this function will create one.
SpvInst* ensureVectorType(BaseType baseType, IRIntegerValue elementCount, IRVectorType* inst)
{
- VectorTypeKey key = {baseType, elementCount};
- SpvInst* result = nullptr;
- if (m_vectorTypes.TryGetValue(key, result))
- return result;
if (!inst)
{
IRBuilder builder;
@@ -1059,14 +1216,9 @@ struct SPIRVEmitContext
builder.getBasicType(baseType),
builder.getIntValue(builder.getIntType(), elementCount));
}
- result = emitInst(
- getSection(SpvLogicalSectionID::Types),
- inst,
- SpvOpTypeVector,
- kResultID,
- inst->getElementType(),
- (SpvWord)elementCount);
- m_vectorTypes[key] = result;
+ auto operands =
+ makeArray<SpvWord>(getID(ensureInst(inst->getElementType())), (SpvWord)elementCount);
+ auto result = emitTypeInst(inst, SpvOpTypeVector, operands.getView());
return result;
}
@@ -1139,16 +1291,13 @@ struct SPIRVEmitContext
varInst,
SpvDecorationBinding,
(SpvWord)index);
- if (space)
- {
- emitInst(
- getSection(SpvLogicalSectionID::Annotations),
- nullptr,
- SpvOpDecorate,
- varInst,
- SpvDecorationDescriptorSet,
- (SpvWord)space);
- }
+ emitInst(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ SpvOpDecorate,
+ varInst,
+ SpvDecorationDescriptorSet,
+ (SpvWord)space);
break;
default:
break;
@@ -1165,6 +1314,11 @@ struct SPIRVEmitContext
if (ptrType->hasAddressSpace())
storageClass = (SpvStorageClass)ptrType->getAddressSpace();
}
+ if (auto systemValInst = maybeEmitSystemVal(param))
+ {
+ registerInst(param, systemValInst);
+ return systemValInst;
+ }
auto varInst = emitInst(
getSection(SpvLogicalSectionID::GlobalVariables),
param,
@@ -1304,11 +1458,25 @@ struct SPIRVEmitContext
for( auto irBlock : irFunc->getBlocks() )
{
emitInst(spvFunc, irBlock, SpvOpLabel, kResultID);
+
+ // In addition to normal basic blocks,
+ // all loops gets a header block.
+ for (auto irInst : irBlock->getChildren())
+ {
+ if (irInst->getOp() == kIROp_loop)
+ {
+ emitInst(spvFunc, irInst, SpvOpLabel, kResultID);
+ }
+ }
}
// Once all the basic blocks have had instructions allocated
// for them, we go through and fill them in with their bodies.
//
+ // Each loop inst results in a loop header block.
+ // We will defer the emit of the contents in loop header block
+ // until all Phi insts are emitted.
+ List<IRLoop*> pendingLoopInsts;
for( auto irBlock : irFunc->getBlocks() )
{
// Note: because we already created the block above,
@@ -1334,9 +1502,20 @@ struct SPIRVEmitContext
// of the block.
//
emitLocalInst(spvBlock, irInst);
+ if (irInst->getOp() == kIROp_loop)
+ pendingLoopInsts.add(as<IRLoop>(irInst));
}
}
+ // Finally, we generate the body of loop header blocks.
+ for (auto loopInst : pendingLoopInsts)
+ {
+ SpvInst* headerBlock = nullptr;
+ m_mapIRInstToSpvInst.TryGetValue(loopInst, headerBlock);
+ SLANG_ASSERT(headerBlock);
+ emitLoopHeaderBlock(loopInst, headerBlock);
+ }
+
// [3.32.9. Function Instructions]
//
// > OpFunctionEnd
@@ -1356,6 +1535,21 @@ struct SPIRVEmitContext
return spvFunc;
}
+ /// Check if a block is a loop's target block.
+ bool isLoopTargetBlock(IRInst* block, IRInst*& loopInst)
+ {
+ for (auto use = block->firstUse; use; use = use->nextUse)
+ {
+ if (use->getUser()->getOp() == kIROp_loop &&
+ as<IRLoop>(use->getUser())->getTargetBlock() == block)
+ {
+ loopInst = use->getUser();
+ return true;
+ }
+ }
+ return false;
+ }
+
// The instructions that appear inside the basic blocks of
// functions are what we will call "local" instructions.
//
@@ -1367,13 +1561,6 @@ struct SPIRVEmitContext
/// Emit an instruction that is local to the body of the given `parent`.
SpvInst* emitLocalInst(SpvInstParent* parent, IRInst* inst)
{
- auto getBlockID = [=](IRBlock* block)
- {
- SpvInst* spvInst = nullptr;
- m_mapIRInstToSpvInst.TryGetValue(block, spvInst);
- SLANG_ASSERT(spvInst);
- return getID(spvInst);
- };
switch( inst->getOp() )
{
default:
@@ -1401,6 +1588,9 @@ struct SPIRVEmitContext
return emitSwizzle(parent, as<IRSwizzle>(inst));
case kIROp_Construct:
return emitConstruct(parent, inst);
+ case kIROp_BitCast:
+ return emitInst(
+ parent, inst, SpvOpBitcast, inst->getDataType(), kResultID, inst->getOperand(0));
case kIROp_Add:
case kIROp_Sub:
case kIROp_Mul:
@@ -1432,50 +1622,49 @@ struct SPIRVEmitContext
case kIROp_discard:
return emitInst(parent, inst, SpvOpKill);
case kIROp_unconditionalBranch:
- return emitInst(
- parent,
- inst,
- SpvOpBranch,
- getBlockID(as<IRUnconditionalBranch>(inst)->getTargetBlock()));
- case kIROp_loop:
{
- auto loopInst = as<IRLoop>(inst);
-
- SpvWord loopControl = 0;
- if (auto loopControlDecoration =
- loopInst->findDecoration<IRLoopControlDecoration>())
+ // If we are jumping to the main block of a loop,
+ // emit a branch to the loop header instead.
+ // The SPV id of the resulting loop header block is associated with the loop inst.
+ auto targetBlock = as<IRUnconditionalBranch>(inst)->getTargetBlock();
+ IRInst* loopInst = nullptr;
+ if (isLoopTargetBlock(targetBlock, loopInst))
{
- switch (loopControlDecoration->getMode())
- {
- case IRLoopControl::kIRLoopControl_Unroll:
- loopControl = 0x1;
- break;
- case IRLoopControl::kIRLoopControl_Loop:
- loopControl = 0x2;
- break;
- default:
- break;
- }
+ return emitInst(parent, inst, SpvOpBranch, getIRInstSpvID(loopInst));
}
- emitInst(
+ // Otherwise, emit a normal branch inst into the target block.
+ return emitInst(
parent,
- nullptr,
- SpvOpLoopMerge,
- getBlockID(loopInst->getBreakBlock()),
- getBlockID(loopInst->getContinueBlock()),
- loopControl);
-
- return emitInst(parent, inst, SpvOpBranch, loopInst->getTargetBlock());
+ inst,
+ SpvOpBranch,
+ getIRInstSpvID(targetBlock));
+ }
+ case kIROp_loop:
+ {
+ // Return loop header block in its own block.
+ auto blockId = getIRInstSpvID(inst);
+ SpvInst* block = nullptr;
+ m_mapIRInstToSpvInst.TryGetValue(inst, block);
+ SLANG_ASSERT(block);
+
+ // Emit a jump to the loop header block.
+ // Note: the body of the loop header block is emitted
+ // after everything else to ensure Phi instructions (which come
+ // from the actual loop target block) are emitted first.
+ emitInst(parent, nullptr, SpvOpBranch, blockId);
+
+ return block;
}
case kIROp_ifElse:
{
auto ifelseInst = as<IRIfElse>(inst);
- auto afterBlockID = getBlockID(ifelseInst->getAfterBlock());
+ auto afterBlockID = getIRInstSpvID(ifelseInst->getAfterBlock());
emitInst(
parent,
nullptr,
SpvOpSelectionMerge,
- afterBlockID);
+ afterBlockID,
+ 0);
auto falseLabel = ifelseInst->getFalseBlock();
return emitInst(
parent,
@@ -1488,11 +1677,8 @@ struct SPIRVEmitContext
case kIROp_Switch:
{
auto switchInst = as<IRSwitch>(inst);
- auto mergeBlockID = getBlockID(switchInst->getBreakLabel());
- emitInst(
- parent,
- nullptr,
- SpvOpSelectionMerge, mergeBlockID);
+ auto mergeBlockID = getIRInstSpvID(switchInst->getBreakLabel());
+ emitInst(parent, nullptr, SpvOpSelectionMerge, mergeBlockID, 0);
return emitInstCustomOperandFunc(parent, inst, SpvOpSwitch, [&]() {
emitOperand(switchInst->getCondition());
auto defaultLabel = switchInst->getDefaultLabel();
@@ -1685,7 +1871,23 @@ struct SPIRVEmitContext
auto entryPointDecor = cast<IREntryPointDecoration>(decoration);
auto spvStage = mapStageToExecutionModel(entryPointDecor->getProfile().getStage());
auto name = entryPointDecor->getName()->getStringSlice();
- emitInst(section, decoration, SpvOpEntryPoint, spvStage, dstID, name);
+ emitInstCustomOperandFunc(section, decoration, SpvOpEntryPoint, [&]() {
+ emitOperand(spvStage);
+ emitOperand(dstID);
+ emitOperand(name);
+ // `interface` part: reference all global variables that are used by this entrypoint.
+ // TODO: we may want to perform more accurate tracking.
+ for (auto globalInst : m_irModule->getModuleInst()->getChildren())
+ {
+ switch (globalInst->getOp())
+ {
+ case kIROp_GlobalVar:
+ case kIROp_GlobalParam:
+ emitOperand(getIRInstSpvID(globalInst));
+ break;
+ }
+ }
+ });
}
break;
@@ -1713,6 +1915,24 @@ struct SPIRVEmitContext
}
break;
+ case kIROp_SPIRVBufferBlockDecoration:
+ {
+ emitInst(
+ getSection(SpvLogicalSectionID::Annotations),
+ decoration,
+ SpvOpDecorate,
+ dstID,
+ SpvDecorationBlock);
+ emitInst(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ SpvOpMemberDecorate,
+ dstID,
+ 0,
+ SpvDecorationOffset,
+ 0);
+ }
+ break;
// ...
}
}
@@ -1742,14 +1962,26 @@ struct SPIRVEmitContext
}
}
- SpvInst* emitBuiltinSystemVal(SpvInstParent* parent, IRInst* inst, SpvBuiltIn builtinVal)
+ Dictionary<SpvBuiltIn, SpvInst*> m_builtinGlobalVars;
+ SpvInst* getBuiltinGlobalVar(IRType* type, SpvBuiltIn builtinVal)
{
+ SpvInst* result = nullptr;
+ if (m_builtinGlobalVars.TryGetValue(builtinVal, result))
+ {
+ return result;
+ }
IRBuilder builder;
builder.sharedBuilder = &m_sharedIRBuilder;
- builder.setInsertBefore(inst);
-
- auto ptrIRType = builder.getPtrType(inst->getDataType());
- auto varInst = emitInst(parent, inst, SpvOpVariable, ptrIRType, kResultID);
+ builder.setInsertBefore(type);
+ auto ptrType = as<IRPtrTypeBase>(type);
+ SLANG_ASSERT(ptrType && "`getBuiltinGlobalVar`: `type` must be ptr type.");
+ auto varInst = emitInst(
+ getSection(SpvLogicalSectionID::GlobalVariables),
+ nullptr,
+ SpvOpVariable,
+ type,
+ kResultID,
+ (SpvStorageClass)ptrType->getAddressSpace());
emitInst(
getSection(SpvLogicalSectionID::Annotations),
nullptr,
@@ -1757,11 +1989,15 @@ struct SPIRVEmitContext
varInst,
SpvDecorationBuiltIn,
builtinVal);
+ m_builtinGlobalVars[builtinVal] = varInst;
return varInst;
}
- SpvInst* emitParam(SpvInstParent* parent, IRInst* inst)
+ SpvInst* maybeEmitSystemVal(IRInst* inst)
{
+ IRBuilder builder;
+ builder.sharedBuilder = &m_sharedIRBuilder;
+ builder.setInsertBefore(inst);
if (auto layout = getVarLayout(inst))
{
if (auto systemValueAttr = layout->findAttr<IRSystemValueSemanticAttr>())
@@ -1770,27 +2006,26 @@ struct SPIRVEmitContext
semanticName = semanticName.toLower();
if (semanticName == "sv_dispatchthreadid")
{
- return emitBuiltinSystemVal(parent, inst, SpvBuiltInGlobalInvocationId);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInGlobalInvocationId);
}
}
}
+ return nullptr;
+ }
+
+ SpvInst* emitParam(SpvInstParent* parent, IRInst* inst)
+ {
return emitInst(parent, inst, SpvOpFunctionParameter, inst->getFullType(), kResultID);
}
SpvInst* emitVar(SpvInstParent* parent, IRInst* inst)
{
- SpvWord storageClass = SpvStorageClassFunction;
- auto rate = inst->getFullType()->getRate();
- if (rate)
+ auto ptrType = as<IRPtrTypeBase>(inst->getDataType());
+ SLANG_ASSERT(ptrType);
+ SpvStorageClass storageClass = SpvStorageClassFunction;
+ if (ptrType->hasAddressSpace())
{
- switch (rate->getOp())
- {
- case kIROp_GroupSharedRate:
- storageClass = SpvStorageClassWorkgroup;
- break;
- default:
- break;
- }
+ storageClass = (SpvStorageClass)ptrType->getAddressSpace();
}
return emitInst(parent, inst, SpvOpVariable, inst->getFullType(), kResultID, storageClass);
}
@@ -1828,6 +2063,48 @@ struct SPIRVEmitContext
return result;
}
+ bool isGlobalValueInst(IRInst* inst)
+ {
+ if (as<IRConstant>(inst))
+ return true;
+ switch (inst->getOp())
+ {
+ case kIROp_Func:
+ case kIROp_GlobalParam:
+ case kIROp_GlobalVar:
+ return true;
+ default:
+ return false;
+ }
+ }
+
+ void emitLoopHeaderBlock(IRLoop* loopInst, SpvInst* loopHeaderBlock)
+ {
+ SpvWord loopControl = 0;
+ if (auto loopControlDecoration = loopInst->findDecoration<IRLoopControlDecoration>())
+ {
+ switch (loopControlDecoration->getMode())
+ {
+ case IRLoopControl::kIRLoopControl_Unroll:
+ loopControl = 0x1;
+ break;
+ case IRLoopControl::kIRLoopControl_Loop:
+ loopControl = 0x2;
+ break;
+ default:
+ break;
+ }
+ }
+ emitInst(
+ loopHeaderBlock,
+ nullptr,
+ SpvOpLoopMerge,
+ getIRInstSpvID(loopInst->getBreakBlock()),
+ getIRInstSpvID(loopInst->getContinueBlock()),
+ loopControl);
+ emitInst(loopHeaderBlock, nullptr, SpvOpBranch, loopInst->getTargetBlock());
+ }
+
SpvInst* emitPhi(SpvInstParent* parent, IRParam* inst)
{
// An `IRParam` in an ordinary `IRBlock` represents a phi value.
@@ -1838,6 +2115,16 @@ struct SPIRVEmitContext
// First, we find the index of this param.
IRBlock* block = as<IRBlock>(inst->getParent());
+ // Special case: if block is a loop's target block, emit phis into the header block instead.
+ IRInst* loopInst = nullptr;
+ if (isLoopTargetBlock(block, loopInst))
+ {
+ SpvInst* loopSpvBlockInst = nullptr;
+ m_mapIRInstToSpvInst.TryGetValue(loopInst, loopSpvBlockInst);
+ SLANG_ASSERT(loopSpvBlockInst);
+ parent = loopSpvBlockInst;
+ }
+
SLANG_ASSERT(block);
int paramIndex = getParamIndexInBlock(block, inst);
@@ -1865,7 +2152,9 @@ struct SPIRVEmitContext
}
SLANG_ASSERT(argStartIndex + paramIndex < branchInst->getOperandCount());
auto valueInst = branchInst->getOperand(argStartIndex + paramIndex);
- emitOperand(valueInst);
+ if (isGlobalValueInst(valueInst))
+ ensureInst(valueInst);
+ emitOperand(getIRInstSpvID(valueInst));
auto sourceBlock = as<IRBlock>(branchInst->getParent());
SLANG_ASSERT(sourceBlock);
emitOperand(getIRInstSpvID(sourceBlock));
@@ -1901,7 +2190,10 @@ struct SPIRVEmitContext
{
SpvSnippet* snippet = getParsedSpvSnippet(intrinsic);
SpvSnippetEmitContext context;
+ context.irResultType = inst->getDataType();
context.resultType = ensureInst(inst->getFullType());
+ context.isResultTypeFloat = isFloatType(inst->getDataType());
+ context.isResultTypeSigned = isSignedType((IRType*)inst->getDataType());
for (SlangUInt i = 0; i < inst->getArgCount(); i++)
{
auto argInst = ensureInst(inst->getArg(i));
@@ -1933,6 +2225,89 @@ struct SPIRVEmitContext
return emitSpvSnippet(parent, inst, context, snippet);
}
+ Dictionary<SpvSnippet::ASMConstant, SpvInst*> m_spvSnippetConstantInsts;
+
+ // Emit SPV Inst that represents a constant defined in a SpvSnippet.
+ SpvInst* maybeEmitSpvConstant(SpvSnippet::ASMConstant constant)
+ {
+ SpvInst* result = nullptr;
+ if (m_spvSnippetConstantInsts.TryGetValue(constant, result))
+ return result;
+
+ IRBuilder builder;
+ builder.sharedBuilder = &m_sharedIRBuilder;
+ builder.setInsertInto(m_irModule->getModuleInst());
+ switch (constant.type)
+ {
+ case SpvSnippet::ASMType::Float:
+ result = emitFloatConstant(constant.floatValues[0], builder.getType(kIROp_FloatType));
+ break;
+ case SpvSnippet::ASMType::Float2:
+ {
+ auto floatType = builder.getType(kIROp_FloatType);
+ auto element1 = emitFloatConstant(constant.floatValues[0], floatType);
+ auto element2 = emitFloatConstant(constant.floatValues[1], floatType);
+ result = emitInst(
+ getSection(SpvLogicalSectionID::Constants),
+ nullptr,
+ SpvOpConstantComposite,
+ builder.getVectorType(floatType, builder.getIntValue(builder.getIntType(), 2)),
+ kResultID,
+ element1,
+ element2);
+ }
+ case SpvSnippet::ASMType::Int:
+ result = emitIntConstant((IRIntegerValue)constant.intValues[0], builder.getIntType());
+ break;
+ case SpvSnippet::ASMType::UInt2:
+ {
+ auto uintType = builder.getType(kIROp_UIntType);
+ auto element1 = emitIntConstant((IRIntegerValue)constant.intValues[0], uintType);
+ auto element2 = emitIntConstant((IRIntegerValue)constant.intValues[1], uintType);
+ result = emitInst(
+ getSection(SpvLogicalSectionID::Constants),
+ nullptr,
+ SpvOpConstantComposite,
+ builder.getVectorType(uintType, builder.getIntValue(builder.getIntType(), 2)),
+ kResultID,
+ element1,
+ element2);
+ }
+ break;
+ }
+ m_spvSnippetConstantInsts[constant] = result;
+ return result;
+ }
+
+ // Emit SPV Inst that represents a type defined in a SpvSnippet.
+ void emitSpvSnippetASMTypeOperand(SpvSnippet::ASMType type)
+ {
+ IRBuilder builder;
+ builder.sharedBuilder = &m_sharedIRBuilder;
+ builder.setInsertInto(m_irModule->getModuleInst());
+ IRType* irType = nullptr;
+ switch (type)
+ {
+ case SpvSnippet::ASMType::Float:
+ irType = builder.getType(kIROp_FloatType);
+ break;
+ case SpvSnippet::ASMType::Int:
+ irType = builder.getIntType();
+ break;
+ case SpvSnippet::ASMType::Float2:
+ irType = builder.getVectorType(
+ builder.getType(kIROp_FloatType), builder.getIntValue(builder.getIntType(), 2));
+ break;
+ case SpvSnippet::ASMType::UInt2:
+ irType = builder.getVectorType(
+ builder.getType(kIROp_UIntType), builder.getIntValue(builder.getIntType(), 2));
+ break;
+ default:
+ break;
+ }
+ emitOperand(irType);
+ }
+
SpvInst* emitSpvSnippet(
SpvInstParent* parent,
IRCall* inst,
@@ -1950,11 +2325,10 @@ struct SPIRVEmitContext
switch (operand.type)
{
case SpvSnippet::ASMOperandType::SpvWord:
- emitOperand((SpvWord)operand.content);
+ emitOperand(operand.content);
break;
case SpvSnippet::ASMOperandType::ObjectReference:
- SLANG_ASSERT(
- operand.content >= 0 && operand.content < context.argumentIds.getCount());
+ SLANG_ASSERT(operand.content < (SpvWord)context.argumentIds.getCount());
emitOperand(context.argumentIds[operand.content]);
break;
case SpvSnippet::ASMOperandType::ResultId:
@@ -1972,8 +2346,64 @@ struct SPIRVEmitContext
}
break;
case SpvSnippet::ASMOperandType::InstReference:
- SLANG_ASSERT(operand.content >= 0 && operand.content < emittedInsts.getCount());
- emitOperand(getID(emittedInsts[operand.content]));
+ SLANG_ASSERT(operand.content < (SpvWord)emittedInsts.getCount());
+ emitOperand(emittedInsts[operand.content]);
+ break;
+ case SpvSnippet::ASMOperandType::GLSL450ExtInstSet:
+ emitOperand(getGLSL450ExtInst());
+ break;
+ case SpvSnippet::ASMOperandType::FloatIntegerSelection:
+ if (context.isResultTypeFloat)
+ {
+ emitOperand(operand.content);
+ }
+ else
+ {
+ emitOperand(operand.content2);
+ }
+ break;
+ case SpvSnippet::ASMOperandType::FloatUnsignedSignedSelection:
+ if (context.isResultTypeFloat)
+ {
+ emitOperand(operand.content);
+ }
+ else
+ {
+ if (context.isResultTypeSigned)
+ {
+ emitOperand(operand.content3);
+ }
+ else
+ {
+ emitOperand(operand.content2);
+ }
+ }
+ break;
+ case SpvSnippet::ASMOperandType::TypeReference:
+ {
+ emitSpvSnippetASMTypeOperand((SpvSnippet::ASMType)operand.content);
+ }
+ break;
+ case SpvSnippet::ASMOperandType::ConstantReference:
+ {
+ auto constant = snippet->constants[operand.content];
+ if (constant.type == SpvSnippet::ASMType::FloatOrDouble)
+ {
+ switch (extractBaseType(context.irResultType))
+ {
+ case BaseType::Float:
+ constant.type = SpvSnippet::ASMType::Float;
+ break;
+ case BaseType::Double:
+ constant.type = SpvSnippet::ASMType::Double;
+ break;
+ default:
+ break;
+ }
+ }
+ SpvInst* spvConstant = maybeEmitSpvConstant(constant);
+ emitOperand(spvConstant);
+ }
break;
}
}
@@ -2048,7 +2478,7 @@ struct SPIRVEmitContext
baseId = getID(varInst);
}
SLANG_ASSERT(baseStructType && "field_address require base to be a struct.");
- auto fieldId = emitConstant(
+ auto fieldId = emitIntConstant(
getStructFieldId(baseStructType, as<IRStructKey>(fieldAddress->getField())),
builder.getIntType());
return emitInst(
@@ -2069,7 +2499,7 @@ struct SPIRVEmitContext
IRStructType* baseStructType = as<IRStructType>(inst->getBase()->getDataType());
SLANG_ASSERT(baseStructType && "field_extract require base to be a struct.");
- auto fieldId = emitConstant(
+ auto fieldId = emitIntConstant(
getStructFieldId(baseStructType, as<IRStructKey>(inst->getField())),
builder.getIntType());
@@ -2163,17 +2593,31 @@ struct SPIRVEmitContext
SpvInst* emitSwizzle(SpvInstParent* parent, IRSwizzle* inst)
{
- return emitInstCustomOperandFunc(parent, inst, SpvOpVectorShuffle, [&]() {
- emitOperand(inst->getDataType());
- emitOperand(kResultID);
- emitOperand(inst->getBase());
- emitOperand(inst->getBase());
- for (UInt i = 0; i < inst->getElementCount(); i++)
- {
- auto index = as<IRIntLit>(inst->getElementIndex(i));
- emitOperand((SpvWord)index->getValue());
- }
- });
+ if (inst->getElementCount() == 1)
+ {
+ return emitInst(
+ parent,
+ inst,
+ SpvOpCompositeExtract,
+ inst->getDataType(),
+ kResultID,
+ inst->getBase(),
+ (SpvWord)as<IRIntLit>(inst->getElementIndex(0))->getValue());
+ }
+ else
+ {
+ return emitInstCustomOperandFunc(parent, inst, SpvOpVectorShuffle, [&]() {
+ emitOperand(inst->getDataType());
+ emitOperand(kResultID);
+ emitOperand(inst->getBase());
+ emitOperand(inst->getBase());
+ for (UInt i = 0; i < inst->getElementCount(); i++)
+ {
+ auto index = as<IRIntLit>(inst->getElementIndex(i));
+ emitOperand((SpvWord)index->getValue());
+ }
+ });
+ }
}
SpvInst* emitConstruct(SpvInstParent* parent, IRInst* inst)
@@ -2183,9 +2627,21 @@ struct SPIRVEmitContext
if (inst->getOperandCount() == 1)
{
if (inst->getDataType() == inst->getOperand(0)->getDataType())
- return emitInst(parent, inst, SpvOpCopyObject, kResultID, inst->getOperand(0));
+ return emitInst(
+ parent,
+ inst,
+ SpvOpCopyObject,
+ inst->getFullType(),
+ kResultID,
+ inst->getOperand(0));
else
- return emitInst(parent, inst, SpvOpBitcast, inst->getDataType(), kResultID, inst->getOperand(0));
+ return emitInst(
+ parent,
+ inst,
+ SpvOpBitcast,
+ inst->getFullType(),
+ kResultID,
+ inst->getOperand(0));
}
else
{
@@ -2205,18 +2661,39 @@ struct SPIRVEmitContext
}
}
- bool isSignedType(IRBasicType* basicType)
+ bool isSignedType(IRType* type)
{
- switch (basicType->getBaseType())
+ switch (type->getOp())
{
- case BaseType::Float:
- case BaseType::Double:
+ case kIROp_FloatType:
+ case kIROp_DoubleType:
return true;
- case BaseType::Int:
- case BaseType::Int16:
- case BaseType::Int64:
- case BaseType::Int8:
+ case kIROp_IntType:
+ case kIROp_Int16Type:
+ case kIROp_Int64Type:
+ case kIROp_Int8Type:
return true;
+ case kIROp_VectorType:
+ return isSignedType(as<IRVectorType>(type)->getElementType());
+ case kIROp_MatrixType:
+ return isSignedType(as<IRMatrixType>(type)->getElementType());
+ default:
+ return false;
+ }
+ }
+
+ bool isFloatType(IRInst* type)
+ {
+ switch (type->getOp())
+ {
+ case kIROp_FloatType:
+ case kIROp_DoubleType:
+ case kIROp_HalfType:
+ return true;
+ case kIROp_VectorType:
+ return isFloatType(as<IRVectorType>(type)->getElementType());
+ case kIROp_MatrixType:
+ return isFloatType(as<IRMatrixType>(type)->getElementType());
default:
return false;
}
@@ -2224,7 +2701,7 @@ struct SPIRVEmitContext
SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst)
{
- IRType* elementType = inst->getDataType();
+ IRType* elementType = inst->getOperand(0)->getDataType();
if (auto vectorType = as<IRVectorType>(inst->getDataType()))
{
elementType = vectorType->getElementType();
@@ -2245,6 +2722,7 @@ struct SPIRVEmitContext
break;
case BaseType::Bool:
isBool = true;
+ break;
default:
break;
}
@@ -2371,6 +2849,12 @@ struct SPIRVEmitContext
}
}
+ void diagnoseUnhandledInst(IRInst* inst)
+ {
+ m_sink->diagnose(
+ inst, Diagnostics::unimplemented, "unexpected IR opcode during code emit");
+ }
+
SPIRVEmitContext(IRModule* module, TargetRequest* target, DiagnosticSink* sink)
: SPIRVEmitSharedContext(module, target)
, m_irModule(module)
@@ -2390,7 +2874,7 @@ SlangResult emitSPIRVFromIR(
spirvOut.clear();
SPIRVEmitContext context(irModule, targetRequest, compileRequest->getSink());
- legalizeIRForSPIRV(&context, irModule, compileRequest->getSink());
+ legalizeIRForSPIRV(&context, irModule, irEntryPoints, compileRequest->getSink());
context.emitFrontMatter();
for (auto irEntryPoint : irEntryPoints)