summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2023-08-08 06:01:55 +0800
committerGitHub <noreply@github.com>2023-08-07 15:01:55 -0700
commit03c6cda7552ab2abe0443fbb4b0ea37b43f60fa5 (patch)
treee50eec86333ba788374d4d2382ff874725fe6964 /source/slang/slang-emit-spirv.cpp
parent0d803a4c934ccfbb1922b86a7b09a7e98c77211a (diff)
Casting and vector/scalar correct arithmetic ops for SPIR-V (#3056)
* types for cast instructions * Information getting functions for int and float types * Implement spirv casting * Broadcast operands for SPIR-V arithmetic SPIR-V doesn't allow vector/sclar arithmetic ops * Simplify spirv int/float type generation --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp196
1 files changed, 170 insertions, 26 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 30d1b4ecb..b6e1d15c0 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1035,29 +1035,32 @@ struct SPIRVEmitContext
// > OpTypeInt
-#define CASE(IROP, BITS, SIGNED) \
- case IROP: \
- return emitTypeInst(inst, SpvOpTypeInt, makeArray<SpvWord>((SpvWord)BITS, (SpvWord)SIGNED).getView());
-
- CASE(kIROp_IntType, 32, 1);
- CASE(kIROp_UIntType, 32, 0);
- CASE(kIROp_Int64Type, 64, 1);
- CASE(kIROp_UInt64Type, 64, 0);
-
-#undef CASE
+ case kIROp_UInt8Type:
+ case kIROp_UInt16Type:
+ case kIROp_UIntType:
+ case kIROp_UInt64Type:
+ case kIROp_Int8Type:
+ case kIROp_Int16Type:
+ case kIROp_IntType:
+ case kIROp_Int64Type:
+ {
+ const IntInfo i = getIntTypeInfo(as<IRType>(inst));
+ return emitTypeInst(
+ inst,
+ SpvOpTypeInt,
+ makeArray(static_cast<SpvWord>(i.width), SpvWord{i.isSigned}).getView());
+ }
// > OpTypeFloat
-#define CASE(IROP, BITS) \
- case IROP: \
- return emitTypeInst( \
- inst, SpvOpTypeFloat, makeArray<SpvWord>(BITS).getView()); \
-
- CASE(kIROp_HalfType, 16);
- CASE(kIROp_FloatType, 32);
- CASE(kIROp_DoubleType, 64);
+ case kIROp_HalfType:
+ case kIROp_FloatType:
+ case kIROp_DoubleType:
+ {
+ const FloatInfo i = getFloatingTypeInfo(as<IRType>(inst));
+ return emitTypeInst(inst, SpvOpTypeFloat, makeArray(static_cast<SpvWord>(i.width)).getView());
+ }
-#undef CASE
case kIROp_PtrType:
case kIROp_RefType:
case kIROp_OutType:
@@ -1619,9 +1622,13 @@ struct SPIRVEmitContext
case kIROp_swizzle:
return emitSwizzle(parent, as<IRSwizzle>(inst));
case kIROp_IntCast:
+ return emitIntCast(parent, as<IRIntCast>(inst));
case kIROp_FloatCast:
+ return emitFloatCast(parent, as<IRFloatCast>(inst));
case kIROp_CastIntToFloat:
+ return emitIntToFloatCast(parent, as<IRCastIntToFloat>(inst));
case kIROp_CastFloatToInt:
+ return emitFloatToIntCast(parent, as<IRCastFloatToInt>(inst));
case kIROp_MatrixReshape:
case kIROp_VectorReshape:
// TODO: break emitConstruct into separate functions for each opcode.
@@ -2667,6 +2674,101 @@ struct SPIRVEmitContext
}
}
+ IRType* dropVector(IRType* t)
+ {
+ if(const auto v = as<IRVectorType>(t))
+ return v->getElementType();
+ return t;
+ };
+
+ SpvInst* emitIntCast(SpvInstParent* parent, IRIntCast* inst)
+ {
+ const auto fromTypeV = inst->getOperand(0)->getDataType();
+ const auto toTypeV = inst->getDataType();
+ SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
+ const auto fromType = dropVector(fromTypeV);
+ const auto toType = dropVector(toTypeV);
+ SLANG_ASSERT(isIntegralType(fromType));
+ SLANG_ASSERT(isIntegralType(toType));
+
+ const auto fromInfo = getIntTypeInfo(fromType);
+ const auto toInfo = getIntTypeInfo(toType);
+
+ const auto convertWith = [&](auto op){
+ return emitInst(parent, inst, op, toTypeV, kResultID, inst->getOperand(0));
+ };
+ if(fromInfo == toInfo)
+ return convertWith(SpvOpCopyObject);
+ else if(fromInfo.width == toInfo.width)
+ return convertWith(SpvOpBitcast);
+ else if(!fromInfo.isSigned && !toInfo.isSigned)
+ // unsigned to unsigned, don't sign extend
+ return convertWith(SpvOpUConvert);
+ else if(toInfo.isSigned)
+ // unsigned to signed, sign extend
+ return convertWith(SpvOpSConvert);
+ else if(fromInfo.isSigned)
+ // signed to unsigned, sign extend
+ return convertWith(SpvOpSConvert);
+ else if(fromInfo.isSigned && toInfo.isSigned)
+ // signed to signed, sign extend
+ return convertWith(SpvOpSConvert);
+
+ SLANG_UNREACHABLE(__func__);
+ }
+
+ SpvInst* emitFloatCast(SpvInstParent* parent, IRFloatCast* inst)
+ {
+ const auto fromTypeV = inst->getOperand(0)->getDataType();
+ const auto toTypeV = inst->getDataType();
+ SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
+ const auto fromType = dropVector(fromTypeV);
+ const auto toType = dropVector(toTypeV);
+ SLANG_ASSERT(isFloatingType(fromType));
+ SLANG_ASSERT(isFloatingType(toType));
+ SLANG_ASSERT(!isTypeEqual(fromType, toType));
+
+ return emitInst(parent, inst, SpvOpFConvert, toTypeV, kResultID, inst->getOperand(0));
+ }
+
+ SpvInst* emitIntToFloatCast(SpvInstParent* parent, IRCastIntToFloat* inst)
+ {
+ const auto fromTypeV = inst->getOperand(0)->getDataType();
+ const auto toTypeV = inst->getDataType();
+ SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
+ const auto fromType = dropVector(fromTypeV);
+ const auto toType = dropVector(toTypeV);
+ SLANG_ASSERT(isIntegralType(fromType));
+ SLANG_ASSERT(isFloatingType(toType));
+
+ const auto fromInfo = getIntTypeInfo(fromType);
+
+ const auto convertWith = [&](auto op){
+ return emitInst(parent, inst, op, toTypeV, kResultID, inst->getOperand(0));
+ };
+
+ return convertWith(fromInfo.isSigned ? SpvOpConvertSToF : SpvOpConvertUToF);
+ }
+
+ SpvInst* emitFloatToIntCast(SpvInstParent* parent, IRCastFloatToInt* inst)
+ {
+ const auto fromTypeV = inst->getOperand(0)->getDataType();
+ const auto toTypeV = inst->getDataType();
+ SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
+ const auto fromType = dropVector(fromTypeV);
+ const auto toType = dropVector(toTypeV);
+ SLANG_ASSERT(isFloatingType(fromType));
+ SLANG_ASSERT(isIntegralType(toType));
+
+ const auto toInfo = getIntTypeInfo(toType);
+
+ const auto convertWith = [&](auto op){
+ return emitInst(parent, inst, op, toTypeV, kResultID, inst->getOperand(0));
+ };
+
+ return convertWith(toInfo.isSigned ? SpvOpConvertFToS : SpvOpConvertFToU);
+ }
+
SpvInst* emitConstruct(SpvInstParent* parent, IRInst* inst)
{
if (as<IRBasicType>(inst->getDataType()))
@@ -2708,6 +2810,25 @@ struct SPIRVEmitContext
}
}
+ SpvInst* emitSplat(SpvInstParent* parent, IRInst* scalar, IRIntegerValue numElems)
+ {
+ const auto scalarTy = as<IRBasicType>(scalar->getDataType());
+ const auto spvVecTy = ensureVectorType(
+ scalarTy->getBaseType(),
+ numElems,
+ nullptr);
+ return emitInstCustomOperandFunc(
+ parent,
+ nullptr,
+ SpvOpCompositeConstruct,
+ [&](){
+ emitOperand(spvVecTy);
+ emitOperand(kResultID);
+ for(Int i = 0; i < numElems; ++i)
+ emitOperand(scalar);
+ });
+ }
+
bool isSignedType(IRType* type)
{
switch (type->getOp())
@@ -2748,12 +2869,8 @@ struct SPIRVEmitContext
SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst)
{
- IRType* elementType = inst->getOperand(0)->getDataType();
- if (auto vectorType = as<IRVectorType>(inst->getDataType()))
- {
- elementType = vectorType->getElementType();
- }
- else if (const auto matrixType = as<IRMatrixType>(inst->getDataType()))
+ IRType* elementType = dropVector(inst->getOperand(0)->getDataType());
+ if (const auto matrixType = as<IRMatrixType>(inst->getDataType()))
{
//TODO: implement.
SLANG_ASSERT(!"unimplemented: matrix arithemetic");
@@ -2852,7 +2969,34 @@ struct SPIRVEmitContext
SLANG_ASSERT(!"unknown arithmetic opcode");
break;
}
- return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, OperandsOf(inst));
+ if(inst->getOperandCount() == 1)
+ {
+ return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, OperandsOf(inst));
+ }
+ else if(inst->getOperandCount() == 2)
+ {
+ auto l = inst->getOperand(0);
+ const auto lVec = as<IRVectorType>(l->getDataType());
+ auto r = inst->getOperand(1);
+ const auto rVec = as<IRVectorType>(r->getDataType());
+ const auto go = [&](const auto l, const auto r){
+ return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, l, r);
+ };
+ if(lVec && !rVec)
+ {
+ const auto len = as<IRIntLit>(lVec->getElementCount());
+ SLANG_ASSERT(len);
+ return go(l, emitSplat(parent, r, len->getValue()));
+ }
+ else if (!lVec && rVec)
+ {
+ const auto len = as<IRIntLit>(rVec->getElementCount());
+ SLANG_ASSERT(len);
+ return go(emitSplat(parent, l, len->getValue()), r);
+ }
+ return go(l, r);
+ }
+ SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands");
}
OrderedHashSet<SpvCapability> m_capabilities;