summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2023-08-08 06:01:55 +0800
committerGitHub <noreply@github.com>2023-08-07 15:01:55 -0700
commit03c6cda7552ab2abe0443fbb4b0ea37b43f60fa5 (patch)
treee50eec86333ba788374d4d2382ff874725fe6964 /source
parent0d803a4c934ccfbb1922b86a7b09a7e98c77211a (diff)
Casting and vector/scalar correct arithmetic ops for SPIR-V (#3056)
* types for cast instructions * Information getting functions for int and float types * Implement spirv casting * Broadcast operands for SPIR-V arithmetic SPIR-V doesn't allow vector/sclar arithmetic ops * Simplify spirv int/float type generation --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-emit-spirv.cpp196
-rw-r--r--source/slang/slang-ir-insts.h19
-rw-r--r--source/slang/slang-ir.cpp49
-rw-r--r--source/slang/slang-ir.h19
4 files changed, 257 insertions, 26 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 30d1b4ecb..b6e1d15c0 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1035,29 +1035,32 @@ struct SPIRVEmitContext
// > OpTypeInt
-#define CASE(IROP, BITS, SIGNED) \
- case IROP: \
- return emitTypeInst(inst, SpvOpTypeInt, makeArray<SpvWord>((SpvWord)BITS, (SpvWord)SIGNED).getView());
-
- CASE(kIROp_IntType, 32, 1);
- CASE(kIROp_UIntType, 32, 0);
- CASE(kIROp_Int64Type, 64, 1);
- CASE(kIROp_UInt64Type, 64, 0);
-
-#undef CASE
+ case kIROp_UInt8Type:
+ case kIROp_UInt16Type:
+ case kIROp_UIntType:
+ case kIROp_UInt64Type:
+ case kIROp_Int8Type:
+ case kIROp_Int16Type:
+ case kIROp_IntType:
+ case kIROp_Int64Type:
+ {
+ const IntInfo i = getIntTypeInfo(as<IRType>(inst));
+ return emitTypeInst(
+ inst,
+ SpvOpTypeInt,
+ makeArray(static_cast<SpvWord>(i.width), SpvWord{i.isSigned}).getView());
+ }
// > OpTypeFloat
-#define CASE(IROP, BITS) \
- case IROP: \
- return emitTypeInst( \
- inst, SpvOpTypeFloat, makeArray<SpvWord>(BITS).getView()); \
-
- CASE(kIROp_HalfType, 16);
- CASE(kIROp_FloatType, 32);
- CASE(kIROp_DoubleType, 64);
+ case kIROp_HalfType:
+ case kIROp_FloatType:
+ case kIROp_DoubleType:
+ {
+ const FloatInfo i = getFloatingTypeInfo(as<IRType>(inst));
+ return emitTypeInst(inst, SpvOpTypeFloat, makeArray(static_cast<SpvWord>(i.width)).getView());
+ }
-#undef CASE
case kIROp_PtrType:
case kIROp_RefType:
case kIROp_OutType:
@@ -1619,9 +1622,13 @@ struct SPIRVEmitContext
case kIROp_swizzle:
return emitSwizzle(parent, as<IRSwizzle>(inst));
case kIROp_IntCast:
+ return emitIntCast(parent, as<IRIntCast>(inst));
case kIROp_FloatCast:
+ return emitFloatCast(parent, as<IRFloatCast>(inst));
case kIROp_CastIntToFloat:
+ return emitIntToFloatCast(parent, as<IRCastIntToFloat>(inst));
case kIROp_CastFloatToInt:
+ return emitFloatToIntCast(parent, as<IRCastFloatToInt>(inst));
case kIROp_MatrixReshape:
case kIROp_VectorReshape:
// TODO: break emitConstruct into separate functions for each opcode.
@@ -2667,6 +2674,101 @@ struct SPIRVEmitContext
}
}
+ IRType* dropVector(IRType* t)
+ {
+ if(const auto v = as<IRVectorType>(t))
+ return v->getElementType();
+ return t;
+ };
+
+ SpvInst* emitIntCast(SpvInstParent* parent, IRIntCast* inst)
+ {
+ const auto fromTypeV = inst->getOperand(0)->getDataType();
+ const auto toTypeV = inst->getDataType();
+ SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
+ const auto fromType = dropVector(fromTypeV);
+ const auto toType = dropVector(toTypeV);
+ SLANG_ASSERT(isIntegralType(fromType));
+ SLANG_ASSERT(isIntegralType(toType));
+
+ const auto fromInfo = getIntTypeInfo(fromType);
+ const auto toInfo = getIntTypeInfo(toType);
+
+ const auto convertWith = [&](auto op){
+ return emitInst(parent, inst, op, toTypeV, kResultID, inst->getOperand(0));
+ };
+ if(fromInfo == toInfo)
+ return convertWith(SpvOpCopyObject);
+ else if(fromInfo.width == toInfo.width)
+ return convertWith(SpvOpBitcast);
+ else if(!fromInfo.isSigned && !toInfo.isSigned)
+ // unsigned to unsigned, don't sign extend
+ return convertWith(SpvOpUConvert);
+ else if(toInfo.isSigned)
+ // unsigned to signed, sign extend
+ return convertWith(SpvOpSConvert);
+ else if(fromInfo.isSigned)
+ // signed to unsigned, sign extend
+ return convertWith(SpvOpSConvert);
+ else if(fromInfo.isSigned && toInfo.isSigned)
+ // signed to signed, sign extend
+ return convertWith(SpvOpSConvert);
+
+ SLANG_UNREACHABLE(__func__);
+ }
+
+ SpvInst* emitFloatCast(SpvInstParent* parent, IRFloatCast* inst)
+ {
+ const auto fromTypeV = inst->getOperand(0)->getDataType();
+ const auto toTypeV = inst->getDataType();
+ SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
+ const auto fromType = dropVector(fromTypeV);
+ const auto toType = dropVector(toTypeV);
+ SLANG_ASSERT(isFloatingType(fromType));
+ SLANG_ASSERT(isFloatingType(toType));
+ SLANG_ASSERT(!isTypeEqual(fromType, toType));
+
+ return emitInst(parent, inst, SpvOpFConvert, toTypeV, kResultID, inst->getOperand(0));
+ }
+
+ SpvInst* emitIntToFloatCast(SpvInstParent* parent, IRCastIntToFloat* inst)
+ {
+ const auto fromTypeV = inst->getOperand(0)->getDataType();
+ const auto toTypeV = inst->getDataType();
+ SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
+ const auto fromType = dropVector(fromTypeV);
+ const auto toType = dropVector(toTypeV);
+ SLANG_ASSERT(isIntegralType(fromType));
+ SLANG_ASSERT(isFloatingType(toType));
+
+ const auto fromInfo = getIntTypeInfo(fromType);
+
+ const auto convertWith = [&](auto op){
+ return emitInst(parent, inst, op, toTypeV, kResultID, inst->getOperand(0));
+ };
+
+ return convertWith(fromInfo.isSigned ? SpvOpConvertSToF : SpvOpConvertUToF);
+ }
+
+ SpvInst* emitFloatToIntCast(SpvInstParent* parent, IRCastFloatToInt* inst)
+ {
+ const auto fromTypeV = inst->getOperand(0)->getDataType();
+ const auto toTypeV = inst->getDataType();
+ SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
+ const auto fromType = dropVector(fromTypeV);
+ const auto toType = dropVector(toTypeV);
+ SLANG_ASSERT(isFloatingType(fromType));
+ SLANG_ASSERT(isIntegralType(toType));
+
+ const auto toInfo = getIntTypeInfo(toType);
+
+ const auto convertWith = [&](auto op){
+ return emitInst(parent, inst, op, toTypeV, kResultID, inst->getOperand(0));
+ };
+
+ return convertWith(toInfo.isSigned ? SpvOpConvertFToS : SpvOpConvertFToU);
+ }
+
SpvInst* emitConstruct(SpvInstParent* parent, IRInst* inst)
{
if (as<IRBasicType>(inst->getDataType()))
@@ -2708,6 +2810,25 @@ struct SPIRVEmitContext
}
}
+ SpvInst* emitSplat(SpvInstParent* parent, IRInst* scalar, IRIntegerValue numElems)
+ {
+ const auto scalarTy = as<IRBasicType>(scalar->getDataType());
+ const auto spvVecTy = ensureVectorType(
+ scalarTy->getBaseType(),
+ numElems,
+ nullptr);
+ return emitInstCustomOperandFunc(
+ parent,
+ nullptr,
+ SpvOpCompositeConstruct,
+ [&](){
+ emitOperand(spvVecTy);
+ emitOperand(kResultID);
+ for(Int i = 0; i < numElems; ++i)
+ emitOperand(scalar);
+ });
+ }
+
bool isSignedType(IRType* type)
{
switch (type->getOp())
@@ -2748,12 +2869,8 @@ struct SPIRVEmitContext
SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst)
{
- IRType* elementType = inst->getOperand(0)->getDataType();
- if (auto vectorType = as<IRVectorType>(inst->getDataType()))
- {
- elementType = vectorType->getElementType();
- }
- else if (const auto matrixType = as<IRMatrixType>(inst->getDataType()))
+ IRType* elementType = dropVector(inst->getOperand(0)->getDataType());
+ if (const auto matrixType = as<IRMatrixType>(inst->getDataType()))
{
//TODO: implement.
SLANG_ASSERT(!"unimplemented: matrix arithemetic");
@@ -2852,7 +2969,34 @@ struct SPIRVEmitContext
SLANG_ASSERT(!"unknown arithmetic opcode");
break;
}
- return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, OperandsOf(inst));
+ if(inst->getOperandCount() == 1)
+ {
+ return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, OperandsOf(inst));
+ }
+ else if(inst->getOperandCount() == 2)
+ {
+ auto l = inst->getOperand(0);
+ const auto lVec = as<IRVectorType>(l->getDataType());
+ auto r = inst->getOperand(1);
+ const auto rVec = as<IRVectorType>(r->getDataType());
+ const auto go = [&](const auto l, const auto r){
+ return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, l, r);
+ };
+ if(lVec && !rVec)
+ {
+ const auto len = as<IRIntLit>(lVec->getElementCount());
+ SLANG_ASSERT(len);
+ return go(l, emitSplat(parent, r, len->getValue()));
+ }
+ else if (!lVec && rVec)
+ {
+ const auto len = as<IRIntLit>(rVec->getElementCount());
+ SLANG_ASSERT(len);
+ return go(emitSplat(parent, l, len->getValue()), r);
+ }
+ return go(l, r);
+ }
+ SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands");
}
OrderedHashSet<SpvCapability> m_capabilities;
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 95f72b3cd..123cc33c6 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2673,6 +2673,25 @@ struct IRGetRegisterSpace : IRBindingQuery
IR_LEAF_ISA(GetRegisterSpace);
};
+struct IRIntCast : IRInst
+{
+ IR_LEAF_ISA(IntCast)
+};
+
+struct IRFloatCast : IRInst
+{
+ IR_LEAF_ISA(FloatCast)
+};
+
+struct IRCastIntToFloat : IRInst
+{
+ IR_LEAF_ISA(CastIntToFloat)
+};
+
+struct IRCastFloatToInt : IRInst
+{
+ IR_LEAF_ISA(CastFloatToInt)
+};
struct IRBuilderSourceLocRAII;
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 38d1eb520..0a79cec57 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -6690,6 +6690,55 @@ namespace Slang
return false;
}
+ bool isFloatingType(IRType *t)
+ {
+ if(auto basic = as<IRBasicType>(t))
+ {
+ switch(basic->getBaseType())
+ {
+ case BaseType::Float:
+ case BaseType::Half:
+ case BaseType::Double:
+ return true;
+ default:
+ return false;
+ }
+ }
+ return false;
+ }
+
+ IntInfo getIntTypeInfo(const IRType* intType)
+ {
+ switch(intType->getOp())
+ {
+ case kIROp_UInt8Type: return {8, false};
+ case kIROp_UInt16Type: return {16, false};
+ case kIROp_UIntType: return {32, false};
+ case kIROp_UInt64Type: return {64, false};
+ case kIROp_Int8Type: return {8, true};
+ case kIROp_Int16Type: return {16, true};
+ case kIROp_IntType: return {32, true};
+ case kIROp_Int64Type: return {64, true};
+
+ case kIROp_IntPtrType: // target platform dependent
+ case kIROp_UIntPtrType: // target platform dependent
+ default:
+ SLANG_UNEXPECTED("Unhandled type passed to getIntTypeInfo");
+ }
+ }
+
+ FloatInfo getFloatingTypeInfo(const IRType* floatType)
+ {
+ switch(floatType->getOp())
+ {
+ case kIROp_HalfType: return {16};
+ case kIROp_FloatType: return {32};
+ case kIROp_DoubleType: return {64};
+ default:
+ SLANG_UNEXPECTED("Unhandled type passed to getFloatTypeInfo");
+ }
+ }
+
bool isIntegralScalarOrCompositeType(IRType* t)
{
if (!t)
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 97f98fce2..3cd8e9126 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1005,6 +1005,25 @@ bool isTypeEqual(IRType* a, IRType* b);
// True if this is an integral IRBasicType, not including Char or Ptr types
bool isIntegralType(IRType* t);
+bool isFloatingType(IRType* t);
+
+struct IntInfo
+{
+ Int width;
+ bool isSigned;
+ bool operator==(const IntInfo& i) const { return width == i.width && isSigned == i.isSigned; }
+};
+
+IntInfo getIntTypeInfo(const IRType* intType);
+
+struct FloatInfo
+{
+ Int width;
+ bool operator==(const FloatInfo& i) const { return width == i.width; }
+};
+
+FloatInfo getFloatingTypeInfo(const IRType* floatType);
+
bool isIntegralScalarOrCompositeType(IRType* t);
void findAllInstsBreadthFirst(IRInst* inst, List<IRInst*>& outInsts);