summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp157
1 files changed, 149 insertions, 8 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 9d8c4d89a..3febbd210 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1192,6 +1192,10 @@ struct SPIRVEmitContext
case kIROp_DoubleType:
{
const FloatInfo i = getFloatingTypeInfo(as<IRType>(inst));
+ if (inst->getOp() == kIROp_DoubleType)
+ requireSPIRVCapability(SpvCapabilityFloat64);
+ else if (inst->getOp() == kIROp_HalfType)
+ requireSPIRVCapability(SpvCapabilityFloat16);
return emitOpTypeFloat(inst, SpvLiteralInteger::from32(int32_t(i.width)));
}
case kIROp_PtrType:
@@ -1359,12 +1363,34 @@ struct SPIRVEmitContext
//
return emitFunc(as<IRFunc>(inst));
- case kIROp_BoolLit:
- case kIROp_IntLit:
- case kIROp_FloatLit:
- case kIROp_StringLit:
- return emitLit(inst);
-
+ case kIROp_BoolLit:
+ case kIROp_IntLit:
+ case kIROp_FloatLit:
+ case kIROp_StringLit:
+ return emitLit(inst);
+ case kIROp_MakeVectorFromScalar:
+ {
+ const auto scalar = inst->getOperand(0);
+ const auto vecTy = as<IRVectorType>(inst->getDataType());
+ SLANG_ASSERT(vecTy);
+ const auto numElems = as<IRIntLit>(vecTy->getElementCount());
+ SLANG_ASSERT(numElems);
+ return emitSplat(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ inst,
+ scalar,
+ numElems->getValue());
+ }
+ case kIROp_MakeVector:
+ case kIROp_MakeArray:
+ case kIROp_MakeStruct:
+ return emitCompositeConstruct(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst);
+ case kIROp_MakeArrayFromElement:
+ return emitMakeArrayFromElement(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst);
+ case kIROp_MakeMatrix:
+ return emitMakeMatrix(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst);
+ case kIROp_MakeMatrixFromScalar:
+ return emitMakeMatrixFromScalar(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst);
case kIROp_GlobalParam:
return emitGlobalParam(as<IRGlobalParam>(inst));
case kIROp_GlobalVar:
@@ -1816,6 +1842,12 @@ struct SPIRVEmitContext
return emitGetElement(parent, as<IRGetElement>(inst));
case kIROp_MakeStruct:
return emitCompositeConstruct(parent, inst);
+ case kIROp_MakeArrayFromElement:
+ return emitMakeArrayFromElement(parent, inst);
+ case kIROp_MakeMatrixFromScalar:
+ return emitMakeMatrixFromScalar(parent, inst);
+ case kIROp_MakeMatrix:
+ return emitMakeMatrix(parent, inst);
case kIROp_Load:
return emitLoad(parent, as<IRLoad>(inst));
case kIROp_Store:
@@ -1955,7 +1987,8 @@ struct SPIRVEmitContext
}
case kIROp_MakeArray:
return emitConstruct(parent, inst);
-
+ case kIROp_Select:
+ return emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, OperandsOf(inst));
case kIROp_DebugLine:
return emitDebugLine(parent, as<IRDebugLine>(inst));
}
@@ -2415,6 +2448,33 @@ struct SPIRVEmitContext
void emitLoopHeaderBlock(IRLoop* loopInst, SpvInst* loopHeaderBlock)
{
+ bool hasBackJump = false;
+ for (auto use = loopInst->getTargetBlock()->firstUse; use; use = use->nextUse)
+ {
+ if (use->getUser() == loopInst)
+ continue;
+ hasBackJump = true;
+ break;
+ }
+ if (!hasBackJump)
+ {
+ // If the loop does not have a back jump, it is used as a breakable region.
+ // SPIRV does not allow loops without a back jump, so we are going to emit
+ // a switch instead.
+ IRBuilder builder(loopInst);
+ builder.setInsertBefore(loopInst);
+ emitOpSelectionMerge(
+ loopHeaderBlock,
+ nullptr,
+ getIRInstSpvID(loopInst->getBreakBlock()),
+ SpvSelectionControlMaskNone
+ );
+ emitInst(loopHeaderBlock, nullptr, SpvOpSwitch,
+ emitIntConstant(0, builder.getIntType()),
+ getIRInstSpvID(loopInst->getTargetBlock()));
+ return;
+ }
+
SpvLoopControlMask loopControl = SpvLoopControlMaskNone;
if (auto loopControlDecoration = loopInst->findDecoration<IRLoopControlDecoration>())
{
@@ -3049,11 +3109,91 @@ struct SPIRVEmitContext
: emitOpConvertFToU(parent, inst, toTypeV, inst->getOperand(0));
}
+ template<typename T, typename Ts>
+ SpvInst* emitCompositeConstruct(
+ SpvInstParent* parent,
+ IRInst* inst,
+ const T& idResultType,
+ const Ts& constituents)
+ {
+ if (parent == getSection(SpvLogicalSectionID::ConstantsAndTypes))
+ return emitOpConstantComposite(parent, inst, idResultType, constituents);
+ return emitOpCompositeConstruct(parent, inst, idResultType, constituents);
+ }
+
SpvInst* emitCompositeConstruct(SpvInstParent* parent, IRInst* inst)
{
+ if (parent == getSection(SpvLogicalSectionID::ConstantsAndTypes))
+ return emitOpConstantComposite(parent, inst, inst->getDataType(), OperandsOf(inst));
return emitOpCompositeConstruct(parent, inst, inst->getDataType(), OperandsOf(inst));
}
+ SpvInst* emitMakeArrayFromElement(SpvInstParent* parent, IRInst* inst)
+ {
+ List<IRInst*> elements;
+ auto arrayType = as<IRArrayType>(inst->getDataType());
+ auto elementCount = getIntVal(arrayType->getElementCount());
+ for (IRIntegerValue i = 0; i < elementCount; i++)
+ {
+ elements.add(inst->getOperand(0));
+ }
+ return emitCompositeConstruct(parent, inst, inst->getDataType(), elements);
+ }
+
+ SpvInst* emitMakeMatrixFromScalar(SpvInstParent* parent, IRInst* inst)
+ {
+ List<SpvInst*> rowVectors;
+ auto matrixType = as<IRMatrixType>(inst->getDataType());
+ auto rowCount = getIntVal(matrixType->getRowCount());
+ auto colCount = getIntVal(matrixType->getColumnCount());
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount);
+ List<IRInst*> colElements;
+ for (IRIntegerValue i = 0; i < colCount; i++)
+ {
+ colElements.add(inst->getOperand(0));
+ }
+ auto rowVector = emitCompositeConstruct(parent, nullptr, rowVectorType, colElements);
+ for (IRIntegerValue i = 0; i < rowCount; i++)
+ {
+ rowVectors.add(rowVector);
+ }
+ return emitCompositeConstruct(parent, inst, inst->getDataType(), rowVectors);
+ }
+
+ SpvInst* emitMakeMatrix(SpvInstParent* parent, IRInst* inst)
+ {
+ // If operands are already row vectors, use CompositeConstruct directly.
+ if (as<IRVectorType>(inst->getOperand(0)->getDataType()))
+ {
+ return emitCompositeConstruct(parent, inst);
+ }
+ // Otherwise, operands are raw elements, we need to construct row vectors first,
+ // then construct matrix from row vectors.
+ List<SpvInst*> rowVectors;
+ auto matrixType = as<IRMatrixType>(inst->getDataType());
+ auto rowCount = getIntVal(matrixType->getRowCount());
+ auto colCount = getIntVal(matrixType->getColumnCount());
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount);
+ List<IRInst*> colElements;
+ UInt index = 0;
+ for (IRIntegerValue j = 0; j < rowCount; j++)
+ {
+ colElements.clear();
+ for (IRIntegerValue i = 0; i < colCount; i++)
+ {
+ colElements.add(inst->getOperand(index));
+ index++;
+ }
+ auto rowVector = emitCompositeConstruct(parent, nullptr, rowVectorType, colElements);
+ rowVectors.add(rowVector);
+ }
+ return emitCompositeConstruct(parent, inst, inst->getDataType(), rowVectors);
+ }
+
SpvInst* emitConstruct(SpvInstParent* parent, IRInst* inst)
{
if (as<IRBasicType>(inst->getDataType()))
@@ -3093,7 +3233,7 @@ struct SPIRVEmitContext
scalarTy->getBaseType(),
numElems,
nullptr);
- return emitOpCompositeConstruct(
+ return emitCompositeConstruct(
parent,
inst,
spvVecTy,
@@ -3154,6 +3294,7 @@ struct SPIRVEmitContext
{
case BaseType::Float:
case BaseType::Double:
+ case BaseType::Half:
isFloatingPoint = true;
break;
case BaseType::Bool: