diff options
| author | Yong He <yonghe@outlook.com> | 2023-08-24 16:32:33 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-24 16:32:33 -0700 |
| commit | 0470ea05a42d6c3f35d81a433fefdd440500cdbd (patch) | |
| tree | 25feb7bfd539013bfa64d8ff7698262932e39110 /source/slang/slang-emit-spirv.cpp | |
| parent | c515bf9edf0ceefa9a0c9b36626ea7c8f72ce36f (diff) | |
Misc. SPIRV Fixes, Part 2. (#3147)
* Misc. SPIRV Fixes, Part 2.
* Fix up.
* Fix.
* Add system smenatic values.
* 16 bit int and floats, matrix/vector reshape, bool ops.
* Fix.
* Fix.
* Allow push constant entry point params.
* entrypoint params.
* swizzleSet and swizzledStore.
* packoffset.
* string hash.
* Fix.
* Matrix arithmetics.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 462 |
1 files changed, 413 insertions, 49 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 3febbd210..0878b3494 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1168,16 +1168,22 @@ struct SPIRVEmitContext // > OpTypeInt - case kIROp_UInt8Type: case kIROp_UInt16Type: + case kIROp_Int16Type: + case kIROp_UInt8Type: case kIROp_UIntType: case kIROp_UInt64Type: case kIROp_Int8Type: - case kIROp_Int16Type: case kIROp_IntType: case kIROp_Int64Type: { const IntInfo i = getIntTypeInfo(as<IRType>(inst)); + if (i.width == 16) + requireSPIRVCapability(SpvCapabilityInt16); + else if (i.width == 64) + requireSPIRVCapability(SpvCapabilityInt64); + else if (i.width == 8) + requireSPIRVCapability(SpvCapabilityInt8); return emitOpTypeInt( inst, SpvLiteralInteger::from32(int32_t(i.width)), @@ -1395,6 +1401,8 @@ struct SPIRVEmitContext return emitGlobalParam(as<IRGlobalParam>(inst)); case kIROp_GlobalVar: return emitGlobalVar(as<IRGlobalVar>(inst)); + case kIROp_Var: + return emitVar(getSection(SpvLogicalSectionID::GlobalVariables), inst); // ... case kIROp_Specialize: @@ -1436,6 +1444,8 @@ struct SPIRVEmitContext } return result; } + case kIROp_GetStringHash: + return emitGetStringHash(inst); default: { String e = "Unhandled global inst in spirv-emit:\n" @@ -1468,6 +1478,8 @@ struct SPIRVEmitContext void emitVarLayout(SpvInst* varInst, IRVarLayout* layout) { + bool needDefaultSetBindingDecoration = false; + bool hasExplicitSetBinding = false; for (auto rr : layout->getOffsetAttrs()) { UInt index = rr->getOffset(); @@ -1498,15 +1510,6 @@ struct SPIRVEmitContext varInst, SpvLiteralInteger::from32(int32_t(index)) ); - if (space) - { - emitOpDecorateIndex( - getSection(SpvLogicalSectionID::Annotations), - nullptr, - varInst, - SpvLiteralInteger::from32(int32_t(space)) - ); - } break; case LayoutResourceKind::SpecializationConstant: @@ -1527,24 +1530,44 @@ struct SPIRVEmitContext getSection(SpvLogicalSectionID::Annotations), nullptr, varInst, - SpvLiteralInteger::from32(int32_t(index)) - ); + SpvLiteralInteger::from32(int32_t(index))); + if (space) + { + emitOpDecorateDescriptorSet( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + varInst, + SpvLiteralInteger::from32(int32_t(space))); + } + else + { + needDefaultSetBindingDecoration = true; + } + break; + case LayoutResourceKind::RegisterSpace: emitOpDecorateDescriptorSet( getSection(SpvLogicalSectionID::Annotations), nullptr, varInst, - SpvLiteralInteger::from32(int32_t(space)) - ); + SpvLiteralInteger::from32(int32_t(index))); + hasExplicitSetBinding = true; break; default: break; } } + if (needDefaultSetBindingDecoration && !hasExplicitSetBinding) + { + emitOpDecorateDescriptorSet( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + varInst, + SpvLiteralInteger::from32(int32_t(0))); + } } /// Emit a global parameter definition. SpvInst* emitGlobalParam(IRGlobalParam* param) { - auto layout = getVarLayout(param); auto storageClass = SpvStorageClassUniform; if (auto ptrType = as<IRPtrTypeBase>(param->getDataType())) { @@ -1562,7 +1585,8 @@ struct SPIRVEmitContext param->getDataType(), storageClass ); - emitVarLayout(varInst, layout); + if (auto layout = getVarLayout(param)) + emitVarLayout(varInst, layout); return varInst; } @@ -1615,6 +1639,8 @@ struct SPIRVEmitContext /// Emit a declaration for the given `irFunc` SpvInst* emitFuncDeclaration(IRFunc* irFunc) { + if (irFunc->findDecorationImpl(kIROp_SPIRVOpDecoration)) + return nullptr; // For now we aren't handling function declarations; // we expect to deal only with fully linked modules. // @@ -1852,6 +1878,10 @@ struct SPIRVEmitContext return emitLoad(parent, as<IRLoad>(inst)); case kIROp_Store: return emitStore(parent, as<IRStore>(inst)); + case kIROp_SwizzledStore: + return emitSwizzledStore(parent, as<IRSwizzledStore>(inst)); + case kIROp_swizzleSet: + return emitSwizzleSet(parent, as<IRSwizzleSet>(inst)); case kIROp_RWStructuredBufferGetElementPtr: return emitStructuredBufferGetElementPtr(parent, inst); case kIROp_StructuredBufferGetDimensions: @@ -1866,10 +1896,6 @@ struct SPIRVEmitContext return emitIntToFloatCast(parent, as<IRCastIntToFloat>(inst)); case kIROp_CastFloatToInt: return emitFloatToIntCast(parent, as<IRCastFloatToInt>(inst)); - case kIROp_MatrixReshape: - case kIROp_VectorReshape: - // TODO: break emitConstruct into separate functions for each opcode. - return emitConstruct(parent, inst); case kIROp_BitCast: return emitOpBitcast( parent, @@ -1991,6 +2017,28 @@ struct SPIRVEmitContext return emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, OperandsOf(inst)); case kIROp_DebugLine: return emitDebugLine(parent, as<IRDebugLine>(inst)); + case kIROp_GetStringHash: + return emitGetStringHash(inst); + + } + } + + SpvInst* emitGetStringHash(IRInst* inst) + { + auto getStringHashInst = as<IRGetStringHash>(inst); + auto stringLit = getStringHashInst->getStringLit(); + + if (stringLit) + { + auto slice = stringLit->getStringSlice(); + return emitIntConstant(getStableHashCode32(slice.begin(), slice.getLength()).hash, inst->getDataType()); + } + else + { + // Couldn't handle + String e = "Unhandled local inst in spirv-emit:\n" + + dumpIRToString(inst, { IRDumpOptions::Mode::Detailed, 0 }); + SLANG_UNIMPLEMENTED_X(e.getBuffer()); } } @@ -2236,7 +2284,14 @@ struct SPIRVEmitContext for (auto field : structType->getFields()) { IRIntegerValue offset = 0; - getOffset(IRTypeLayoutRules::get(layoutRuleName), field, &offset); + if (auto offsetDecor = field->getKey()->findDecoration<IRPackOffsetDecoration>()) + { + offset = (getIntVal(offsetDecor->getRegisterOffset()) * 4 + getIntVal(offsetDecor->getComponentOffset())) * 4; + } + else + { + getOffset(IRTypeLayoutRules::get(layoutRuleName), field, &offset); + } emitOpMemberDecorateOffset( getSection(SpvLogicalSectionID::Annotations), nullptr, @@ -2368,14 +2423,168 @@ struct SPIRVEmitContext { String semanticName = systemValueAttr->getName(); semanticName = semanticName.toLower(); - if (semanticName == "sv_dispatchthreadid") + if (semanticName == "sv_position") + { + auto importDecor = inst->findDecoration<IRImportDecoration>(); + if (importDecor->getMangledName() == "gl_FragCoord") + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragCoord); + else + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPosition); + } + else if (semanticName == "sv_target") + { + // Note: we do *not* need to generate some kind of `gl_` + // builtin for fragment-shader outputs: they are just + // ordinary `out` variables, with ordinary `location`s, + // as far as GLSL is concerned. + return nullptr; + } + else if (semanticName == "sv_clipdistance") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInClipDistance); + } + else if (semanticName == "sv_culldistance") + { + requireSPIRVCapability(SpvCapabilityCullDistance); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInCullDistance); + } + else if (semanticName == "sv_coverage") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInSampleMask); + } + else if (semanticName == "sv_innercoverage") + { + requireSPIRVCapability(SpvCapabilityFragmentFullyCoveredEXT); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFullyCoveredEXT); + } + else if (semanticName == "sv_depth") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth); + } + else if (semanticName == "sv_depthgreaterequal") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth); + } + else if (semanticName == "sv_depthlessequal") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth); + } + else if (semanticName == "sv_dispatchthreadid") { return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInGlobalInvocationId); } + else if (semanticName == "sv_domainlocation") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInTessCoord); + } + else if (semanticName == "sv_groupid") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInWorkgroupId); + } else if (semanticName == "sv_groupindex") { return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLocalInvocationIndex); } + else if (semanticName == "sv_groupthreadid") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLocalInvocationId); + } + else if (semanticName == "sv_gsinstanceid") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInvocationId); + } + else if (semanticName == "sv_instanceid") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInstanceIndex); + } + else if (semanticName == "sv_isfrontface") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFrontFacing); + } + else if (semanticName == "sv_outputcontrolpointid") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInvocationId); + } + else if (semanticName == "sv_pointsize") + { + // float in hlsl & glsl + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPointSize); + } + else if (semanticName == "sv_primitiveid") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitiveId); + } + else if (semanticName == "sv_rendertargetarrayindex") + { + requireSPIRVCapability(SpvCapabilityShaderLayer); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLayer); + } + else if (semanticName == "sv_sampleindex") + { + requireSPIRVCapability(SpvCapabilitySampleRateShading); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInSampleId); + } + else if (semanticName == "sv_stencilref") + { + requireSPIRVCapability(SpvCapabilityStencilExportEXT); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragStencilRefEXT); + } + else if (semanticName == "sv_tessfactor") + { + requireSPIRVCapability(SpvCapabilityTessellation); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInTessLevelOuter); + } + else if (semanticName == "sv_vertexid") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInVertexId); + } + else if (semanticName == "sv_viewid") + { + requireSPIRVCapability(SpvCapabilityMultiView); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewIndex); + } + else if (semanticName == "sv_viewportarrayindex") + { + requireSPIRVCapability(SpvCapabilityShaderViewportIndex); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewportIndex); + } + else if (semanticName == "nv_x_right") + { + SLANG_UNIMPLEMENTED_X("spirv emit for nv_x_right"); + } + else if (semanticName == "nv_viewport_mask") + { + requireSPIRVCapability(SpvCapabilityPerViewAttributesNV); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewportMaskPerViewNV); + } + else if (semanticName == "sv_barycentrics") + { + if (m_targetRequest->getTargetCaps().implies(CapabilityAtom::GL_NV_fragment_shader_barycentric)) + { + requireSPIRVCapability(SpvCapabilityFragmentBarycentricNV); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInBaryCoordNV); + } + else + { + requireSPIRVCapability(SpvCapabilityFragmentBarycentricKHR); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInBaryCoordKHR); + } + + // TODO: There is also the `gl_BaryCoordNoPerspNV` builtin, which + // we ought to use if the `noperspective` modifier has been + // applied to this varying input. + } + else if (semanticName == "sv_cullprimitive") + { + requireSPIRVCapability(SpvCapabilityMeshShadingEXT); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInCullPrimitiveEXT); + } + else if (semanticName == "sv_shadingrate") + { + requireSPIRVCapability(SpvCapabilityFragmentShadingRateKHR); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitiveShadingRateKHR); + } + SLANG_UNREACHABLE("Unimplemented system value in spirv emit."); } } return nullptr; @@ -2571,6 +2780,14 @@ struct SPIRVEmitContext { return emitIntrinsicCallExpr(parent, static_cast<IRCall*>(inst), targetIntrinsic); } + else if (auto spvOpDecor = funcValue->findDecorationImpl(kIROp_SPIRVOpDecoration)) + { + SpvOp op = (SpvOp)getIntVal(spvOpDecor->getOperand(0)); + List<IRInst*> args; + for (UInt i = 0; i < inst->getArgCount(); i++) + args.add(inst->getArg(i)); + return emitInst(parent, inst, op, inst->getFullType(), kResultID, args); + } else { return emitOpFunctionCall( @@ -2658,6 +2875,9 @@ struct SPIRVEmitContext case SpvSnippet::ASMType::Int: result = emitIntConstant((IRIntegerValue)constant.intValues[0], builder.getIntType()); break; + case SpvSnippet::ASMType::UInt16: + result = emitIntConstant((IRIntegerValue)constant.intValues[0], builder.getType(kIROp_UInt16Type)); + break; case SpvSnippet::ASMType::UInt2: { auto uintType = builder.getType(kIROp_UIntType); @@ -2686,12 +2906,18 @@ struct SPIRVEmitContext case SpvSnippet::ASMType::Float: irType = builder.getType(kIROp_FloatType); break; + case SpvSnippet::ASMType::Half: + irType = builder.getType(kIROp_HalfType); + break; case SpvSnippet::ASMType::Int: irType = builder.getIntType(); break; case SpvSnippet::ASMType::UInt: irType = builder.getUIntType(); break; + case SpvSnippet::ASMType::UInt16: + irType = builder.getType(kIROp_UInt16Type); + break; case SpvSnippet::ASMType::Float2: irType = builder.getVectorType( builder.getType(kIROp_FloatType), builder.getIntValue(builder.getIntType(), 2)); @@ -2969,6 +3195,46 @@ struct SPIRVEmitContext return emitOpStore(parent, inst, inst->getPtr(), inst->getVal()); } + SpvInst* emitSwizzledStore(SpvInstParent* parent, IRSwizzledStore* inst) + { + auto sourceVectorType = as<IRVectorType>(inst->getSource()->getDataType()); + SLANG_ASSERT(sourceVectorType); + auto sourceElementType = sourceVectorType->getElementType(); + SLANG_ASSERT(getIntVal(sourceVectorType->getElementCount()) == (IRIntegerValue)inst->getElementCount()); + SpvInst* result = nullptr; + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto destPtrType = as<IRPtrTypeBase>(inst->getDest()->getDataType()); + SpvStorageClass addrSpace = SpvStorageClassFunction; + if (destPtrType->hasAddressSpace()) + addrSpace = (SpvStorageClass)destPtrType->getAddressSpace(); + auto ptrElementType = builder.getPtrType(kIROp_PtrType, sourceElementType, addrSpace); + for (UInt i = 0; i < inst->getElementCount(); i++) + { + auto index = inst->getElementIndex(i); + auto addr = emitOpAccessChain(parent, nullptr, ptrElementType, inst->getDest(), makeArray(index)); + auto val = emitOpCompositeExtract(parent, nullptr, sourceElementType, inst->getSource(), makeArray(SpvLiteralInteger::from32((int32_t)i))); + result = emitOpStore(parent, (i == inst->getElementCount() - 1 ? inst : nullptr), addr, val); + } + return result; + } + + SpvInst* emitSwizzleSet(SpvInstParent* parent, IRSwizzleSet* inst) + { + auto resultVectorType = as<IRVectorType>(inst->getDataType()); + List<SpvLiteralInteger> shuffleIndices; + shuffleIndices.setCount((Index)getIntVal(resultVectorType->getElementCount())); + for (Index i = 0; i < shuffleIndices.getCount(); i++) + shuffleIndices[i] = SpvLiteralInteger::from32((int32_t)i); + for (UInt i = 0; i < inst->getElementCount(); i++) + { + auto destIndex = (int32_t)getIntVal(inst->getElementIndex(i)); + SLANG_ASSERT(destIndex < shuffleIndices.getCount()); + shuffleIndices[destIndex] = SpvLiteralInteger::from32((int32_t)(i + shuffleIndices.getCount())); + } + return emitOpVectorShuffle(parent, inst, inst->getFullType(), inst->getBase(), inst->getSource(), shuffleIndices.getArrayView()); + } + SpvInst* emitStructuredBufferGetElementPtr(SpvInstParent* parent, IRInst* inst) { //"%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1;" @@ -3035,6 +3301,40 @@ struct SPIRVEmitContext SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); const auto fromType = dropVector(fromTypeV); const auto toType = dropVector(toTypeV); + + if (as<IRBoolType>(fromType)) + { + // Cast from bool to int. + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto zero = builder.getIntValue(toType, 0); + auto one = builder.getIntValue(toType, 1); + if (auto vecType = as<IRVectorType>(toTypeV)) + { + auto zeroV = emitSplat(parent, nullptr, zero, getIntVal(vecType->getElementCount())); + auto oneV = emitSplat(parent, nullptr, one, getIntVal(vecType->getElementCount())); + return emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, inst->getOperand(0), + oneV, zeroV); + } + return emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, inst->getOperand(0), one, zero); + } + else if (as<IRBoolType>(toType)) + { + // Cast from int to bool. + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto zero = builder.getIntValue(fromType, 0); + if (auto vecType = as<IRVectorType>(toTypeV)) + { + auto zeroV = emitSplat(parent, nullptr, zero, getIntVal(vecType->getElementCount())); + return emitOpINotEqual(parent, inst, inst->getFullType(), inst->getOperand(0), zeroV); + } + else + { + return emitOpINotEqual(parent, inst, inst->getFullType(), inst->getOperand(0), zero); + } + } + SLANG_ASSERT(isIntegralType(fromType)); SLANG_ASSERT(isIntegralType(toType)); @@ -3100,6 +3400,24 @@ struct SPIRVEmitContext const auto fromType = dropVector(fromTypeV); const auto toType = dropVector(toTypeV); SLANG_ASSERT(isFloatingType(fromType)); + + if (as<IRBoolType>(toType)) + { + // Float to bool cast. + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto zero = builder.getIntValue(fromType, 0); + if (auto vecType = as<IRVectorType>(toTypeV)) + { + auto zeroV = emitSplat(parent, nullptr, zero, getIntVal(vecType->getElementCount())); + return emitInst(parent, inst, SpvOpFUnordNotEqual, inst->getFullType(), kResultID, inst->getOperand(0), zeroV); + } + else + { + return emitInst(parent, inst, SpvOpFUnordNotEqual, inst->getFullType(), kResultID, inst->getOperand(0), zero); + } + } + SLANG_ASSERT(isIntegralType(toType)); const auto toInfo = getIntTypeInfo(toType); @@ -3279,14 +3597,9 @@ struct SPIRVEmitContext } } - SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) + SpvInst* emitVectorOrScalarArithmetic(SpvInstParent* parent, IRInst* instToRegister, IRInst* type, IROp op, UInt operandCount, ArrayView<IRInst*> operands) { - IRType* elementType = dropVector(inst->getOperand(0)->getDataType()); - if (const auto matrixType = as<IRMatrixType>(inst->getDataType())) - { - //TODO: implement. - SLANG_ASSERT(!"unimplemented: matrix arithemetic"); - } + IRType* elementType = dropVector(operands[0]->getDataType()); IRBasicType* basicType = as<IRBasicType>(elementType); bool isFloatingPoint = false; bool isBool = false; @@ -3305,7 +3618,7 @@ struct SPIRVEmitContext } SpvOp opCode = SpvOpUndef; bool isSigned = isSignedType(basicType); - switch (inst->getOp()) + switch (op) { case kIROp_Add: opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd; @@ -3323,30 +3636,30 @@ struct SPIRVEmitContext opCode = isSigned ? SpvOpSRem : SpvOpUMod; break; case kIROp_FRem: - opCode = SpvOpFRem; + opCode = SpvOpFMod; break; case kIROp_Less: opCode = isFloatingPoint ? SpvOpFOrdLessThan - : isSigned ? SpvOpSLessThan : SpvOpULessThan; + : isSigned ? SpvOpSLessThan : SpvOpULessThan; break; case kIROp_Leq: opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual - : isSigned ? SpvOpSLessThanEqual : SpvOpULessThanEqual; + : isSigned ? SpvOpSLessThanEqual : SpvOpULessThanEqual; break; case kIROp_Eql: opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual; break; case kIROp_Neq: opCode = isFloatingPoint ? SpvOpFOrdNotEqual - : isBool ? SpvOpLogicalNotEqual : SpvOpINotEqual; + : isBool ? SpvOpLogicalNotEqual : SpvOpINotEqual; break; case kIROp_Geq: opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual - : isSigned ? SpvOpSGreaterThanEqual : SpvOpUGreaterThanEqual; + : isSigned ? SpvOpSGreaterThanEqual : SpvOpUGreaterThanEqual; break; case kIROp_Greater: opCode = isFloatingPoint ? SpvOpFOrdGreaterThan - : isSigned ? SpvOpSGreaterThan : SpvOpUGreaterThan; + : isSigned ? SpvOpSGreaterThan : SpvOpUGreaterThan; break; case kIROp_Neg: opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate; @@ -3361,16 +3674,28 @@ struct SPIRVEmitContext opCode = SpvOpLogicalNot; break; case kIROp_BitAnd: - opCode = SpvOpBitwiseAnd; + if (isBool) + opCode = SpvOpLogicalAnd; + else + opCode = SpvOpBitwiseAnd; break; case kIROp_BitOr: - opCode = SpvOpBitwiseOr; + if (isBool) + opCode = SpvOpLogicalOr; + else + opCode = SpvOpBitwiseOr; break; case kIROp_BitXor: - opCode = SpvOpBitwiseXor; + if (isBool) + opCode = SpvOpLogicalNotEqual; + else + opCode = SpvOpBitwiseXor; break; case kIROp_BitNot: - opCode = SpvOpBitReverse; + if (isBool) + opCode = SpvOpLogicalNot; + else + opCode = SpvOpBitReverse; break; case kIROp_Rsh: opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical; @@ -3382,20 +3707,20 @@ struct SPIRVEmitContext SLANG_ASSERT(!"unknown arithmetic opcode"); break; } - if(inst->getOperandCount() == 1) + if (operandCount == 1) { - return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, OperandsOf(inst)); + return emitInst(parent, instToRegister, opCode, type, kResultID, operands); } - else if(inst->getOperandCount() == 2) + else if (operandCount == 2) { - auto l = inst->getOperand(0); + auto l = operands[0]; const auto lVec = as<IRVectorType>(l->getDataType()); - auto r = inst->getOperand(1); + auto r = operands[1]; const auto rVec = as<IRVectorType>(r->getDataType()); - const auto go = [&](const auto l, const auto r){ - return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, l, r); + const auto go = [&](const auto l, const auto r) { + return emitInst(parent, instToRegister, opCode, type, kResultID, l, r); }; - if(lVec && !rVec) + if (lVec && !rVec) { const auto len = as<IRIntLit>(lVec->getElementCount()); SLANG_ASSERT(len); @@ -3412,6 +3737,45 @@ struct SPIRVEmitContext SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands"); } + + SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) + { + if (const auto matrixType = as<IRMatrixType>(inst->getDataType())) + { + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); + List<SpvInst*> rows; + for (IRIntegerValue i = 0; i < rowCount; i++) + { + List<IRInst*> operands; + for (UInt j = 0; j < inst->getOperandCount(); j++) + { + auto originalOperand = inst->getOperand(j); + if (as<IRMatrixType>(originalOperand->getDataType())) + { + auto operand = builder.emitElementExtract(originalOperand, i); + emitLocalInst(parent, operand); + operands.add(operand); + } + else + { + operands.add(originalOperand); + } + } + rows.add(emitVectorOrScalarArithmetic(parent, nullptr, rowVectorType, inst->getOp(), inst->getOperandCount(), operands.getArrayView())); + } + return emitCompositeConstruct(parent, inst, inst->getDataType(), rows); + } + + Array<IRInst*, 4> operands; + for (UInt i = 0; i < inst->getOperandCount(); i++) + operands.add(inst->getOperand(i)); + return emitVectorOrScalarArithmetic(parent, inst, inst->getDataType(), inst->getOp(), inst->getOperandCount(), operands.getView()); + } + SpvInst* emitDebugLine(SpvInstParent* parent, IRDebugLine* debugLine) { auto scope = findDebugScope(debugLine); |
