summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-24 16:32:33 -0700
committerGitHub <noreply@github.com>2023-08-24 16:32:33 -0700
commit0470ea05a42d6c3f35d81a433fefdd440500cdbd (patch)
tree25feb7bfd539013bfa64d8ff7698262932e39110 /source/slang/slang-emit-spirv.cpp
parentc515bf9edf0ceefa9a0c9b36626ea7c8f72ce36f (diff)
Misc. SPIRV Fixes, Part 2. (#3147)
* Misc. SPIRV Fixes, Part 2. * Fix up. * Fix. * Add system smenatic values. * 16 bit int and floats, matrix/vector reshape, bool ops. * Fix. * Fix. * Allow push constant entry point params. * entrypoint params. * swizzleSet and swizzledStore. * packoffset. * string hash. * Fix. * Matrix arithmetics. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp462
1 files changed, 413 insertions, 49 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 3febbd210..0878b3494 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1168,16 +1168,22 @@ struct SPIRVEmitContext
// > OpTypeInt
- case kIROp_UInt8Type:
case kIROp_UInt16Type:
+ case kIROp_Int16Type:
+ case kIROp_UInt8Type:
case kIROp_UIntType:
case kIROp_UInt64Type:
case kIROp_Int8Type:
- case kIROp_Int16Type:
case kIROp_IntType:
case kIROp_Int64Type:
{
const IntInfo i = getIntTypeInfo(as<IRType>(inst));
+ if (i.width == 16)
+ requireSPIRVCapability(SpvCapabilityInt16);
+ else if (i.width == 64)
+ requireSPIRVCapability(SpvCapabilityInt64);
+ else if (i.width == 8)
+ requireSPIRVCapability(SpvCapabilityInt8);
return emitOpTypeInt(
inst,
SpvLiteralInteger::from32(int32_t(i.width)),
@@ -1395,6 +1401,8 @@ struct SPIRVEmitContext
return emitGlobalParam(as<IRGlobalParam>(inst));
case kIROp_GlobalVar:
return emitGlobalVar(as<IRGlobalVar>(inst));
+ case kIROp_Var:
+ return emitVar(getSection(SpvLogicalSectionID::GlobalVariables), inst);
// ...
case kIROp_Specialize:
@@ -1436,6 +1444,8 @@ struct SPIRVEmitContext
}
return result;
}
+ case kIROp_GetStringHash:
+ return emitGetStringHash(inst);
default:
{
String e = "Unhandled global inst in spirv-emit:\n"
@@ -1468,6 +1478,8 @@ struct SPIRVEmitContext
void emitVarLayout(SpvInst* varInst, IRVarLayout* layout)
{
+ bool needDefaultSetBindingDecoration = false;
+ bool hasExplicitSetBinding = false;
for (auto rr : layout->getOffsetAttrs())
{
UInt index = rr->getOffset();
@@ -1498,15 +1510,6 @@ struct SPIRVEmitContext
varInst,
SpvLiteralInteger::from32(int32_t(index))
);
- if (space)
- {
- emitOpDecorateIndex(
- getSection(SpvLogicalSectionID::Annotations),
- nullptr,
- varInst,
- SpvLiteralInteger::from32(int32_t(space))
- );
- }
break;
case LayoutResourceKind::SpecializationConstant:
@@ -1527,24 +1530,44 @@ struct SPIRVEmitContext
getSection(SpvLogicalSectionID::Annotations),
nullptr,
varInst,
- SpvLiteralInteger::from32(int32_t(index))
- );
+ SpvLiteralInteger::from32(int32_t(index)));
+ if (space)
+ {
+ emitOpDecorateDescriptorSet(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ varInst,
+ SpvLiteralInteger::from32(int32_t(space)));
+ }
+ else
+ {
+ needDefaultSetBindingDecoration = true;
+ }
+ break;
+ case LayoutResourceKind::RegisterSpace:
emitOpDecorateDescriptorSet(
getSection(SpvLogicalSectionID::Annotations),
nullptr,
varInst,
- SpvLiteralInteger::from32(int32_t(space))
- );
+ SpvLiteralInteger::from32(int32_t(index)));
+ hasExplicitSetBinding = true;
break;
default:
break;
}
}
+ if (needDefaultSetBindingDecoration && !hasExplicitSetBinding)
+ {
+ emitOpDecorateDescriptorSet(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ varInst,
+ SpvLiteralInteger::from32(int32_t(0)));
+ }
}
/// Emit a global parameter definition.
SpvInst* emitGlobalParam(IRGlobalParam* param)
{
- auto layout = getVarLayout(param);
auto storageClass = SpvStorageClassUniform;
if (auto ptrType = as<IRPtrTypeBase>(param->getDataType()))
{
@@ -1562,7 +1585,8 @@ struct SPIRVEmitContext
param->getDataType(),
storageClass
);
- emitVarLayout(varInst, layout);
+ if (auto layout = getVarLayout(param))
+ emitVarLayout(varInst, layout);
return varInst;
}
@@ -1615,6 +1639,8 @@ struct SPIRVEmitContext
/// Emit a declaration for the given `irFunc`
SpvInst* emitFuncDeclaration(IRFunc* irFunc)
{
+ if (irFunc->findDecorationImpl(kIROp_SPIRVOpDecoration))
+ return nullptr;
// For now we aren't handling function declarations;
// we expect to deal only with fully linked modules.
//
@@ -1852,6 +1878,10 @@ struct SPIRVEmitContext
return emitLoad(parent, as<IRLoad>(inst));
case kIROp_Store:
return emitStore(parent, as<IRStore>(inst));
+ case kIROp_SwizzledStore:
+ return emitSwizzledStore(parent, as<IRSwizzledStore>(inst));
+ case kIROp_swizzleSet:
+ return emitSwizzleSet(parent, as<IRSwizzleSet>(inst));
case kIROp_RWStructuredBufferGetElementPtr:
return emitStructuredBufferGetElementPtr(parent, inst);
case kIROp_StructuredBufferGetDimensions:
@@ -1866,10 +1896,6 @@ struct SPIRVEmitContext
return emitIntToFloatCast(parent, as<IRCastIntToFloat>(inst));
case kIROp_CastFloatToInt:
return emitFloatToIntCast(parent, as<IRCastFloatToInt>(inst));
- case kIROp_MatrixReshape:
- case kIROp_VectorReshape:
- // TODO: break emitConstruct into separate functions for each opcode.
- return emitConstruct(parent, inst);
case kIROp_BitCast:
return emitOpBitcast(
parent,
@@ -1991,6 +2017,28 @@ struct SPIRVEmitContext
return emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, OperandsOf(inst));
case kIROp_DebugLine:
return emitDebugLine(parent, as<IRDebugLine>(inst));
+ case kIROp_GetStringHash:
+ return emitGetStringHash(inst);
+
+ }
+ }
+
+ SpvInst* emitGetStringHash(IRInst* inst)
+ {
+ auto getStringHashInst = as<IRGetStringHash>(inst);
+ auto stringLit = getStringHashInst->getStringLit();
+
+ if (stringLit)
+ {
+ auto slice = stringLit->getStringSlice();
+ return emitIntConstant(getStableHashCode32(slice.begin(), slice.getLength()).hash, inst->getDataType());
+ }
+ else
+ {
+ // Couldn't handle
+ String e = "Unhandled local inst in spirv-emit:\n"
+ + dumpIRToString(inst, { IRDumpOptions::Mode::Detailed, 0 });
+ SLANG_UNIMPLEMENTED_X(e.getBuffer());
}
}
@@ -2236,7 +2284,14 @@ struct SPIRVEmitContext
for (auto field : structType->getFields())
{
IRIntegerValue offset = 0;
- getOffset(IRTypeLayoutRules::get(layoutRuleName), field, &offset);
+ if (auto offsetDecor = field->getKey()->findDecoration<IRPackOffsetDecoration>())
+ {
+ offset = (getIntVal(offsetDecor->getRegisterOffset()) * 4 + getIntVal(offsetDecor->getComponentOffset())) * 4;
+ }
+ else
+ {
+ getOffset(IRTypeLayoutRules::get(layoutRuleName), field, &offset);
+ }
emitOpMemberDecorateOffset(
getSection(SpvLogicalSectionID::Annotations),
nullptr,
@@ -2368,14 +2423,168 @@ struct SPIRVEmitContext
{
String semanticName = systemValueAttr->getName();
semanticName = semanticName.toLower();
- if (semanticName == "sv_dispatchthreadid")
+ if (semanticName == "sv_position")
+ {
+ auto importDecor = inst->findDecoration<IRImportDecoration>();
+ if (importDecor->getMangledName() == "gl_FragCoord")
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragCoord);
+ else
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPosition);
+ }
+ else if (semanticName == "sv_target")
+ {
+ // Note: we do *not* need to generate some kind of `gl_`
+ // builtin for fragment-shader outputs: they are just
+ // ordinary `out` variables, with ordinary `location`s,
+ // as far as GLSL is concerned.
+ return nullptr;
+ }
+ else if (semanticName == "sv_clipdistance")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInClipDistance);
+ }
+ else if (semanticName == "sv_culldistance")
+ {
+ requireSPIRVCapability(SpvCapabilityCullDistance);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInCullDistance);
+ }
+ else if (semanticName == "sv_coverage")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInSampleMask);
+ }
+ else if (semanticName == "sv_innercoverage")
+ {
+ requireSPIRVCapability(SpvCapabilityFragmentFullyCoveredEXT);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFullyCoveredEXT);
+ }
+ else if (semanticName == "sv_depth")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth);
+ }
+ else if (semanticName == "sv_depthgreaterequal")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth);
+ }
+ else if (semanticName == "sv_depthlessequal")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth);
+ }
+ else if (semanticName == "sv_dispatchthreadid")
{
return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInGlobalInvocationId);
}
+ else if (semanticName == "sv_domainlocation")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInTessCoord);
+ }
+ else if (semanticName == "sv_groupid")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInWorkgroupId);
+ }
else if (semanticName == "sv_groupindex")
{
return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLocalInvocationIndex);
}
+ else if (semanticName == "sv_groupthreadid")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLocalInvocationId);
+ }
+ else if (semanticName == "sv_gsinstanceid")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInvocationId);
+ }
+ else if (semanticName == "sv_instanceid")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInstanceIndex);
+ }
+ else if (semanticName == "sv_isfrontface")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFrontFacing);
+ }
+ else if (semanticName == "sv_outputcontrolpointid")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInvocationId);
+ }
+ else if (semanticName == "sv_pointsize")
+ {
+ // float in hlsl & glsl
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPointSize);
+ }
+ else if (semanticName == "sv_primitiveid")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitiveId);
+ }
+ else if (semanticName == "sv_rendertargetarrayindex")
+ {
+ requireSPIRVCapability(SpvCapabilityShaderLayer);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLayer);
+ }
+ else if (semanticName == "sv_sampleindex")
+ {
+ requireSPIRVCapability(SpvCapabilitySampleRateShading);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInSampleId);
+ }
+ else if (semanticName == "sv_stencilref")
+ {
+ requireSPIRVCapability(SpvCapabilityStencilExportEXT);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragStencilRefEXT);
+ }
+ else if (semanticName == "sv_tessfactor")
+ {
+ requireSPIRVCapability(SpvCapabilityTessellation);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInTessLevelOuter);
+ }
+ else if (semanticName == "sv_vertexid")
+ {
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInVertexId);
+ }
+ else if (semanticName == "sv_viewid")
+ {
+ requireSPIRVCapability(SpvCapabilityMultiView);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewIndex);
+ }
+ else if (semanticName == "sv_viewportarrayindex")
+ {
+ requireSPIRVCapability(SpvCapabilityShaderViewportIndex);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewportIndex);
+ }
+ else if (semanticName == "nv_x_right")
+ {
+ SLANG_UNIMPLEMENTED_X("spirv emit for nv_x_right");
+ }
+ else if (semanticName == "nv_viewport_mask")
+ {
+ requireSPIRVCapability(SpvCapabilityPerViewAttributesNV);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewportMaskPerViewNV);
+ }
+ else if (semanticName == "sv_barycentrics")
+ {
+ if (m_targetRequest->getTargetCaps().implies(CapabilityAtom::GL_NV_fragment_shader_barycentric))
+ {
+ requireSPIRVCapability(SpvCapabilityFragmentBarycentricNV);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInBaryCoordNV);
+ }
+ else
+ {
+ requireSPIRVCapability(SpvCapabilityFragmentBarycentricKHR);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInBaryCoordKHR);
+ }
+
+ // TODO: There is also the `gl_BaryCoordNoPerspNV` builtin, which
+ // we ought to use if the `noperspective` modifier has been
+ // applied to this varying input.
+ }
+ else if (semanticName == "sv_cullprimitive")
+ {
+ requireSPIRVCapability(SpvCapabilityMeshShadingEXT);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInCullPrimitiveEXT);
+ }
+ else if (semanticName == "sv_shadingrate")
+ {
+ requireSPIRVCapability(SpvCapabilityFragmentShadingRateKHR);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitiveShadingRateKHR);
+ }
+ SLANG_UNREACHABLE("Unimplemented system value in spirv emit.");
}
}
return nullptr;
@@ -2571,6 +2780,14 @@ struct SPIRVEmitContext
{
return emitIntrinsicCallExpr(parent, static_cast<IRCall*>(inst), targetIntrinsic);
}
+ else if (auto spvOpDecor = funcValue->findDecorationImpl(kIROp_SPIRVOpDecoration))
+ {
+ SpvOp op = (SpvOp)getIntVal(spvOpDecor->getOperand(0));
+ List<IRInst*> args;
+ for (UInt i = 0; i < inst->getArgCount(); i++)
+ args.add(inst->getArg(i));
+ return emitInst(parent, inst, op, inst->getFullType(), kResultID, args);
+ }
else
{
return emitOpFunctionCall(
@@ -2658,6 +2875,9 @@ struct SPIRVEmitContext
case SpvSnippet::ASMType::Int:
result = emitIntConstant((IRIntegerValue)constant.intValues[0], builder.getIntType());
break;
+ case SpvSnippet::ASMType::UInt16:
+ result = emitIntConstant((IRIntegerValue)constant.intValues[0], builder.getType(kIROp_UInt16Type));
+ break;
case SpvSnippet::ASMType::UInt2:
{
auto uintType = builder.getType(kIROp_UIntType);
@@ -2686,12 +2906,18 @@ struct SPIRVEmitContext
case SpvSnippet::ASMType::Float:
irType = builder.getType(kIROp_FloatType);
break;
+ case SpvSnippet::ASMType::Half:
+ irType = builder.getType(kIROp_HalfType);
+ break;
case SpvSnippet::ASMType::Int:
irType = builder.getIntType();
break;
case SpvSnippet::ASMType::UInt:
irType = builder.getUIntType();
break;
+ case SpvSnippet::ASMType::UInt16:
+ irType = builder.getType(kIROp_UInt16Type);
+ break;
case SpvSnippet::ASMType::Float2:
irType = builder.getVectorType(
builder.getType(kIROp_FloatType), builder.getIntValue(builder.getIntType(), 2));
@@ -2969,6 +3195,46 @@ struct SPIRVEmitContext
return emitOpStore(parent, inst, inst->getPtr(), inst->getVal());
}
+ SpvInst* emitSwizzledStore(SpvInstParent* parent, IRSwizzledStore* inst)
+ {
+ auto sourceVectorType = as<IRVectorType>(inst->getSource()->getDataType());
+ SLANG_ASSERT(sourceVectorType);
+ auto sourceElementType = sourceVectorType->getElementType();
+ SLANG_ASSERT(getIntVal(sourceVectorType->getElementCount()) == (IRIntegerValue)inst->getElementCount());
+ SpvInst* result = nullptr;
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto destPtrType = as<IRPtrTypeBase>(inst->getDest()->getDataType());
+ SpvStorageClass addrSpace = SpvStorageClassFunction;
+ if (destPtrType->hasAddressSpace())
+ addrSpace = (SpvStorageClass)destPtrType->getAddressSpace();
+ auto ptrElementType = builder.getPtrType(kIROp_PtrType, sourceElementType, addrSpace);
+ for (UInt i = 0; i < inst->getElementCount(); i++)
+ {
+ auto index = inst->getElementIndex(i);
+ auto addr = emitOpAccessChain(parent, nullptr, ptrElementType, inst->getDest(), makeArray(index));
+ auto val = emitOpCompositeExtract(parent, nullptr, sourceElementType, inst->getSource(), makeArray(SpvLiteralInteger::from32((int32_t)i)));
+ result = emitOpStore(parent, (i == inst->getElementCount() - 1 ? inst : nullptr), addr, val);
+ }
+ return result;
+ }
+
+ SpvInst* emitSwizzleSet(SpvInstParent* parent, IRSwizzleSet* inst)
+ {
+ auto resultVectorType = as<IRVectorType>(inst->getDataType());
+ List<SpvLiteralInteger> shuffleIndices;
+ shuffleIndices.setCount((Index)getIntVal(resultVectorType->getElementCount()));
+ for (Index i = 0; i < shuffleIndices.getCount(); i++)
+ shuffleIndices[i] = SpvLiteralInteger::from32((int32_t)i);
+ for (UInt i = 0; i < inst->getElementCount(); i++)
+ {
+ auto destIndex = (int32_t)getIntVal(inst->getElementIndex(i));
+ SLANG_ASSERT(destIndex < shuffleIndices.getCount());
+ shuffleIndices[destIndex] = SpvLiteralInteger::from32((int32_t)(i + shuffleIndices.getCount()));
+ }
+ return emitOpVectorShuffle(parent, inst, inst->getFullType(), inst->getBase(), inst->getSource(), shuffleIndices.getArrayView());
+ }
+
SpvInst* emitStructuredBufferGetElementPtr(SpvInstParent* parent, IRInst* inst)
{
//"%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1;"
@@ -3035,6 +3301,40 @@ struct SPIRVEmitContext
SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
const auto fromType = dropVector(fromTypeV);
const auto toType = dropVector(toTypeV);
+
+ if (as<IRBoolType>(fromType))
+ {
+ // Cast from bool to int.
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto zero = builder.getIntValue(toType, 0);
+ auto one = builder.getIntValue(toType, 1);
+ if (auto vecType = as<IRVectorType>(toTypeV))
+ {
+ auto zeroV = emitSplat(parent, nullptr, zero, getIntVal(vecType->getElementCount()));
+ auto oneV = emitSplat(parent, nullptr, one, getIntVal(vecType->getElementCount()));
+ return emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, inst->getOperand(0),
+ oneV, zeroV);
+ }
+ return emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, inst->getOperand(0), one, zero);
+ }
+ else if (as<IRBoolType>(toType))
+ {
+ // Cast from int to bool.
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto zero = builder.getIntValue(fromType, 0);
+ if (auto vecType = as<IRVectorType>(toTypeV))
+ {
+ auto zeroV = emitSplat(parent, nullptr, zero, getIntVal(vecType->getElementCount()));
+ return emitOpINotEqual(parent, inst, inst->getFullType(), inst->getOperand(0), zeroV);
+ }
+ else
+ {
+ return emitOpINotEqual(parent, inst, inst->getFullType(), inst->getOperand(0), zero);
+ }
+ }
+
SLANG_ASSERT(isIntegralType(fromType));
SLANG_ASSERT(isIntegralType(toType));
@@ -3100,6 +3400,24 @@ struct SPIRVEmitContext
const auto fromType = dropVector(fromTypeV);
const auto toType = dropVector(toTypeV);
SLANG_ASSERT(isFloatingType(fromType));
+
+ if (as<IRBoolType>(toType))
+ {
+ // Float to bool cast.
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto zero = builder.getIntValue(fromType, 0);
+ if (auto vecType = as<IRVectorType>(toTypeV))
+ {
+ auto zeroV = emitSplat(parent, nullptr, zero, getIntVal(vecType->getElementCount()));
+ return emitInst(parent, inst, SpvOpFUnordNotEqual, inst->getFullType(), kResultID, inst->getOperand(0), zeroV);
+ }
+ else
+ {
+ return emitInst(parent, inst, SpvOpFUnordNotEqual, inst->getFullType(), kResultID, inst->getOperand(0), zero);
+ }
+ }
+
SLANG_ASSERT(isIntegralType(toType));
const auto toInfo = getIntTypeInfo(toType);
@@ -3279,14 +3597,9 @@ struct SPIRVEmitContext
}
}
- SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst)
+ SpvInst* emitVectorOrScalarArithmetic(SpvInstParent* parent, IRInst* instToRegister, IRInst* type, IROp op, UInt operandCount, ArrayView<IRInst*> operands)
{
- IRType* elementType = dropVector(inst->getOperand(0)->getDataType());
- if (const auto matrixType = as<IRMatrixType>(inst->getDataType()))
- {
- //TODO: implement.
- SLANG_ASSERT(!"unimplemented: matrix arithemetic");
- }
+ IRType* elementType = dropVector(operands[0]->getDataType());
IRBasicType* basicType = as<IRBasicType>(elementType);
bool isFloatingPoint = false;
bool isBool = false;
@@ -3305,7 +3618,7 @@ struct SPIRVEmitContext
}
SpvOp opCode = SpvOpUndef;
bool isSigned = isSignedType(basicType);
- switch (inst->getOp())
+ switch (op)
{
case kIROp_Add:
opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd;
@@ -3323,30 +3636,30 @@ struct SPIRVEmitContext
opCode = isSigned ? SpvOpSRem : SpvOpUMod;
break;
case kIROp_FRem:
- opCode = SpvOpFRem;
+ opCode = SpvOpFMod;
break;
case kIROp_Less:
opCode = isFloatingPoint ? SpvOpFOrdLessThan
- : isSigned ? SpvOpSLessThan : SpvOpULessThan;
+ : isSigned ? SpvOpSLessThan : SpvOpULessThan;
break;
case kIROp_Leq:
opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual
- : isSigned ? SpvOpSLessThanEqual : SpvOpULessThanEqual;
+ : isSigned ? SpvOpSLessThanEqual : SpvOpULessThanEqual;
break;
case kIROp_Eql:
opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual;
break;
case kIROp_Neq:
opCode = isFloatingPoint ? SpvOpFOrdNotEqual
- : isBool ? SpvOpLogicalNotEqual : SpvOpINotEqual;
+ : isBool ? SpvOpLogicalNotEqual : SpvOpINotEqual;
break;
case kIROp_Geq:
opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual
- : isSigned ? SpvOpSGreaterThanEqual : SpvOpUGreaterThanEqual;
+ : isSigned ? SpvOpSGreaterThanEqual : SpvOpUGreaterThanEqual;
break;
case kIROp_Greater:
opCode = isFloatingPoint ? SpvOpFOrdGreaterThan
- : isSigned ? SpvOpSGreaterThan : SpvOpUGreaterThan;
+ : isSigned ? SpvOpSGreaterThan : SpvOpUGreaterThan;
break;
case kIROp_Neg:
opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate;
@@ -3361,16 +3674,28 @@ struct SPIRVEmitContext
opCode = SpvOpLogicalNot;
break;
case kIROp_BitAnd:
- opCode = SpvOpBitwiseAnd;
+ if (isBool)
+ opCode = SpvOpLogicalAnd;
+ else
+ opCode = SpvOpBitwiseAnd;
break;
case kIROp_BitOr:
- opCode = SpvOpBitwiseOr;
+ if (isBool)
+ opCode = SpvOpLogicalOr;
+ else
+ opCode = SpvOpBitwiseOr;
break;
case kIROp_BitXor:
- opCode = SpvOpBitwiseXor;
+ if (isBool)
+ opCode = SpvOpLogicalNotEqual;
+ else
+ opCode = SpvOpBitwiseXor;
break;
case kIROp_BitNot:
- opCode = SpvOpBitReverse;
+ if (isBool)
+ opCode = SpvOpLogicalNot;
+ else
+ opCode = SpvOpBitReverse;
break;
case kIROp_Rsh:
opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical;
@@ -3382,20 +3707,20 @@ struct SPIRVEmitContext
SLANG_ASSERT(!"unknown arithmetic opcode");
break;
}
- if(inst->getOperandCount() == 1)
+ if (operandCount == 1)
{
- return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, OperandsOf(inst));
+ return emitInst(parent, instToRegister, opCode, type, kResultID, operands);
}
- else if(inst->getOperandCount() == 2)
+ else if (operandCount == 2)
{
- auto l = inst->getOperand(0);
+ auto l = operands[0];
const auto lVec = as<IRVectorType>(l->getDataType());
- auto r = inst->getOperand(1);
+ auto r = operands[1];
const auto rVec = as<IRVectorType>(r->getDataType());
- const auto go = [&](const auto l, const auto r){
- return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, l, r);
+ const auto go = [&](const auto l, const auto r) {
+ return emitInst(parent, instToRegister, opCode, type, kResultID, l, r);
};
- if(lVec && !rVec)
+ if (lVec && !rVec)
{
const auto len = as<IRIntLit>(lVec->getElementCount());
SLANG_ASSERT(len);
@@ -3412,6 +3737,45 @@ struct SPIRVEmitContext
SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands");
}
+
+ SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst)
+ {
+ if (const auto matrixType = as<IRMatrixType>(inst->getDataType()))
+ {
+ auto rowCount = getIntVal(matrixType->getRowCount());
+ auto colCount = getIntVal(matrixType->getColumnCount());
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount);
+ List<SpvInst*> rows;
+ for (IRIntegerValue i = 0; i < rowCount; i++)
+ {
+ List<IRInst*> operands;
+ for (UInt j = 0; j < inst->getOperandCount(); j++)
+ {
+ auto originalOperand = inst->getOperand(j);
+ if (as<IRMatrixType>(originalOperand->getDataType()))
+ {
+ auto operand = builder.emitElementExtract(originalOperand, i);
+ emitLocalInst(parent, operand);
+ operands.add(operand);
+ }
+ else
+ {
+ operands.add(originalOperand);
+ }
+ }
+ rows.add(emitVectorOrScalarArithmetic(parent, nullptr, rowVectorType, inst->getOp(), inst->getOperandCount(), operands.getArrayView()));
+ }
+ return emitCompositeConstruct(parent, inst, inst->getDataType(), rows);
+ }
+
+ Array<IRInst*, 4> operands;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ operands.add(inst->getOperand(i));
+ return emitVectorOrScalarArithmetic(parent, inst, inst->getDataType(), inst->getOp(), inst->getOperandCount(), operands.getView());
+ }
+
SpvInst* emitDebugLine(SpvInstParent* parent, IRDebugLine* debugLine)
{
auto scope = findDebugScope(debugLine);