summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
authorkaizhangNV <149626564+kaizhangNV@users.noreply.github.com>2025-05-14 12:11:53 -0500
committerGitHub <noreply@github.com>2025-05-14 10:11:53 -0700
commit375ecfe2903b09f07abeba2eafb88d9a564c1458 (patch)
treea507ffcdbe118f5d69ffb3e6c341d8f954e0bfef /source/slang/slang-emit-spirv.cpp
parent39c9e25f6d728e970b68a9452330e754991b4ac5 (diff)
support specialization constant sized array (#6871)
Close #6859 Goal of this PR We want to support an array whose size can be specialization constant for shared/global variable e.g. layout (constant_id = 0) const uint BLOCK_SIZE = 64; shared float buf_a[(BLOCK_SIZE + 5) * 4]; Overview of the solution: During IndexExpr check, we will loose the restriction to allow SpecConst passing, but the size parameter will not be a constant value because it cannot be folded into a constant, so we will make it follow the same logic as generic parameter value, and the size will be represented by FuncCallIntVal/PolynomialIntVal/DeclRefIntVal. During IR lowering, we will detect whether there is spec constant in the IntVal, and wrap the IRInst with a SpecConstRateType, and propagate the type though the lowering logic, such that the IntVal representing the array size will have SpecConstRateType. During spirv emit stage, if we detect that a IRInst has SpecConstRateType, we will emit it as SpecConstantOp. We have to implement new logic to emit OpSpecConstantOp, the existing emit logic doesn't support emitting OpSpecConstantOp, especially this op can embed arithmetic operation at global scope, where we can only emit arithmetic instruct at local. But there are only few instructs we need to support. Overview of the solution: This PR doesn't support generic, and we will create a separate PR to extend that, tracked in #6840.
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp325
1 files changed, 213 insertions, 112 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 096e7d8bc..32d3ba7c3 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -775,6 +775,147 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
m_operandStack.setCount(operandsStartIndex);
}
+ SpvOp _specConstantOpcodeConvert(IROp irOpCode, IRBasicType* basicType)
+ {
+ SpvOp opCode = SpvOpUndef;
+ opCode = _arithmeticOpCodeConvert(irOpCode, basicType);
+ if (opCode == SpvOpUndef)
+ {
+ switch (irOpCode)
+ {
+ case kIROp_IntCast:
+ {
+ auto typeStyle = getTypeStyle(basicType->getBaseType());
+ if (typeStyle == kIROp_FloatType)
+ {
+ return SpvOpConvertFToU;
+ }
+ else if (typeStyle == kIROp_IntType)
+ {
+ return SpvOpUConvert;
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ return opCode;
+ }
+ return opCode;
+ }
+
+ SpvOp _arithmeticOpCodeConvert(IROp irOpCode, IRBasicType* basicType)
+ {
+ bool isFloatingPoint = false;
+ bool isBool = false;
+ switch (basicType->getBaseType())
+ {
+ case BaseType::Float:
+ case BaseType::Double:
+ case BaseType::Half:
+ isFloatingPoint = true;
+ break;
+ case BaseType::Bool:
+ isBool = true;
+ break;
+ default:
+ break;
+ }
+ bool isSigned = isSignedType(basicType);
+ SpvOp opCode = SpvOpUndef;
+ switch (irOpCode)
+ {
+ case kIROp_Add:
+ opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd;
+ break;
+ case kIROp_Sub:
+ opCode = isFloatingPoint ? SpvOpFSub : SpvOpISub;
+ break;
+ case kIROp_Mul:
+ opCode = isFloatingPoint ? SpvOpFMul : SpvOpIMul;
+ break;
+ case kIROp_Div:
+ opCode = isFloatingPoint ? SpvOpFDiv : isSigned ? SpvOpSDiv : SpvOpUDiv;
+ break;
+ case kIROp_IRem:
+ opCode = isSigned ? SpvOpSRem : SpvOpUMod;
+ break;
+ case kIROp_FRem:
+ opCode = SpvOpFRem;
+ break;
+ case kIROp_Less:
+ opCode = isFloatingPoint ? SpvOpFOrdLessThan
+ : isSigned ? SpvOpSLessThan
+ : SpvOpULessThan;
+ break;
+ case kIROp_Leq:
+ opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual
+ : isSigned ? SpvOpSLessThanEqual
+ : SpvOpULessThanEqual;
+ break;
+ case kIROp_Eql:
+ opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual;
+ break;
+ case kIROp_Neq:
+ opCode = isFloatingPoint ? SpvOpFUnordNotEqual
+ : isBool ? SpvOpLogicalNotEqual
+ : SpvOpINotEqual;
+ break;
+ case kIROp_Geq:
+ opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual
+ : isSigned ? SpvOpSGreaterThanEqual
+ : SpvOpUGreaterThanEqual;
+ break;
+ case kIROp_Greater:
+ opCode = isFloatingPoint ? SpvOpFOrdGreaterThan
+ : isSigned ? SpvOpSGreaterThan
+ : SpvOpUGreaterThan;
+ break;
+ case kIROp_Neg:
+ opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate;
+ break;
+ case kIROp_And:
+ opCode = SpvOpLogicalAnd;
+ break;
+ case kIROp_Or:
+ opCode = SpvOpLogicalOr;
+ break;
+ case kIROp_Not:
+ opCode = SpvOpLogicalNot;
+ break;
+ case kIROp_BitAnd:
+ if (isBool)
+ opCode = SpvOpLogicalAnd;
+ else
+ opCode = SpvOpBitwiseAnd;
+ break;
+ case kIROp_BitOr:
+ if (isBool)
+ opCode = SpvOpLogicalOr;
+ else
+ opCode = SpvOpBitwiseOr;
+ break;
+ case kIROp_BitXor:
+ if (isBool)
+ opCode = SpvOpLogicalNotEqual;
+ else
+ opCode = SpvOpBitwiseXor;
+ break;
+ case kIROp_BitNot:
+ if (isBool)
+ opCode = SpvOpLogicalNot;
+ else
+ opCode = SpvOpNot;
+ break;
+ case kIROp_Rsh:
+ opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical;
+ break;
+ case kIROp_Lsh:
+ opCode = SpvOpShiftLeftLogical;
+ break;
+ }
+ return opCode;
+ }
/// Ensure that an instruction has been emitted
SpvInst* ensureInst(IRInst* irInst)
{
@@ -1972,8 +2113,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
as<IRDebugInlinedAt>(inst));
default:
{
- if (as<IRSPIRVAsmOperand>(inst))
+ if (isSpecConstRateType(inst->getFullType()))
+ return emitSpecializationConstantOp(inst);
+
+ else if (as<IRSPIRVAsmOperand>(inst))
return nullptr;
+
String e = "Unhandled global inst in spirv-emit:\n" +
dumpIRToString(inst, {IRDumpOptions::Mode::Detailed, 0});
SLANG_UNIMPLEMENTED_X(e.begin());
@@ -2756,6 +2901,66 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return result;
}
+ SpvInst* emitSpecializationConstantOp(IRInst* inst)
+ {
+ SpvInst* spv = nullptr;
+ if (m_mapIRInstToSpvInst.tryGetValue(inst, spv))
+ return spv;
+
+ // For each OpSpecConstantOp, the operand must be:
+ // 1. A specialization constant
+ // 2. A literal constant
+ // 3. Another OpSpecConstantOp
+
+ // For 1 and 2, we can just emit the specialization constant or literal constant.
+ if (auto param = as<IRGlobalParam>(inst))
+ {
+ auto layout = getVarLayout(param);
+ if (layout)
+ {
+ if (auto offset =
+ layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant))
+ {
+ return emitSpecializationConstant(param, offset);
+ }
+ }
+ SLANG_UNREACHABLE("Non specialization constant used in OpSpecConstantOp\n");
+ }
+ else if (as<IRConstant>(inst))
+ {
+ // We need to emit the constant as a specialization constant
+ return emitLit(inst);
+ }
+
+ IRType* type = inst->getOperand(0)->getDataType();
+ IRBasicType* basicType = as<IRBasicType>(type);
+ SpvOp opCode = _specConstantOpcodeConvert(inst->getOp(), basicType);
+ if (opCode == SpvOpUndef)
+ {
+ String e = "Unhandled inst in spirv-emit:\n" +
+ dumpIRToString(inst, {IRDumpOptions::Mode::Detailed, 0});
+ SLANG_UNIMPLEMENTED_X(e.getBuffer());
+ }
+
+ Array<SpvInst*, 3> operands;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto operand = inst->getOperand(i);
+ SpvInst* spvInst = emitSpecializationConstantOp(operand);
+ operands.add(spvInst);
+ }
+
+ auto resultType = inst->getFullType();
+ return emitInst(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ inst,
+ SpvOpSpecConstantOp,
+ resultType,
+ kResultID,
+ opCode,
+ operands);
+ }
+
/// Emit a global parameter definition.
SpvInst* emitGlobalParam(IRGlobalParam* param)
{
@@ -7197,117 +7402,13 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType());
IRBasicType* basicType = as<IRBasicType>(elementType);
- bool isFloatingPoint = false;
- bool isBool = false;
- switch (basicType->getBaseType())
- {
- case BaseType::Float:
- case BaseType::Double:
- case BaseType::Half:
- isFloatingPoint = true;
- break;
- case BaseType::Bool:
- isBool = true;
- break;
- default:
- break;
- }
- SpvOp opCode = SpvOpUndef;
- bool isSigned = isSignedType(basicType);
- switch (op)
- {
- case kIROp_Add:
- opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd;
- break;
- case kIROp_Sub:
- opCode = isFloatingPoint ? SpvOpFSub : SpvOpISub;
- break;
- case kIROp_Mul:
- opCode = isFloatingPoint ? SpvOpFMul : SpvOpIMul;
- break;
- case kIROp_Div:
- opCode = isFloatingPoint ? SpvOpFDiv : isSigned ? SpvOpSDiv : SpvOpUDiv;
- break;
- case kIROp_IRem:
- opCode = isSigned ? SpvOpSRem : SpvOpUMod;
- break;
- case kIROp_FRem:
- opCode = SpvOpFRem;
- break;
- case kIROp_Less:
- opCode = isFloatingPoint ? SpvOpFOrdLessThan
- : isSigned ? SpvOpSLessThan
- : SpvOpULessThan;
- break;
- case kIROp_Leq:
- opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual
- : isSigned ? SpvOpSLessThanEqual
- : SpvOpULessThanEqual;
- break;
- case kIROp_Eql:
- opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual;
- break;
- case kIROp_Neq:
- opCode = isFloatingPoint ? SpvOpFUnordNotEqual
- : isBool ? SpvOpLogicalNotEqual
- : SpvOpINotEqual;
- break;
- case kIROp_Geq:
- opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual
- : isSigned ? SpvOpSGreaterThanEqual
- : SpvOpUGreaterThanEqual;
- break;
- case kIROp_Greater:
- opCode = isFloatingPoint ? SpvOpFOrdGreaterThan
- : isSigned ? SpvOpSGreaterThan
- : SpvOpUGreaterThan;
- break;
- case kIROp_Neg:
- opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate;
- break;
- case kIROp_And:
- opCode = SpvOpLogicalAnd;
- break;
- case kIROp_Or:
- opCode = SpvOpLogicalOr;
- break;
- case kIROp_Not:
- opCode = SpvOpLogicalNot;
- break;
- case kIROp_BitAnd:
- if (isBool)
- opCode = SpvOpLogicalAnd;
- else
- opCode = SpvOpBitwiseAnd;
- break;
- case kIROp_BitOr:
- if (isBool)
- opCode = SpvOpLogicalOr;
- else
- opCode = SpvOpBitwiseOr;
- break;
- case kIROp_BitXor:
- if (isBool)
- opCode = SpvOpLogicalNotEqual;
- else
- opCode = SpvOpBitwiseXor;
- break;
- case kIROp_BitNot:
- if (isBool)
- opCode = SpvOpLogicalNot;
- else
- opCode = SpvOpNot;
- break;
- case kIROp_Rsh:
- opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical;
- break;
- case kIROp_Lsh:
- opCode = SpvOpShiftLeftLogical;
- break;
- default:
+
+ SpvOp opCode = _arithmeticOpCodeConvert(op, basicType);
+ if (opCode == SpvOpUndef)
SLANG_ASSERT(!"unknown arithmetic opcode");
- break;
- }
+
+ bool isFloatingPoint = (getTypeStyle(basicType->getBaseType()) == kIROp_FloatType);
+
if (operandCount == 1)
{
return emitInst(parent, instToRegister, opCode, type, kResultID, operands);
@@ -7846,7 +7947,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
emitDebugType(arrayType->getElementType()),
sizedArrayType ? builder.getIntValue(
builder.getUIntType(),
- getIntVal(sizedArrayType->getElementCount()))
+ getArraySizeVal(sizedArrayType->getElementCount()))
: builder.getIntValue(builder.getUIntType(), 0));
}
else if (auto vectorType = as<IRVectorType>(type))