diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2023-08-08 06:01:55 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-07 15:01:55 -0700 |
| commit | 03c6cda7552ab2abe0443fbb4b0ea37b43f60fa5 (patch) | |
| tree | e50eec86333ba788374d4d2382ff874725fe6964 /source | |
| parent | 0d803a4c934ccfbb1922b86a7b09a7e98c77211a (diff) | |
Casting and vector/scalar correct arithmetic ops for SPIR-V (#3056)
* types for cast instructions
* Information getting functions for int and float types
* Implement spirv casting
* Broadcast operands for SPIR-V arithmetic
SPIR-V doesn't allow vector/sclar arithmetic ops
* Simplify spirv int/float type generation
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 196 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 19 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 49 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 19 |
4 files changed, 257 insertions, 26 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 30d1b4ecb..b6e1d15c0 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1035,29 +1035,32 @@ struct SPIRVEmitContext // > OpTypeInt -#define CASE(IROP, BITS, SIGNED) \ - case IROP: \ - return emitTypeInst(inst, SpvOpTypeInt, makeArray<SpvWord>((SpvWord)BITS, (SpvWord)SIGNED).getView()); - - CASE(kIROp_IntType, 32, 1); - CASE(kIROp_UIntType, 32, 0); - CASE(kIROp_Int64Type, 64, 1); - CASE(kIROp_UInt64Type, 64, 0); - -#undef CASE + case kIROp_UInt8Type: + case kIROp_UInt16Type: + case kIROp_UIntType: + case kIROp_UInt64Type: + case kIROp_Int8Type: + case kIROp_Int16Type: + case kIROp_IntType: + case kIROp_Int64Type: + { + const IntInfo i = getIntTypeInfo(as<IRType>(inst)); + return emitTypeInst( + inst, + SpvOpTypeInt, + makeArray(static_cast<SpvWord>(i.width), SpvWord{i.isSigned}).getView()); + } // > OpTypeFloat -#define CASE(IROP, BITS) \ - case IROP: \ - return emitTypeInst( \ - inst, SpvOpTypeFloat, makeArray<SpvWord>(BITS).getView()); \ - - CASE(kIROp_HalfType, 16); - CASE(kIROp_FloatType, 32); - CASE(kIROp_DoubleType, 64); + case kIROp_HalfType: + case kIROp_FloatType: + case kIROp_DoubleType: + { + const FloatInfo i = getFloatingTypeInfo(as<IRType>(inst)); + return emitTypeInst(inst, SpvOpTypeFloat, makeArray(static_cast<SpvWord>(i.width)).getView()); + } -#undef CASE case kIROp_PtrType: case kIROp_RefType: case kIROp_OutType: @@ -1619,9 +1622,13 @@ struct SPIRVEmitContext case kIROp_swizzle: return emitSwizzle(parent, as<IRSwizzle>(inst)); case kIROp_IntCast: + return emitIntCast(parent, as<IRIntCast>(inst)); case kIROp_FloatCast: + return emitFloatCast(parent, as<IRFloatCast>(inst)); case kIROp_CastIntToFloat: + return emitIntToFloatCast(parent, as<IRCastIntToFloat>(inst)); case kIROp_CastFloatToInt: + return emitFloatToIntCast(parent, as<IRCastFloatToInt>(inst)); case kIROp_MatrixReshape: case kIROp_VectorReshape: // TODO: break emitConstruct into separate functions for each opcode. @@ -2667,6 +2674,101 @@ struct SPIRVEmitContext } } + IRType* dropVector(IRType* t) + { + if(const auto v = as<IRVectorType>(t)) + return v->getElementType(); + return t; + }; + + SpvInst* emitIntCast(SpvInstParent* parent, IRIntCast* inst) + { + const auto fromTypeV = inst->getOperand(0)->getDataType(); + const auto toTypeV = inst->getDataType(); + SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); + const auto fromType = dropVector(fromTypeV); + const auto toType = dropVector(toTypeV); + SLANG_ASSERT(isIntegralType(fromType)); + SLANG_ASSERT(isIntegralType(toType)); + + const auto fromInfo = getIntTypeInfo(fromType); + const auto toInfo = getIntTypeInfo(toType); + + const auto convertWith = [&](auto op){ + return emitInst(parent, inst, op, toTypeV, kResultID, inst->getOperand(0)); + }; + if(fromInfo == toInfo) + return convertWith(SpvOpCopyObject); + else if(fromInfo.width == toInfo.width) + return convertWith(SpvOpBitcast); + else if(!fromInfo.isSigned && !toInfo.isSigned) + // unsigned to unsigned, don't sign extend + return convertWith(SpvOpUConvert); + else if(toInfo.isSigned) + // unsigned to signed, sign extend + return convertWith(SpvOpSConvert); + else if(fromInfo.isSigned) + // signed to unsigned, sign extend + return convertWith(SpvOpSConvert); + else if(fromInfo.isSigned && toInfo.isSigned) + // signed to signed, sign extend + return convertWith(SpvOpSConvert); + + SLANG_UNREACHABLE(__func__); + } + + SpvInst* emitFloatCast(SpvInstParent* parent, IRFloatCast* inst) + { + const auto fromTypeV = inst->getOperand(0)->getDataType(); + const auto toTypeV = inst->getDataType(); + SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); + const auto fromType = dropVector(fromTypeV); + const auto toType = dropVector(toTypeV); + SLANG_ASSERT(isFloatingType(fromType)); + SLANG_ASSERT(isFloatingType(toType)); + SLANG_ASSERT(!isTypeEqual(fromType, toType)); + + return emitInst(parent, inst, SpvOpFConvert, toTypeV, kResultID, inst->getOperand(0)); + } + + SpvInst* emitIntToFloatCast(SpvInstParent* parent, IRCastIntToFloat* inst) + { + const auto fromTypeV = inst->getOperand(0)->getDataType(); + const auto toTypeV = inst->getDataType(); + SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); + const auto fromType = dropVector(fromTypeV); + const auto toType = dropVector(toTypeV); + SLANG_ASSERT(isIntegralType(fromType)); + SLANG_ASSERT(isFloatingType(toType)); + + const auto fromInfo = getIntTypeInfo(fromType); + + const auto convertWith = [&](auto op){ + return emitInst(parent, inst, op, toTypeV, kResultID, inst->getOperand(0)); + }; + + return convertWith(fromInfo.isSigned ? SpvOpConvertSToF : SpvOpConvertUToF); + } + + SpvInst* emitFloatToIntCast(SpvInstParent* parent, IRCastFloatToInt* inst) + { + const auto fromTypeV = inst->getOperand(0)->getDataType(); + const auto toTypeV = inst->getDataType(); + SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); + const auto fromType = dropVector(fromTypeV); + const auto toType = dropVector(toTypeV); + SLANG_ASSERT(isFloatingType(fromType)); + SLANG_ASSERT(isIntegralType(toType)); + + const auto toInfo = getIntTypeInfo(toType); + + const auto convertWith = [&](auto op){ + return emitInst(parent, inst, op, toTypeV, kResultID, inst->getOperand(0)); + }; + + return convertWith(toInfo.isSigned ? SpvOpConvertFToS : SpvOpConvertFToU); + } + SpvInst* emitConstruct(SpvInstParent* parent, IRInst* inst) { if (as<IRBasicType>(inst->getDataType())) @@ -2708,6 +2810,25 @@ struct SPIRVEmitContext } } + SpvInst* emitSplat(SpvInstParent* parent, IRInst* scalar, IRIntegerValue numElems) + { + const auto scalarTy = as<IRBasicType>(scalar->getDataType()); + const auto spvVecTy = ensureVectorType( + scalarTy->getBaseType(), + numElems, + nullptr); + return emitInstCustomOperandFunc( + parent, + nullptr, + SpvOpCompositeConstruct, + [&](){ + emitOperand(spvVecTy); + emitOperand(kResultID); + for(Int i = 0; i < numElems; ++i) + emitOperand(scalar); + }); + } + bool isSignedType(IRType* type) { switch (type->getOp()) @@ -2748,12 +2869,8 @@ struct SPIRVEmitContext SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) { - IRType* elementType = inst->getOperand(0)->getDataType(); - if (auto vectorType = as<IRVectorType>(inst->getDataType())) - { - elementType = vectorType->getElementType(); - } - else if (const auto matrixType = as<IRMatrixType>(inst->getDataType())) + IRType* elementType = dropVector(inst->getOperand(0)->getDataType()); + if (const auto matrixType = as<IRMatrixType>(inst->getDataType())) { //TODO: implement. SLANG_ASSERT(!"unimplemented: matrix arithemetic"); @@ -2852,7 +2969,34 @@ struct SPIRVEmitContext SLANG_ASSERT(!"unknown arithmetic opcode"); break; } - return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, OperandsOf(inst)); + if(inst->getOperandCount() == 1) + { + return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, OperandsOf(inst)); + } + else if(inst->getOperandCount() == 2) + { + auto l = inst->getOperand(0); + const auto lVec = as<IRVectorType>(l->getDataType()); + auto r = inst->getOperand(1); + const auto rVec = as<IRVectorType>(r->getDataType()); + const auto go = [&](const auto l, const auto r){ + return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, l, r); + }; + if(lVec && !rVec) + { + const auto len = as<IRIntLit>(lVec->getElementCount()); + SLANG_ASSERT(len); + return go(l, emitSplat(parent, r, len->getValue())); + } + else if (!lVec && rVec) + { + const auto len = as<IRIntLit>(rVec->getElementCount()); + SLANG_ASSERT(len); + return go(emitSplat(parent, l, len->getValue()), r); + } + return go(l, r); + } + SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands"); } OrderedHashSet<SpvCapability> m_capabilities; diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 95f72b3cd..123cc33c6 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2673,6 +2673,25 @@ struct IRGetRegisterSpace : IRBindingQuery IR_LEAF_ISA(GetRegisterSpace); }; +struct IRIntCast : IRInst +{ + IR_LEAF_ISA(IntCast) +}; + +struct IRFloatCast : IRInst +{ + IR_LEAF_ISA(FloatCast) +}; + +struct IRCastIntToFloat : IRInst +{ + IR_LEAF_ISA(CastIntToFloat) +}; + +struct IRCastFloatToInt : IRInst +{ + IR_LEAF_ISA(CastFloatToInt) +}; struct IRBuilderSourceLocRAII; diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 38d1eb520..0a79cec57 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6690,6 +6690,55 @@ namespace Slang return false; } + bool isFloatingType(IRType *t) + { + if(auto basic = as<IRBasicType>(t)) + { + switch(basic->getBaseType()) + { + case BaseType::Float: + case BaseType::Half: + case BaseType::Double: + return true; + default: + return false; + } + } + return false; + } + + IntInfo getIntTypeInfo(const IRType* intType) + { + switch(intType->getOp()) + { + case kIROp_UInt8Type: return {8, false}; + case kIROp_UInt16Type: return {16, false}; + case kIROp_UIntType: return {32, false}; + case kIROp_UInt64Type: return {64, false}; + case kIROp_Int8Type: return {8, true}; + case kIROp_Int16Type: return {16, true}; + case kIROp_IntType: return {32, true}; + case kIROp_Int64Type: return {64, true}; + + case kIROp_IntPtrType: // target platform dependent + case kIROp_UIntPtrType: // target platform dependent + default: + SLANG_UNEXPECTED("Unhandled type passed to getIntTypeInfo"); + } + } + + FloatInfo getFloatingTypeInfo(const IRType* floatType) + { + switch(floatType->getOp()) + { + case kIROp_HalfType: return {16}; + case kIROp_FloatType: return {32}; + case kIROp_DoubleType: return {64}; + default: + SLANG_UNEXPECTED("Unhandled type passed to getFloatTypeInfo"); + } + } + bool isIntegralScalarOrCompositeType(IRType* t) { if (!t) diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 97f98fce2..3cd8e9126 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1005,6 +1005,25 @@ bool isTypeEqual(IRType* a, IRType* b); // True if this is an integral IRBasicType, not including Char or Ptr types bool isIntegralType(IRType* t); +bool isFloatingType(IRType* t); + +struct IntInfo +{ + Int width; + bool isSigned; + bool operator==(const IntInfo& i) const { return width == i.width && isSigned == i.isSigned; } +}; + +IntInfo getIntTypeInfo(const IRType* intType); + +struct FloatInfo +{ + Int width; + bool operator==(const FloatInfo& i) const { return width == i.width; } +}; + +FloatInfo getFloatingTypeInfo(const IRType* floatType); + bool isIntegralScalarOrCompositeType(IRType* t); void findAllInstsBreadthFirst(IRInst* inst, List<IRInst*>& outInsts); |
