diff options
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 157 |
1 files changed, 149 insertions, 8 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 9d8c4d89a..3febbd210 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1192,6 +1192,10 @@ struct SPIRVEmitContext case kIROp_DoubleType: { const FloatInfo i = getFloatingTypeInfo(as<IRType>(inst)); + if (inst->getOp() == kIROp_DoubleType) + requireSPIRVCapability(SpvCapabilityFloat64); + else if (inst->getOp() == kIROp_HalfType) + requireSPIRVCapability(SpvCapabilityFloat16); return emitOpTypeFloat(inst, SpvLiteralInteger::from32(int32_t(i.width))); } case kIROp_PtrType: @@ -1359,12 +1363,34 @@ struct SPIRVEmitContext // return emitFunc(as<IRFunc>(inst)); - case kIROp_BoolLit: - case kIROp_IntLit: - case kIROp_FloatLit: - case kIROp_StringLit: - return emitLit(inst); - + case kIROp_BoolLit: + case kIROp_IntLit: + case kIROp_FloatLit: + case kIROp_StringLit: + return emitLit(inst); + case kIROp_MakeVectorFromScalar: + { + const auto scalar = inst->getOperand(0); + const auto vecTy = as<IRVectorType>(inst->getDataType()); + SLANG_ASSERT(vecTy); + const auto numElems = as<IRIntLit>(vecTy->getElementCount()); + SLANG_ASSERT(numElems); + return emitSplat( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + inst, + scalar, + numElems->getValue()); + } + case kIROp_MakeVector: + case kIROp_MakeArray: + case kIROp_MakeStruct: + return emitCompositeConstruct(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst); + case kIROp_MakeArrayFromElement: + return emitMakeArrayFromElement(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst); + case kIROp_MakeMatrix: + return emitMakeMatrix(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst); + case kIROp_MakeMatrixFromScalar: + return emitMakeMatrixFromScalar(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst); case kIROp_GlobalParam: return emitGlobalParam(as<IRGlobalParam>(inst)); case kIROp_GlobalVar: @@ -1816,6 +1842,12 @@ struct SPIRVEmitContext return emitGetElement(parent, as<IRGetElement>(inst)); case kIROp_MakeStruct: return emitCompositeConstruct(parent, inst); + case kIROp_MakeArrayFromElement: + return emitMakeArrayFromElement(parent, inst); + case kIROp_MakeMatrixFromScalar: + return emitMakeMatrixFromScalar(parent, inst); + case kIROp_MakeMatrix: + return emitMakeMatrix(parent, inst); case kIROp_Load: return emitLoad(parent, as<IRLoad>(inst)); case kIROp_Store: @@ -1955,7 +1987,8 @@ struct SPIRVEmitContext } case kIROp_MakeArray: return emitConstruct(parent, inst); - + case kIROp_Select: + return emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, OperandsOf(inst)); case kIROp_DebugLine: return emitDebugLine(parent, as<IRDebugLine>(inst)); } @@ -2415,6 +2448,33 @@ struct SPIRVEmitContext void emitLoopHeaderBlock(IRLoop* loopInst, SpvInst* loopHeaderBlock) { + bool hasBackJump = false; + for (auto use = loopInst->getTargetBlock()->firstUse; use; use = use->nextUse) + { + if (use->getUser() == loopInst) + continue; + hasBackJump = true; + break; + } + if (!hasBackJump) + { + // If the loop does not have a back jump, it is used as a breakable region. + // SPIRV does not allow loops without a back jump, so we are going to emit + // a switch instead. + IRBuilder builder(loopInst); + builder.setInsertBefore(loopInst); + emitOpSelectionMerge( + loopHeaderBlock, + nullptr, + getIRInstSpvID(loopInst->getBreakBlock()), + SpvSelectionControlMaskNone + ); + emitInst(loopHeaderBlock, nullptr, SpvOpSwitch, + emitIntConstant(0, builder.getIntType()), + getIRInstSpvID(loopInst->getTargetBlock())); + return; + } + SpvLoopControlMask loopControl = SpvLoopControlMaskNone; if (auto loopControlDecoration = loopInst->findDecoration<IRLoopControlDecoration>()) { @@ -3049,11 +3109,91 @@ struct SPIRVEmitContext : emitOpConvertFToU(parent, inst, toTypeV, inst->getOperand(0)); } + template<typename T, typename Ts> + SpvInst* emitCompositeConstruct( + SpvInstParent* parent, + IRInst* inst, + const T& idResultType, + const Ts& constituents) + { + if (parent == getSection(SpvLogicalSectionID::ConstantsAndTypes)) + return emitOpConstantComposite(parent, inst, idResultType, constituents); + return emitOpCompositeConstruct(parent, inst, idResultType, constituents); + } + SpvInst* emitCompositeConstruct(SpvInstParent* parent, IRInst* inst) { + if (parent == getSection(SpvLogicalSectionID::ConstantsAndTypes)) + return emitOpConstantComposite(parent, inst, inst->getDataType(), OperandsOf(inst)); return emitOpCompositeConstruct(parent, inst, inst->getDataType(), OperandsOf(inst)); } + SpvInst* emitMakeArrayFromElement(SpvInstParent* parent, IRInst* inst) + { + List<IRInst*> elements; + auto arrayType = as<IRArrayType>(inst->getDataType()); + auto elementCount = getIntVal(arrayType->getElementCount()); + for (IRIntegerValue i = 0; i < elementCount; i++) + { + elements.add(inst->getOperand(0)); + } + return emitCompositeConstruct(parent, inst, inst->getDataType(), elements); + } + + SpvInst* emitMakeMatrixFromScalar(SpvInstParent* parent, IRInst* inst) + { + List<SpvInst*> rowVectors; + auto matrixType = as<IRMatrixType>(inst->getDataType()); + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); + List<IRInst*> colElements; + for (IRIntegerValue i = 0; i < colCount; i++) + { + colElements.add(inst->getOperand(0)); + } + auto rowVector = emitCompositeConstruct(parent, nullptr, rowVectorType, colElements); + for (IRIntegerValue i = 0; i < rowCount; i++) + { + rowVectors.add(rowVector); + } + return emitCompositeConstruct(parent, inst, inst->getDataType(), rowVectors); + } + + SpvInst* emitMakeMatrix(SpvInstParent* parent, IRInst* inst) + { + // If operands are already row vectors, use CompositeConstruct directly. + if (as<IRVectorType>(inst->getOperand(0)->getDataType())) + { + return emitCompositeConstruct(parent, inst); + } + // Otherwise, operands are raw elements, we need to construct row vectors first, + // then construct matrix from row vectors. + List<SpvInst*> rowVectors; + auto matrixType = as<IRMatrixType>(inst->getDataType()); + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); + List<IRInst*> colElements; + UInt index = 0; + for (IRIntegerValue j = 0; j < rowCount; j++) + { + colElements.clear(); + for (IRIntegerValue i = 0; i < colCount; i++) + { + colElements.add(inst->getOperand(index)); + index++; + } + auto rowVector = emitCompositeConstruct(parent, nullptr, rowVectorType, colElements); + rowVectors.add(rowVector); + } + return emitCompositeConstruct(parent, inst, inst->getDataType(), rowVectors); + } + SpvInst* emitConstruct(SpvInstParent* parent, IRInst* inst) { if (as<IRBasicType>(inst->getDataType())) @@ -3093,7 +3233,7 @@ struct SPIRVEmitContext scalarTy->getBaseType(), numElems, nullptr); - return emitOpCompositeConstruct( + return emitCompositeConstruct( parent, inst, spvVecTy, @@ -3154,6 +3294,7 @@ struct SPIRVEmitContext { case BaseType::Float: case BaseType::Double: + case BaseType::Half: isFloatingPoint = true; break; case BaseType::Bool: |
