diff options
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 325 |
1 files changed, 213 insertions, 112 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 096e7d8bc..32d3ba7c3 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -775,6 +775,147 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex m_operandStack.setCount(operandsStartIndex); } + SpvOp _specConstantOpcodeConvert(IROp irOpCode, IRBasicType* basicType) + { + SpvOp opCode = SpvOpUndef; + opCode = _arithmeticOpCodeConvert(irOpCode, basicType); + if (opCode == SpvOpUndef) + { + switch (irOpCode) + { + case kIROp_IntCast: + { + auto typeStyle = getTypeStyle(basicType->getBaseType()); + if (typeStyle == kIROp_FloatType) + { + return SpvOpConvertFToU; + } + else if (typeStyle == kIROp_IntType) + { + return SpvOpUConvert; + } + break; + } + default: + break; + } + return opCode; + } + return opCode; + } + + SpvOp _arithmeticOpCodeConvert(IROp irOpCode, IRBasicType* basicType) + { + bool isFloatingPoint = false; + bool isBool = false; + switch (basicType->getBaseType()) + { + case BaseType::Float: + case BaseType::Double: + case BaseType::Half: + isFloatingPoint = true; + break; + case BaseType::Bool: + isBool = true; + break; + default: + break; + } + bool isSigned = isSignedType(basicType); + SpvOp opCode = SpvOpUndef; + switch (irOpCode) + { + case kIROp_Add: + opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd; + break; + case kIROp_Sub: + opCode = isFloatingPoint ? SpvOpFSub : SpvOpISub; + break; + case kIROp_Mul: + opCode = isFloatingPoint ? SpvOpFMul : SpvOpIMul; + break; + case kIROp_Div: + opCode = isFloatingPoint ? SpvOpFDiv : isSigned ? SpvOpSDiv : SpvOpUDiv; + break; + case kIROp_IRem: + opCode = isSigned ? SpvOpSRem : SpvOpUMod; + break; + case kIROp_FRem: + opCode = SpvOpFRem; + break; + case kIROp_Less: + opCode = isFloatingPoint ? SpvOpFOrdLessThan + : isSigned ? SpvOpSLessThan + : SpvOpULessThan; + break; + case kIROp_Leq: + opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual + : isSigned ? SpvOpSLessThanEqual + : SpvOpULessThanEqual; + break; + case kIROp_Eql: + opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual; + break; + case kIROp_Neq: + opCode = isFloatingPoint ? SpvOpFUnordNotEqual + : isBool ? SpvOpLogicalNotEqual + : SpvOpINotEqual; + break; + case kIROp_Geq: + opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual + : isSigned ? SpvOpSGreaterThanEqual + : SpvOpUGreaterThanEqual; + break; + case kIROp_Greater: + opCode = isFloatingPoint ? SpvOpFOrdGreaterThan + : isSigned ? SpvOpSGreaterThan + : SpvOpUGreaterThan; + break; + case kIROp_Neg: + opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate; + break; + case kIROp_And: + opCode = SpvOpLogicalAnd; + break; + case kIROp_Or: + opCode = SpvOpLogicalOr; + break; + case kIROp_Not: + opCode = SpvOpLogicalNot; + break; + case kIROp_BitAnd: + if (isBool) + opCode = SpvOpLogicalAnd; + else + opCode = SpvOpBitwiseAnd; + break; + case kIROp_BitOr: + if (isBool) + opCode = SpvOpLogicalOr; + else + opCode = SpvOpBitwiseOr; + break; + case kIROp_BitXor: + if (isBool) + opCode = SpvOpLogicalNotEqual; + else + opCode = SpvOpBitwiseXor; + break; + case kIROp_BitNot: + if (isBool) + opCode = SpvOpLogicalNot; + else + opCode = SpvOpNot; + break; + case kIROp_Rsh: + opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical; + break; + case kIROp_Lsh: + opCode = SpvOpShiftLeftLogical; + break; + } + return opCode; + } /// Ensure that an instruction has been emitted SpvInst* ensureInst(IRInst* irInst) { @@ -1972,8 +2113,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex as<IRDebugInlinedAt>(inst)); default: { - if (as<IRSPIRVAsmOperand>(inst)) + if (isSpecConstRateType(inst->getFullType())) + return emitSpecializationConstantOp(inst); + + else if (as<IRSPIRVAsmOperand>(inst)) return nullptr; + String e = "Unhandled global inst in spirv-emit:\n" + dumpIRToString(inst, {IRDumpOptions::Mode::Detailed, 0}); SLANG_UNIMPLEMENTED_X(e.begin()); @@ -2756,6 +2901,66 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return result; } + SpvInst* emitSpecializationConstantOp(IRInst* inst) + { + SpvInst* spv = nullptr; + if (m_mapIRInstToSpvInst.tryGetValue(inst, spv)) + return spv; + + // For each OpSpecConstantOp, the operand must be: + // 1. A specialization constant + // 2. A literal constant + // 3. Another OpSpecConstantOp + + // For 1 and 2, we can just emit the specialization constant or literal constant. + if (auto param = as<IRGlobalParam>(inst)) + { + auto layout = getVarLayout(param); + if (layout) + { + if (auto offset = + layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant)) + { + return emitSpecializationConstant(param, offset); + } + } + SLANG_UNREACHABLE("Non specialization constant used in OpSpecConstantOp\n"); + } + else if (as<IRConstant>(inst)) + { + // We need to emit the constant as a specialization constant + return emitLit(inst); + } + + IRType* type = inst->getOperand(0)->getDataType(); + IRBasicType* basicType = as<IRBasicType>(type); + SpvOp opCode = _specConstantOpcodeConvert(inst->getOp(), basicType); + if (opCode == SpvOpUndef) + { + String e = "Unhandled inst in spirv-emit:\n" + + dumpIRToString(inst, {IRDumpOptions::Mode::Detailed, 0}); + SLANG_UNIMPLEMENTED_X(e.getBuffer()); + } + + Array<SpvInst*, 3> operands; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto operand = inst->getOperand(i); + SpvInst* spvInst = emitSpecializationConstantOp(operand); + operands.add(spvInst); + } + + auto resultType = inst->getFullType(); + return emitInst( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + inst, + SpvOpSpecConstantOp, + resultType, + kResultID, + opCode, + operands); + } + /// Emit a global parameter definition. SpvInst* emitGlobalParam(IRGlobalParam* param) { @@ -7197,117 +7402,13 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType()); IRBasicType* basicType = as<IRBasicType>(elementType); - bool isFloatingPoint = false; - bool isBool = false; - switch (basicType->getBaseType()) - { - case BaseType::Float: - case BaseType::Double: - case BaseType::Half: - isFloatingPoint = true; - break; - case BaseType::Bool: - isBool = true; - break; - default: - break; - } - SpvOp opCode = SpvOpUndef; - bool isSigned = isSignedType(basicType); - switch (op) - { - case kIROp_Add: - opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd; - break; - case kIROp_Sub: - opCode = isFloatingPoint ? SpvOpFSub : SpvOpISub; - break; - case kIROp_Mul: - opCode = isFloatingPoint ? SpvOpFMul : SpvOpIMul; - break; - case kIROp_Div: - opCode = isFloatingPoint ? SpvOpFDiv : isSigned ? SpvOpSDiv : SpvOpUDiv; - break; - case kIROp_IRem: - opCode = isSigned ? SpvOpSRem : SpvOpUMod; - break; - case kIROp_FRem: - opCode = SpvOpFRem; - break; - case kIROp_Less: - opCode = isFloatingPoint ? SpvOpFOrdLessThan - : isSigned ? SpvOpSLessThan - : SpvOpULessThan; - break; - case kIROp_Leq: - opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual - : isSigned ? SpvOpSLessThanEqual - : SpvOpULessThanEqual; - break; - case kIROp_Eql: - opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual; - break; - case kIROp_Neq: - opCode = isFloatingPoint ? SpvOpFUnordNotEqual - : isBool ? SpvOpLogicalNotEqual - : SpvOpINotEqual; - break; - case kIROp_Geq: - opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual - : isSigned ? SpvOpSGreaterThanEqual - : SpvOpUGreaterThanEqual; - break; - case kIROp_Greater: - opCode = isFloatingPoint ? SpvOpFOrdGreaterThan - : isSigned ? SpvOpSGreaterThan - : SpvOpUGreaterThan; - break; - case kIROp_Neg: - opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate; - break; - case kIROp_And: - opCode = SpvOpLogicalAnd; - break; - case kIROp_Or: - opCode = SpvOpLogicalOr; - break; - case kIROp_Not: - opCode = SpvOpLogicalNot; - break; - case kIROp_BitAnd: - if (isBool) - opCode = SpvOpLogicalAnd; - else - opCode = SpvOpBitwiseAnd; - break; - case kIROp_BitOr: - if (isBool) - opCode = SpvOpLogicalOr; - else - opCode = SpvOpBitwiseOr; - break; - case kIROp_BitXor: - if (isBool) - opCode = SpvOpLogicalNotEqual; - else - opCode = SpvOpBitwiseXor; - break; - case kIROp_BitNot: - if (isBool) - opCode = SpvOpLogicalNot; - else - opCode = SpvOpNot; - break; - case kIROp_Rsh: - opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical; - break; - case kIROp_Lsh: - opCode = SpvOpShiftLeftLogical; - break; - default: + + SpvOp opCode = _arithmeticOpCodeConvert(op, basicType); + if (opCode == SpvOpUndef) SLANG_ASSERT(!"unknown arithmetic opcode"); - break; - } + + bool isFloatingPoint = (getTypeStyle(basicType->getBaseType()) == kIROp_FloatType); + if (operandCount == 1) { return emitInst(parent, instToRegister, opCode, type, kResultID, operands); @@ -7846,7 +7947,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex emitDebugType(arrayType->getElementType()), sizedArrayType ? builder.getIntValue( builder.getUIntType(), - getIntVal(sizedArrayType->getElementCount())) + getArraySizeVal(sizedArrayType->getElementCount())) : builder.getIntValue(builder.getUIntType(), 0)); } else if (auto vectorType = as<IRVectorType>(type)) |
