summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp325
1 files changed, 213 insertions, 112 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 096e7d8bc..32d3ba7c3 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -775,6 +775,147 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
m_operandStack.setCount(operandsStartIndex);
}
+ SpvOp _specConstantOpcodeConvert(IROp irOpCode, IRBasicType* basicType)
+ {
+ SpvOp opCode = SpvOpUndef;
+ opCode = _arithmeticOpCodeConvert(irOpCode, basicType);
+ if (opCode == SpvOpUndef)
+ {
+ switch (irOpCode)
+ {
+ case kIROp_IntCast:
+ {
+ auto typeStyle = getTypeStyle(basicType->getBaseType());
+ if (typeStyle == kIROp_FloatType)
+ {
+ return SpvOpConvertFToU;
+ }
+ else if (typeStyle == kIROp_IntType)
+ {
+ return SpvOpUConvert;
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ return opCode;
+ }
+ return opCode;
+ }
+
+ SpvOp _arithmeticOpCodeConvert(IROp irOpCode, IRBasicType* basicType)
+ {
+ bool isFloatingPoint = false;
+ bool isBool = false;
+ switch (basicType->getBaseType())
+ {
+ case BaseType::Float:
+ case BaseType::Double:
+ case BaseType::Half:
+ isFloatingPoint = true;
+ break;
+ case BaseType::Bool:
+ isBool = true;
+ break;
+ default:
+ break;
+ }
+ bool isSigned = isSignedType(basicType);
+ SpvOp opCode = SpvOpUndef;
+ switch (irOpCode)
+ {
+ case kIROp_Add:
+ opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd;
+ break;
+ case kIROp_Sub:
+ opCode = isFloatingPoint ? SpvOpFSub : SpvOpISub;
+ break;
+ case kIROp_Mul:
+ opCode = isFloatingPoint ? SpvOpFMul : SpvOpIMul;
+ break;
+ case kIROp_Div:
+ opCode = isFloatingPoint ? SpvOpFDiv : isSigned ? SpvOpSDiv : SpvOpUDiv;
+ break;
+ case kIROp_IRem:
+ opCode = isSigned ? SpvOpSRem : SpvOpUMod;
+ break;
+ case kIROp_FRem:
+ opCode = SpvOpFRem;
+ break;
+ case kIROp_Less:
+ opCode = isFloatingPoint ? SpvOpFOrdLessThan
+ : isSigned ? SpvOpSLessThan
+ : SpvOpULessThan;
+ break;
+ case kIROp_Leq:
+ opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual
+ : isSigned ? SpvOpSLessThanEqual
+ : SpvOpULessThanEqual;
+ break;
+ case kIROp_Eql:
+ opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual;
+ break;
+ case kIROp_Neq:
+ opCode = isFloatingPoint ? SpvOpFUnordNotEqual
+ : isBool ? SpvOpLogicalNotEqual
+ : SpvOpINotEqual;
+ break;
+ case kIROp_Geq:
+ opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual
+ : isSigned ? SpvOpSGreaterThanEqual
+ : SpvOpUGreaterThanEqual;
+ break;
+ case kIROp_Greater:
+ opCode = isFloatingPoint ? SpvOpFOrdGreaterThan
+ : isSigned ? SpvOpSGreaterThan
+ : SpvOpUGreaterThan;
+ break;
+ case kIROp_Neg:
+ opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate;
+ break;
+ case kIROp_And:
+ opCode = SpvOpLogicalAnd;
+ break;
+ case kIROp_Or:
+ opCode = SpvOpLogicalOr;
+ break;
+ case kIROp_Not:
+ opCode = SpvOpLogicalNot;
+ break;
+ case kIROp_BitAnd:
+ if (isBool)
+ opCode = SpvOpLogicalAnd;
+ else
+ opCode = SpvOpBitwiseAnd;
+ break;
+ case kIROp_BitOr:
+ if (isBool)
+ opCode = SpvOpLogicalOr;
+ else
+ opCode = SpvOpBitwiseOr;
+ break;
+ case kIROp_BitXor:
+ if (isBool)
+ opCode = SpvOpLogicalNotEqual;
+ else
+ opCode = SpvOpBitwiseXor;
+ break;
+ case kIROp_BitNot:
+ if (isBool)
+ opCode = SpvOpLogicalNot;
+ else
+ opCode = SpvOpNot;
+ break;
+ case kIROp_Rsh:
+ opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical;
+ break;
+ case kIROp_Lsh:
+ opCode = SpvOpShiftLeftLogical;
+ break;
+ }
+ return opCode;
+ }
/// Ensure that an instruction has been emitted
SpvInst* ensureInst(IRInst* irInst)
{
@@ -1972,8 +2113,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
as<IRDebugInlinedAt>(inst));
default:
{
- if (as<IRSPIRVAsmOperand>(inst))
+ if (isSpecConstRateType(inst->getFullType()))
+ return emitSpecializationConstantOp(inst);
+
+ else if (as<IRSPIRVAsmOperand>(inst))
return nullptr;
+
String e = "Unhandled global inst in spirv-emit:\n" +
dumpIRToString(inst, {IRDumpOptions::Mode::Detailed, 0});
SLANG_UNIMPLEMENTED_X(e.begin());
@@ -2756,6 +2901,66 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return result;
}
+ SpvInst* emitSpecializationConstantOp(IRInst* inst)
+ {
+ SpvInst* spv = nullptr;
+ if (m_mapIRInstToSpvInst.tryGetValue(inst, spv))
+ return spv;
+
+ // For each OpSpecConstantOp, the operand must be:
+ // 1. A specialization constant
+ // 2. A literal constant
+ // 3. Another OpSpecConstantOp
+
+ // For 1 and 2, we can just emit the specialization constant or literal constant.
+ if (auto param = as<IRGlobalParam>(inst))
+ {
+ auto layout = getVarLayout(param);
+ if (layout)
+ {
+ if (auto offset =
+ layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant))
+ {
+ return emitSpecializationConstant(param, offset);
+ }
+ }
+ SLANG_UNREACHABLE("Non specialization constant used in OpSpecConstantOp\n");
+ }
+ else if (as<IRConstant>(inst))
+ {
+ // We need to emit the constant as a specialization constant
+ return emitLit(inst);
+ }
+
+ IRType* type = inst->getOperand(0)->getDataType();
+ IRBasicType* basicType = as<IRBasicType>(type);
+ SpvOp opCode = _specConstantOpcodeConvert(inst->getOp(), basicType);
+ if (opCode == SpvOpUndef)
+ {
+ String e = "Unhandled inst in spirv-emit:\n" +
+ dumpIRToString(inst, {IRDumpOptions::Mode::Detailed, 0});
+ SLANG_UNIMPLEMENTED_X(e.getBuffer());
+ }
+
+ Array<SpvInst*, 3> operands;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto operand = inst->getOperand(i);
+ SpvInst* spvInst = emitSpecializationConstantOp(operand);
+ operands.add(spvInst);
+ }
+
+ auto resultType = inst->getFullType();
+ return emitInst(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ inst,
+ SpvOpSpecConstantOp,
+ resultType,
+ kResultID,
+ opCode,
+ operands);
+ }
+
/// Emit a global parameter definition.
SpvInst* emitGlobalParam(IRGlobalParam* param)
{
@@ -7197,117 +7402,13 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType());
IRBasicType* basicType = as<IRBasicType>(elementType);
- bool isFloatingPoint = false;
- bool isBool = false;
- switch (basicType->getBaseType())
- {
- case BaseType::Float:
- case BaseType::Double:
- case BaseType::Half:
- isFloatingPoint = true;
- break;
- case BaseType::Bool:
- isBool = true;
- break;
- default:
- break;
- }
- SpvOp opCode = SpvOpUndef;
- bool isSigned = isSignedType(basicType);
- switch (op)
- {
- case kIROp_Add:
- opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd;
- break;
- case kIROp_Sub:
- opCode = isFloatingPoint ? SpvOpFSub : SpvOpISub;
- break;
- case kIROp_Mul:
- opCode = isFloatingPoint ? SpvOpFMul : SpvOpIMul;
- break;
- case kIROp_Div:
- opCode = isFloatingPoint ? SpvOpFDiv : isSigned ? SpvOpSDiv : SpvOpUDiv;
- break;
- case kIROp_IRem:
- opCode = isSigned ? SpvOpSRem : SpvOpUMod;
- break;
- case kIROp_FRem:
- opCode = SpvOpFRem;
- break;
- case kIROp_Less:
- opCode = isFloatingPoint ? SpvOpFOrdLessThan
- : isSigned ? SpvOpSLessThan
- : SpvOpULessThan;
- break;
- case kIROp_Leq:
- opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual
- : isSigned ? SpvOpSLessThanEqual
- : SpvOpULessThanEqual;
- break;
- case kIROp_Eql:
- opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual;
- break;
- case kIROp_Neq:
- opCode = isFloatingPoint ? SpvOpFUnordNotEqual
- : isBool ? SpvOpLogicalNotEqual
- : SpvOpINotEqual;
- break;
- case kIROp_Geq:
- opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual
- : isSigned ? SpvOpSGreaterThanEqual
- : SpvOpUGreaterThanEqual;
- break;
- case kIROp_Greater:
- opCode = isFloatingPoint ? SpvOpFOrdGreaterThan
- : isSigned ? SpvOpSGreaterThan
- : SpvOpUGreaterThan;
- break;
- case kIROp_Neg:
- opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate;
- break;
- case kIROp_And:
- opCode = SpvOpLogicalAnd;
- break;
- case kIROp_Or:
- opCode = SpvOpLogicalOr;
- break;
- case kIROp_Not:
- opCode = SpvOpLogicalNot;
- break;
- case kIROp_BitAnd:
- if (isBool)
- opCode = SpvOpLogicalAnd;
- else
- opCode = SpvOpBitwiseAnd;
- break;
- case kIROp_BitOr:
- if (isBool)
- opCode = SpvOpLogicalOr;
- else
- opCode = SpvOpBitwiseOr;
- break;
- case kIROp_BitXor:
- if (isBool)
- opCode = SpvOpLogicalNotEqual;
- else
- opCode = SpvOpBitwiseXor;
- break;
- case kIROp_BitNot:
- if (isBool)
- opCode = SpvOpLogicalNot;
- else
- opCode = SpvOpNot;
- break;
- case kIROp_Rsh:
- opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical;
- break;
- case kIROp_Lsh:
- opCode = SpvOpShiftLeftLogical;
- break;
- default:
+
+ SpvOp opCode = _arithmeticOpCodeConvert(op, basicType);
+ if (opCode == SpvOpUndef)
SLANG_ASSERT(!"unknown arithmetic opcode");
- break;
- }
+
+ bool isFloatingPoint = (getTypeStyle(basicType->getBaseType()) == kIROp_FloatType);
+
if (operandCount == 1)
{
return emitInst(parent, instToRegister, opCode, type, kResultID, operands);
@@ -7846,7 +7947,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
emitDebugType(arrayType->getElementType()),
sizedArrayType ? builder.getIntValue(
builder.getUIntType(),
- getIntVal(sizedArrayType->getElementCount()))
+ getArraySizeVal(sizedArrayType->getElementCount()))
: builder.getIntValue(builder.getUIntType(), 0));
}
else if (auto vectorType = as<IRVectorType>(type))