diff options
Diffstat (limited to 'source')
22 files changed, 866 insertions, 134 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 73c20f2d0..a3900826e 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -1048,6 +1048,9 @@ __generic<T : __BuiltinType> __target_intrinsic(cpp, "bool($0)") __target_intrinsic(cuda, "bool($0)") __target_intrinsic(glsl, "bool($0)") +__target_intrinsic(spirv, boolean(T), "OpCopyObject resultType resultId _0") +__target_intrinsic(spirv, integral(T), "OpINotEqual resultType resultId _0 const(,0)") +__target_intrinsic(spirv, floating(T), "OpFUnordNotEqual resultType resultId _0 const(,0)") [__readNone] bool any(T x); @@ -1375,6 +1378,7 @@ matrix<uint,N,M> asuint(matrix<uint,N,M> x) __target_intrinsic(hlsl) __target_intrinsic(glsl, "uint16_t(packHalf2x16(vec2($0, 0.0)))") __target_intrinsic(cuda, "__half_as_ushort") +__target_intrinsic(spirv, "OpBitcast resultType resultId _0") [__readNone] uint16_t asuint16(float16_t value); @@ -1391,6 +1395,7 @@ matrix<uint16_t,R,C> asuint16<let R : int, let C : int>(matrix<float16_t,R,C> va __target_intrinsic(hlsl) __target_intrinsic(glsl, "float16_t(unpackHalf2x16($0).x)") __target_intrinsic(cuda, "__ushort_as_half") +__target_intrinsic(spirv, "OpBitcast resultType resultId _0") [__readNone] float16_t asfloat16(uint16_t value); @@ -1406,12 +1411,14 @@ matrix<float16_t,R,C> asfloat16<let R : int, let C : int>(matrix<uint16_t,R,C> v __target_intrinsic(hlsl) __target_intrinsic(cuda, "__half_as_short") +__target_intrinsic(spirv, "OpBitcast resultType resultId _0") [__unsafeForceInlineEarly][__readNone] int16_t asint16(float16_t value) { return asuint16(value); } __target_intrinsic(hlsl) [__unsafeForceInlineEarly][__readNone] vector<int16_t,N> asint16<let N : int>(vector<float16_t,N> value) { return asuint16(value); } __target_intrinsic(hlsl) [__unsafeForceInlineEarly][__readNone] matrix<int16_t,R,C> asint16<let R : int, let C : int>(matrix<float16_t,R,C> value) { return asuint16(value); } __target_intrinsic(hlsl) __target_intrinsic(cuda, "__short_as_half") +__target_intrinsic(spirv, "OpBitcast resultType resultId _0") [__readNone] [__unsafeForceInlineEarly] float16_t asfloat16(int16_t value) { return asfloat16(asuint16(value)); } @@ -2086,6 +2093,10 @@ __glsl_version(420) __target_intrinsic(hlsl) __cuda_sm_version(6.0) __target_intrinsic(cuda, "__half2float(__ushort_as_half($0))") +__target_intrinsic(spirv, R"( + %lowBits = OpUConvert _type(uint16_t) resultId _0; + %half = OpBitcast _type(half) resultId %lowBits; + OpFConvert resultType resultId %half)") [__readNone] float f16tof32(uint value); @@ -2105,6 +2116,10 @@ __glsl_version(420) __target_intrinsic(hlsl) __cuda_sm_version(6.0) __target_intrinsic(cuda, "__half_as_ushort(__float2half($0))") +__target_intrinsic(spirv, R"( + %half = OpFConvert _type(half) resultId _0; + %lowBits = OpBitcast _type(uint16_t) resultId %half; + OpUConvert resultType resultId %lowBits)") [__readNone] uint f32tof16(float value); @@ -2123,6 +2138,7 @@ vector<uint, N> f32tof16(vector<float, N> value) __target_intrinsic(glsl, "unpackHalf2x16($0).x") __target_intrinsic(cuda, "__half2float") +__target_intrinsic(spirv, "OpFConvert resultType resultId _0") __glsl_version(420) [__readNone] float f16tof32(float16_t value); @@ -2130,6 +2146,7 @@ float f16tof32(float16_t value); __generic<let N : int> __target_intrinsic(hlsl) __target_intrinsic(cuda, "__half2float") +__target_intrinsic(spirv, "OpFConvert resultType resultId _0") [__readNone] vector<float, N> f16tof32(vector<float16_t, N> value) { @@ -2140,11 +2157,13 @@ vector<float, N> f16tof32(vector<float16_t, N> value) __target_intrinsic(glsl, "packHalf2x16(vec2($0,0.0))") __glsl_version(420) __target_intrinsic(cuda, "__float2half") +__target_intrinsic(spirv, "OpFConvert resultType resultId _0") [__readNone] float16_t f32tof16_(float value); __generic<let N : int> __target_intrinsic(cuda, "__float2half") +__target_intrinsic(spirv, "OpFConvert resultType resultId _0") [__readNone] vector<float16_t, N> f32tof16_(vector<float, N> value) { diff --git a/source/slang/slang-emit-base.cpp b/source/slang/slang-emit-base.cpp index d565edb24..7ee3ea9ca 100644 --- a/source/slang/slang-emit-base.cpp +++ b/source/slang/slang-emit-base.cpp @@ -1,4 +1,5 @@ #include "slang-emit-base.h" +#include "slang-ir-util.h" namespace Slang { @@ -45,11 +46,7 @@ void SourceEmitterBase::handleRequiredCapabilities(IRInst* inst) IRVarLayout* SourceEmitterBase::getVarLayout(IRInst* var) { - auto decoration = var->findDecoration<IRLayoutDecoration>(); - if (!decoration) - return nullptr; - - return as<IRVarLayout>(decoration->getLayout()); + return findVarLayout(var); } BaseType SourceEmitterBase::extractBaseType(IRType* inType) diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h index 60431b76c..57c074751 100644 --- a/source/slang/slang-emit-spirv-ops.h +++ b/source/slang/slang-emit-spirv-ops.h @@ -1012,14 +1012,14 @@ SpvInst* emitOpMemberDecorateUserSemantic( } // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpVectorShuffle -template<typename T1, typename T2, typename T3, Index N> +template<typename T1, typename T2, typename T3> SpvInst* emitOpVectorShuffle( SpvInstParent* parent, IRInst* inst, const T1& idResultType, const T2& vector1, const T3& vector2, - const Array<SpvLiteralInteger, N>& components + ArrayView<SpvLiteralInteger> components ) { static_assert(isSingular<T1>); diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 3febbd210..0878b3494 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1168,16 +1168,22 @@ struct SPIRVEmitContext // > OpTypeInt - case kIROp_UInt8Type: case kIROp_UInt16Type: + case kIROp_Int16Type: + case kIROp_UInt8Type: case kIROp_UIntType: case kIROp_UInt64Type: case kIROp_Int8Type: - case kIROp_Int16Type: case kIROp_IntType: case kIROp_Int64Type: { const IntInfo i = getIntTypeInfo(as<IRType>(inst)); + if (i.width == 16) + requireSPIRVCapability(SpvCapabilityInt16); + else if (i.width == 64) + requireSPIRVCapability(SpvCapabilityInt64); + else if (i.width == 8) + requireSPIRVCapability(SpvCapabilityInt8); return emitOpTypeInt( inst, SpvLiteralInteger::from32(int32_t(i.width)), @@ -1395,6 +1401,8 @@ struct SPIRVEmitContext return emitGlobalParam(as<IRGlobalParam>(inst)); case kIROp_GlobalVar: return emitGlobalVar(as<IRGlobalVar>(inst)); + case kIROp_Var: + return emitVar(getSection(SpvLogicalSectionID::GlobalVariables), inst); // ... case kIROp_Specialize: @@ -1436,6 +1444,8 @@ struct SPIRVEmitContext } return result; } + case kIROp_GetStringHash: + return emitGetStringHash(inst); default: { String e = "Unhandled global inst in spirv-emit:\n" @@ -1468,6 +1478,8 @@ struct SPIRVEmitContext void emitVarLayout(SpvInst* varInst, IRVarLayout* layout) { + bool needDefaultSetBindingDecoration = false; + bool hasExplicitSetBinding = false; for (auto rr : layout->getOffsetAttrs()) { UInt index = rr->getOffset(); @@ -1498,15 +1510,6 @@ struct SPIRVEmitContext varInst, SpvLiteralInteger::from32(int32_t(index)) ); - if (space) - { - emitOpDecorateIndex( - getSection(SpvLogicalSectionID::Annotations), - nullptr, - varInst, - SpvLiteralInteger::from32(int32_t(space)) - ); - } break; case LayoutResourceKind::SpecializationConstant: @@ -1527,24 +1530,44 @@ struct SPIRVEmitContext getSection(SpvLogicalSectionID::Annotations), nullptr, varInst, - SpvLiteralInteger::from32(int32_t(index)) - ); + SpvLiteralInteger::from32(int32_t(index))); + if (space) + { + emitOpDecorateDescriptorSet( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + varInst, + SpvLiteralInteger::from32(int32_t(space))); + } + else + { + needDefaultSetBindingDecoration = true; + } + break; + case LayoutResourceKind::RegisterSpace: emitOpDecorateDescriptorSet( getSection(SpvLogicalSectionID::Annotations), nullptr, varInst, - SpvLiteralInteger::from32(int32_t(space)) - ); + SpvLiteralInteger::from32(int32_t(index))); + hasExplicitSetBinding = true; break; default: break; } } + if (needDefaultSetBindingDecoration && !hasExplicitSetBinding) + { + emitOpDecorateDescriptorSet( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + varInst, + SpvLiteralInteger::from32(int32_t(0))); + } } /// Emit a global parameter definition. SpvInst* emitGlobalParam(IRGlobalParam* param) { - auto layout = getVarLayout(param); auto storageClass = SpvStorageClassUniform; if (auto ptrType = as<IRPtrTypeBase>(param->getDataType())) { @@ -1562,7 +1585,8 @@ struct SPIRVEmitContext param->getDataType(), storageClass ); - emitVarLayout(varInst, layout); + if (auto layout = getVarLayout(param)) + emitVarLayout(varInst, layout); return varInst; } @@ -1615,6 +1639,8 @@ struct SPIRVEmitContext /// Emit a declaration for the given `irFunc` SpvInst* emitFuncDeclaration(IRFunc* irFunc) { + if (irFunc->findDecorationImpl(kIROp_SPIRVOpDecoration)) + return nullptr; // For now we aren't handling function declarations; // we expect to deal only with fully linked modules. // @@ -1852,6 +1878,10 @@ struct SPIRVEmitContext return emitLoad(parent, as<IRLoad>(inst)); case kIROp_Store: return emitStore(parent, as<IRStore>(inst)); + case kIROp_SwizzledStore: + return emitSwizzledStore(parent, as<IRSwizzledStore>(inst)); + case kIROp_swizzleSet: + return emitSwizzleSet(parent, as<IRSwizzleSet>(inst)); case kIROp_RWStructuredBufferGetElementPtr: return emitStructuredBufferGetElementPtr(parent, inst); case kIROp_StructuredBufferGetDimensions: @@ -1866,10 +1896,6 @@ struct SPIRVEmitContext return emitIntToFloatCast(parent, as<IRCastIntToFloat>(inst)); case kIROp_CastFloatToInt: return emitFloatToIntCast(parent, as<IRCastFloatToInt>(inst)); - case kIROp_MatrixReshape: - case kIROp_VectorReshape: - // TODO: break emitConstruct into separate functions for each opcode. - return emitConstruct(parent, inst); case kIROp_BitCast: return emitOpBitcast( parent, @@ -1991,6 +2017,28 @@ struct SPIRVEmitContext return emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, OperandsOf(inst)); case kIROp_DebugLine: return emitDebugLine(parent, as<IRDebugLine>(inst)); + case kIROp_GetStringHash: + return emitGetStringHash(inst); + + } + } + + SpvInst* emitGetStringHash(IRInst* inst) + { + auto getStringHashInst = as<IRGetStringHash>(inst); + auto stringLit = getStringHashInst->getStringLit(); + + if (stringLit) + { + auto slice = stringLit->getStringSlice(); + return emitIntConstant(getStableHashCode32(slice.begin(), slice.getLength()).hash, inst->getDataType()); + } + else + { + // Couldn't handle + String e = "Unhandled local inst in spirv-emit:\n" + + dumpIRToString(inst, { IRDumpOptions::Mode::Detailed, 0 }); + SLANG_UNIMPLEMENTED_X(e.getBuffer()); } } @@ -2236,7 +2284,14 @@ struct SPIRVEmitContext for (auto field : structType->getFields()) { IRIntegerValue offset = 0; - getOffset(IRTypeLayoutRules::get(layoutRuleName), field, &offset); + if (auto offsetDecor = field->getKey()->findDecoration<IRPackOffsetDecoration>()) + { + offset = (getIntVal(offsetDecor->getRegisterOffset()) * 4 + getIntVal(offsetDecor->getComponentOffset())) * 4; + } + else + { + getOffset(IRTypeLayoutRules::get(layoutRuleName), field, &offset); + } emitOpMemberDecorateOffset( getSection(SpvLogicalSectionID::Annotations), nullptr, @@ -2368,14 +2423,168 @@ struct SPIRVEmitContext { String semanticName = systemValueAttr->getName(); semanticName = semanticName.toLower(); - if (semanticName == "sv_dispatchthreadid") + if (semanticName == "sv_position") + { + auto importDecor = inst->findDecoration<IRImportDecoration>(); + if (importDecor->getMangledName() == "gl_FragCoord") + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragCoord); + else + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPosition); + } + else if (semanticName == "sv_target") + { + // Note: we do *not* need to generate some kind of `gl_` + // builtin for fragment-shader outputs: they are just + // ordinary `out` variables, with ordinary `location`s, + // as far as GLSL is concerned. + return nullptr; + } + else if (semanticName == "sv_clipdistance") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInClipDistance); + } + else if (semanticName == "sv_culldistance") + { + requireSPIRVCapability(SpvCapabilityCullDistance); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInCullDistance); + } + else if (semanticName == "sv_coverage") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInSampleMask); + } + else if (semanticName == "sv_innercoverage") + { + requireSPIRVCapability(SpvCapabilityFragmentFullyCoveredEXT); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFullyCoveredEXT); + } + else if (semanticName == "sv_depth") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth); + } + else if (semanticName == "sv_depthgreaterequal") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth); + } + else if (semanticName == "sv_depthlessequal") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth); + } + else if (semanticName == "sv_dispatchthreadid") { return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInGlobalInvocationId); } + else if (semanticName == "sv_domainlocation") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInTessCoord); + } + else if (semanticName == "sv_groupid") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInWorkgroupId); + } else if (semanticName == "sv_groupindex") { return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLocalInvocationIndex); } + else if (semanticName == "sv_groupthreadid") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLocalInvocationId); + } + else if (semanticName == "sv_gsinstanceid") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInvocationId); + } + else if (semanticName == "sv_instanceid") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInstanceIndex); + } + else if (semanticName == "sv_isfrontface") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFrontFacing); + } + else if (semanticName == "sv_outputcontrolpointid") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInvocationId); + } + else if (semanticName == "sv_pointsize") + { + // float in hlsl & glsl + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPointSize); + } + else if (semanticName == "sv_primitiveid") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitiveId); + } + else if (semanticName == "sv_rendertargetarrayindex") + { + requireSPIRVCapability(SpvCapabilityShaderLayer); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLayer); + } + else if (semanticName == "sv_sampleindex") + { + requireSPIRVCapability(SpvCapabilitySampleRateShading); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInSampleId); + } + else if (semanticName == "sv_stencilref") + { + requireSPIRVCapability(SpvCapabilityStencilExportEXT); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragStencilRefEXT); + } + else if (semanticName == "sv_tessfactor") + { + requireSPIRVCapability(SpvCapabilityTessellation); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInTessLevelOuter); + } + else if (semanticName == "sv_vertexid") + { + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInVertexId); + } + else if (semanticName == "sv_viewid") + { + requireSPIRVCapability(SpvCapabilityMultiView); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewIndex); + } + else if (semanticName == "sv_viewportarrayindex") + { + requireSPIRVCapability(SpvCapabilityShaderViewportIndex); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewportIndex); + } + else if (semanticName == "nv_x_right") + { + SLANG_UNIMPLEMENTED_X("spirv emit for nv_x_right"); + } + else if (semanticName == "nv_viewport_mask") + { + requireSPIRVCapability(SpvCapabilityPerViewAttributesNV); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewportMaskPerViewNV); + } + else if (semanticName == "sv_barycentrics") + { + if (m_targetRequest->getTargetCaps().implies(CapabilityAtom::GL_NV_fragment_shader_barycentric)) + { + requireSPIRVCapability(SpvCapabilityFragmentBarycentricNV); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInBaryCoordNV); + } + else + { + requireSPIRVCapability(SpvCapabilityFragmentBarycentricKHR); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInBaryCoordKHR); + } + + // TODO: There is also the `gl_BaryCoordNoPerspNV` builtin, which + // we ought to use if the `noperspective` modifier has been + // applied to this varying input. + } + else if (semanticName == "sv_cullprimitive") + { + requireSPIRVCapability(SpvCapabilityMeshShadingEXT); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInCullPrimitiveEXT); + } + else if (semanticName == "sv_shadingrate") + { + requireSPIRVCapability(SpvCapabilityFragmentShadingRateKHR); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitiveShadingRateKHR); + } + SLANG_UNREACHABLE("Unimplemented system value in spirv emit."); } } return nullptr; @@ -2571,6 +2780,14 @@ struct SPIRVEmitContext { return emitIntrinsicCallExpr(parent, static_cast<IRCall*>(inst), targetIntrinsic); } + else if (auto spvOpDecor = funcValue->findDecorationImpl(kIROp_SPIRVOpDecoration)) + { + SpvOp op = (SpvOp)getIntVal(spvOpDecor->getOperand(0)); + List<IRInst*> args; + for (UInt i = 0; i < inst->getArgCount(); i++) + args.add(inst->getArg(i)); + return emitInst(parent, inst, op, inst->getFullType(), kResultID, args); + } else { return emitOpFunctionCall( @@ -2658,6 +2875,9 @@ struct SPIRVEmitContext case SpvSnippet::ASMType::Int: result = emitIntConstant((IRIntegerValue)constant.intValues[0], builder.getIntType()); break; + case SpvSnippet::ASMType::UInt16: + result = emitIntConstant((IRIntegerValue)constant.intValues[0], builder.getType(kIROp_UInt16Type)); + break; case SpvSnippet::ASMType::UInt2: { auto uintType = builder.getType(kIROp_UIntType); @@ -2686,12 +2906,18 @@ struct SPIRVEmitContext case SpvSnippet::ASMType::Float: irType = builder.getType(kIROp_FloatType); break; + case SpvSnippet::ASMType::Half: + irType = builder.getType(kIROp_HalfType); + break; case SpvSnippet::ASMType::Int: irType = builder.getIntType(); break; case SpvSnippet::ASMType::UInt: irType = builder.getUIntType(); break; + case SpvSnippet::ASMType::UInt16: + irType = builder.getType(kIROp_UInt16Type); + break; case SpvSnippet::ASMType::Float2: irType = builder.getVectorType( builder.getType(kIROp_FloatType), builder.getIntValue(builder.getIntType(), 2)); @@ -2969,6 +3195,46 @@ struct SPIRVEmitContext return emitOpStore(parent, inst, inst->getPtr(), inst->getVal()); } + SpvInst* emitSwizzledStore(SpvInstParent* parent, IRSwizzledStore* inst) + { + auto sourceVectorType = as<IRVectorType>(inst->getSource()->getDataType()); + SLANG_ASSERT(sourceVectorType); + auto sourceElementType = sourceVectorType->getElementType(); + SLANG_ASSERT(getIntVal(sourceVectorType->getElementCount()) == (IRIntegerValue)inst->getElementCount()); + SpvInst* result = nullptr; + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto destPtrType = as<IRPtrTypeBase>(inst->getDest()->getDataType()); + SpvStorageClass addrSpace = SpvStorageClassFunction; + if (destPtrType->hasAddressSpace()) + addrSpace = (SpvStorageClass)destPtrType->getAddressSpace(); + auto ptrElementType = builder.getPtrType(kIROp_PtrType, sourceElementType, addrSpace); + for (UInt i = 0; i < inst->getElementCount(); i++) + { + auto index = inst->getElementIndex(i); + auto addr = emitOpAccessChain(parent, nullptr, ptrElementType, inst->getDest(), makeArray(index)); + auto val = emitOpCompositeExtract(parent, nullptr, sourceElementType, inst->getSource(), makeArray(SpvLiteralInteger::from32((int32_t)i))); + result = emitOpStore(parent, (i == inst->getElementCount() - 1 ? inst : nullptr), addr, val); + } + return result; + } + + SpvInst* emitSwizzleSet(SpvInstParent* parent, IRSwizzleSet* inst) + { + auto resultVectorType = as<IRVectorType>(inst->getDataType()); + List<SpvLiteralInteger> shuffleIndices; + shuffleIndices.setCount((Index)getIntVal(resultVectorType->getElementCount())); + for (Index i = 0; i < shuffleIndices.getCount(); i++) + shuffleIndices[i] = SpvLiteralInteger::from32((int32_t)i); + for (UInt i = 0; i < inst->getElementCount(); i++) + { + auto destIndex = (int32_t)getIntVal(inst->getElementIndex(i)); + SLANG_ASSERT(destIndex < shuffleIndices.getCount()); + shuffleIndices[destIndex] = SpvLiteralInteger::from32((int32_t)(i + shuffleIndices.getCount())); + } + return emitOpVectorShuffle(parent, inst, inst->getFullType(), inst->getBase(), inst->getSource(), shuffleIndices.getArrayView()); + } + SpvInst* emitStructuredBufferGetElementPtr(SpvInstParent* parent, IRInst* inst) { //"%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1;" @@ -3035,6 +3301,40 @@ struct SPIRVEmitContext SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); const auto fromType = dropVector(fromTypeV); const auto toType = dropVector(toTypeV); + + if (as<IRBoolType>(fromType)) + { + // Cast from bool to int. + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto zero = builder.getIntValue(toType, 0); + auto one = builder.getIntValue(toType, 1); + if (auto vecType = as<IRVectorType>(toTypeV)) + { + auto zeroV = emitSplat(parent, nullptr, zero, getIntVal(vecType->getElementCount())); + auto oneV = emitSplat(parent, nullptr, one, getIntVal(vecType->getElementCount())); + return emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, inst->getOperand(0), + oneV, zeroV); + } + return emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, inst->getOperand(0), one, zero); + } + else if (as<IRBoolType>(toType)) + { + // Cast from int to bool. + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto zero = builder.getIntValue(fromType, 0); + if (auto vecType = as<IRVectorType>(toTypeV)) + { + auto zeroV = emitSplat(parent, nullptr, zero, getIntVal(vecType->getElementCount())); + return emitOpINotEqual(parent, inst, inst->getFullType(), inst->getOperand(0), zeroV); + } + else + { + return emitOpINotEqual(parent, inst, inst->getFullType(), inst->getOperand(0), zero); + } + } + SLANG_ASSERT(isIntegralType(fromType)); SLANG_ASSERT(isIntegralType(toType)); @@ -3100,6 +3400,24 @@ struct SPIRVEmitContext const auto fromType = dropVector(fromTypeV); const auto toType = dropVector(toTypeV); SLANG_ASSERT(isFloatingType(fromType)); + + if (as<IRBoolType>(toType)) + { + // Float to bool cast. + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto zero = builder.getIntValue(fromType, 0); + if (auto vecType = as<IRVectorType>(toTypeV)) + { + auto zeroV = emitSplat(parent, nullptr, zero, getIntVal(vecType->getElementCount())); + return emitInst(parent, inst, SpvOpFUnordNotEqual, inst->getFullType(), kResultID, inst->getOperand(0), zeroV); + } + else + { + return emitInst(parent, inst, SpvOpFUnordNotEqual, inst->getFullType(), kResultID, inst->getOperand(0), zero); + } + } + SLANG_ASSERT(isIntegralType(toType)); const auto toInfo = getIntTypeInfo(toType); @@ -3279,14 +3597,9 @@ struct SPIRVEmitContext } } - SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) + SpvInst* emitVectorOrScalarArithmetic(SpvInstParent* parent, IRInst* instToRegister, IRInst* type, IROp op, UInt operandCount, ArrayView<IRInst*> operands) { - IRType* elementType = dropVector(inst->getOperand(0)->getDataType()); - if (const auto matrixType = as<IRMatrixType>(inst->getDataType())) - { - //TODO: implement. - SLANG_ASSERT(!"unimplemented: matrix arithemetic"); - } + IRType* elementType = dropVector(operands[0]->getDataType()); IRBasicType* basicType = as<IRBasicType>(elementType); bool isFloatingPoint = false; bool isBool = false; @@ -3305,7 +3618,7 @@ struct SPIRVEmitContext } SpvOp opCode = SpvOpUndef; bool isSigned = isSignedType(basicType); - switch (inst->getOp()) + switch (op) { case kIROp_Add: opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd; @@ -3323,30 +3636,30 @@ struct SPIRVEmitContext opCode = isSigned ? SpvOpSRem : SpvOpUMod; break; case kIROp_FRem: - opCode = SpvOpFRem; + opCode = SpvOpFMod; break; case kIROp_Less: opCode = isFloatingPoint ? SpvOpFOrdLessThan - : isSigned ? SpvOpSLessThan : SpvOpULessThan; + : isSigned ? SpvOpSLessThan : SpvOpULessThan; break; case kIROp_Leq: opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual - : isSigned ? SpvOpSLessThanEqual : SpvOpULessThanEqual; + : isSigned ? SpvOpSLessThanEqual : SpvOpULessThanEqual; break; case kIROp_Eql: opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual; break; case kIROp_Neq: opCode = isFloatingPoint ? SpvOpFOrdNotEqual - : isBool ? SpvOpLogicalNotEqual : SpvOpINotEqual; + : isBool ? SpvOpLogicalNotEqual : SpvOpINotEqual; break; case kIROp_Geq: opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual - : isSigned ? SpvOpSGreaterThanEqual : SpvOpUGreaterThanEqual; + : isSigned ? SpvOpSGreaterThanEqual : SpvOpUGreaterThanEqual; break; case kIROp_Greater: opCode = isFloatingPoint ? SpvOpFOrdGreaterThan - : isSigned ? SpvOpSGreaterThan : SpvOpUGreaterThan; + : isSigned ? SpvOpSGreaterThan : SpvOpUGreaterThan; break; case kIROp_Neg: opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate; @@ -3361,16 +3674,28 @@ struct SPIRVEmitContext opCode = SpvOpLogicalNot; break; case kIROp_BitAnd: - opCode = SpvOpBitwiseAnd; + if (isBool) + opCode = SpvOpLogicalAnd; + else + opCode = SpvOpBitwiseAnd; break; case kIROp_BitOr: - opCode = SpvOpBitwiseOr; + if (isBool) + opCode = SpvOpLogicalOr; + else + opCode = SpvOpBitwiseOr; break; case kIROp_BitXor: - opCode = SpvOpBitwiseXor; + if (isBool) + opCode = SpvOpLogicalNotEqual; + else + opCode = SpvOpBitwiseXor; break; case kIROp_BitNot: - opCode = SpvOpBitReverse; + if (isBool) + opCode = SpvOpLogicalNot; + else + opCode = SpvOpBitReverse; break; case kIROp_Rsh: opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical; @@ -3382,20 +3707,20 @@ struct SPIRVEmitContext SLANG_ASSERT(!"unknown arithmetic opcode"); break; } - if(inst->getOperandCount() == 1) + if (operandCount == 1) { - return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, OperandsOf(inst)); + return emitInst(parent, instToRegister, opCode, type, kResultID, operands); } - else if(inst->getOperandCount() == 2) + else if (operandCount == 2) { - auto l = inst->getOperand(0); + auto l = operands[0]; const auto lVec = as<IRVectorType>(l->getDataType()); - auto r = inst->getOperand(1); + auto r = operands[1]; const auto rVec = as<IRVectorType>(r->getDataType()); - const auto go = [&](const auto l, const auto r){ - return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, l, r); + const auto go = [&](const auto l, const auto r) { + return emitInst(parent, instToRegister, opCode, type, kResultID, l, r); }; - if(lVec && !rVec) + if (lVec && !rVec) { const auto len = as<IRIntLit>(lVec->getElementCount()); SLANG_ASSERT(len); @@ -3412,6 +3737,45 @@ struct SPIRVEmitContext SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands"); } + + SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) + { + if (const auto matrixType = as<IRMatrixType>(inst->getDataType())) + { + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); + List<SpvInst*> rows; + for (IRIntegerValue i = 0; i < rowCount; i++) + { + List<IRInst*> operands; + for (UInt j = 0; j < inst->getOperandCount(); j++) + { + auto originalOperand = inst->getOperand(j); + if (as<IRMatrixType>(originalOperand->getDataType())) + { + auto operand = builder.emitElementExtract(originalOperand, i); + emitLocalInst(parent, operand); + operands.add(operand); + } + else + { + operands.add(originalOperand); + } + } + rows.add(emitVectorOrScalarArithmetic(parent, nullptr, rowVectorType, inst->getOp(), inst->getOperandCount(), operands.getArrayView())); + } + return emitCompositeConstruct(parent, inst, inst->getDataType(), rows); + } + + Array<IRInst*, 4> operands; + for (UInt i = 0; i < inst->getOperandCount(); i++) + operands.add(inst->getOperand(i)); + return emitVectorOrScalarArithmetic(parent, inst, inst->getDataType(), inst->getOp(), inst->getOperandCount(), operands.getView()); + } + SpvInst* emitDebugLine(SpvInstParent* parent, IRDebugLine* debugLine) { auto scope = findDebugScope(debugLine); diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp index ed7818dbf..200f8bc55 100644 --- a/source/slang/slang-ir-any-value-marshalling.cpp +++ b/source/slang/slang-ir-any-value-marshalling.cpp @@ -241,7 +241,6 @@ namespace Slang { case kIROp_IntType: case kIROp_FloatType: - case kIROp_BoolType: #if SLANG_PTR_IS_32 case kIROp_IntPtrType: #endif @@ -260,6 +259,23 @@ namespace Slang advanceOffset(4); break; } + case kIROp_BoolType: + { + ensureOffsetAt4ByteBoundary(); + if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) + { + auto srcVal = builder->emitLoad(concreteVar); + IRInst* args[] = {srcVal, builder->getIntValue(builder->getUIntType(), 1), builder->getIntValue(builder->getUIntType(), 0) }; + auto dstVal = builder->emitIntrinsicInst(builder->getUIntType(), kIROp_Select, 3, args); + auto dstAddr = builder->emitFieldAddress( + uintPtrType, + anyValueVar, + anyValInfo->fieldKeys[fieldOffset]); + builder->emitStore(dstAddr, dstVal); + } + advanceOffset(4); + break; + } case kIROp_UIntType: #if SLANG_PTR_IS_32 case kIROp_UIntPtrType: @@ -416,7 +432,6 @@ namespace Slang { case kIROp_IntType: case kIROp_FloatType: - case kIROp_BoolType: { ensureOffsetAt4ByteBoundary(); if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) @@ -432,6 +447,22 @@ namespace Slang advanceOffset(4); break; } + case kIROp_BoolType: + { + ensureOffsetAt4ByteBoundary(); + if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount())) + { + auto srcAddr = builder->emitFieldAddress( + uintPtrType, + anyValueVar, + anyValInfo->fieldKeys[fieldOffset]); + auto srcVal = builder->emitLoad(srcAddr); + srcVal = builder->emitNeq(srcVal, builder->getIntValue(builder->getUIntType(), 0)); + builder->emitStore(concreteVar, srcVal); + } + advanceOffset(4); + break; + } case kIROp_UIntType: { ensureOffsetAt4ByteBoundary(); diff --git a/source/slang/slang-ir-bind-existentials.cpp b/source/slang/slang-ir-bind-existentials.cpp index 35c07452a..f4fbf6e20 100644 --- a/source/slang/slang-ir-bind-existentials.cpp +++ b/source/slang/slang-ir-bind-existentials.cpp @@ -351,8 +351,7 @@ struct BindExistentialSlots inst, slotOperandCount, slotOperands.getBuffer()); - - use->set(newVal); + builder.replaceOperand(use, newVal); } } }; diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 5c4e1d037..4c5b58882 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -1488,8 +1488,24 @@ ScalarizedVal adaptType( IRBuilder* builder, IRInst* val, IRType* toType, - IRType* /*fromType*/) + IRType* fromType) { + if (auto fromVector = as<IRVectorType>(fromType)) + { + if (auto toVector = as<IRVectorType>(toType)) + { + if (fromVector->getElementCount() != toVector->getElementCount()) + { + fromType = builder->getVectorType(fromVector->getElementType(), toVector->getElementCount()); + val = builder->emitVectorReshape(fromType, val); + } + } + else if (auto toBasicType = as<IRBasicType>(toType)) + { + UInt index = 0; + val = builder->emitSwizzle(fromVector->getElementType(), val, 1, &index); + } + } // TODO: actually consider what needs to go on here... return ScalarizedVal::value(builder->emitCast( toType, diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index e878c56f7..df245f555 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -84,6 +84,10 @@ struct InliningPassBase { changed = considerCallSiteInFunc(func); } + else if (auto call = as<IRCall>(inst)) + { + considerCallSite(call); + } // Recursively consider the children of inst. for (auto child : inst->getModifiableChildren()) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 22d8c5538..0b2cc790d 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -99,7 +99,9 @@ struct IRTargetSpecificDecoration : IRDecoration IRType* getTypeScrutinee() { SLANG_ASSERT(getOperandCount() == 4); - const auto t = as<IRType>(getOperand(3)); + // Note: cannot use as<IRType> here because the operand can be + // an `IRParam` representing a generic type. + const auto t = (IRType*)(getOperand(3)); SLANG_ASSERT(t); return t; } diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp index 244ff2039..cba6894a9 100644 --- a/source/slang/slang-ir-layout.cpp +++ b/source/slang/slang-ir-layout.cpp @@ -404,9 +404,10 @@ struct Std430LayoutRules : IRTypeLayoutRules } virtual IRSizeAndAlignment getVectorSizeAndAlignment(IRSizeAndAlignment element, IRIntegerValue count) { + IRIntegerValue countForAlignment = count; if (count == 3) - count = 4; - return IRSizeAndAlignment((int)(element.size * count), (int)(element.size * count)); + countForAlignment = 4; + return IRSizeAndAlignment((int)(element.size * count), (int)(element.size * countForAlignment)); } }; diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index c2d7ba95f..cb1cf3db3 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -1954,13 +1954,6 @@ static LegalVal legalizeInst( } } -IRVarLayout* findVarLayout(IRInst* value) -{ - if (auto layoutDecoration = value->findDecoration<IRLayoutDecoration>()) - return as<IRVarLayout>(layoutDecoration->getLayout()); - return nullptr; -} - static UnownedStringSlice findNameHint(IRInst* inst) { if( auto nameHintDecoration = inst->findDecoration<IRNameHintDecoration>() ) diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 909ffea83..bf87d72fe 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -224,10 +224,18 @@ namespace Slang if (auto matrixType = as<IRMatrixType>(type)) { - if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout) + // For spirv, we always want to lower all matrix types, because matrix types + // are considered abstract types. + if (!target->shouldEmitSPIRVDirectly()) { - info.loweredType = type; - return info; + // For other targets, we only lower the matrix types if they differ from the default + // matrix layout. + if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout && + rules->ruleName == IRTypeLayoutRuleName::Natural) + { + info.loweredType = type; + return info; + } } auto loweredType = builder.createStructType(); @@ -264,7 +272,7 @@ namespace Slang else if (auto arrayType = as<IRArrayType>(type)) { auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType(), rules); - if (!loweredInnerTypeInfo.convertLoweredToOriginal && rules->ruleName == IRTypeLayoutRuleName::Natural) + if (!loweredInnerTypeInfo.convertLoweredToOriginal) { info.loweredType = type; return info; @@ -378,6 +386,45 @@ namespace Slang return info; } + switch (target->getTarget()) + { + case CodeGenTarget::SPIRV: + case CodeGenTarget::SPIRVAssembly: + if (as<IRBoolType>(type)) + { + // Bool is an abstract type in SPIRV, so we need to lower them into an int. + info.loweredType = builder.getIntType(); + // Create unpack func. + { + builder.setInsertAfter(type); + info.convertLoweredToOriginal = builder.createFunc(); + builder.setInsertInto(info.convertLoweredToOriginal); + builder.addNameHintDecoration(info.convertLoweredToOriginal, UnownedStringSlice("unpackStorage")); + info.convertLoweredToOriginal->setFullType(builder.getFuncType(1, (IRType**)&info.loweredType, type)); + builder.emitBlock(); + auto loweredParam = builder.emitParam(info.loweredType); + auto result = builder.emitCast(type, loweredParam); + builder.emitReturn(result); + } + + // Create pack func. + { + builder.setInsertAfter(info.convertLoweredToOriginal); + info.convertOriginalToLowered = builder.createFunc(); + builder.setInsertInto(info.convertOriginalToLowered); + builder.addNameHintDecoration(info.convertOriginalToLowered, UnownedStringSlice("packStorage")); + info.convertOriginalToLowered->setFullType(builder.getFuncType(1, (IRType**)&type, info.loweredType)); + builder.emitBlock(); + auto param = builder.emitParam(type); + auto result = builder.emitCast(info.loweredType, param); + builder.emitReturn(result); + } + } + break; + default: + break; + } + info.loweredType = type; return info; } @@ -506,7 +553,7 @@ namespace Slang builder.setInsertBefore(user); auto newLoad = cloneInst(&cloneEnv, &builder, user); newLoad->setFullType(loweredElementTypeInfo.loweredType); - auto unpackedVal = builder.emitCallInst(elementType, loweredElementTypeInfo.convertLoweredToOriginal, 1, &newLoad); + auto unpackedVal = builder.emitCallInst((IRType*)originalElementType, loweredElementTypeInfo.convertLoweredToOriginal, 1, &newLoad); user->replaceUsesWith(unpackedVal); user->removeAndDeallocate(); break; diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index 600361b2f..d83c6dccd 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -683,6 +683,83 @@ struct PeepholeContext : InstPassBase } } break; + case kIROp_VectorReshape: + { + auto fromType = as<IRVectorType>(inst->getOperand(0)->getDataType()); + auto resultType = as<IRVectorType>(inst->getDataType()); + if (!resultType) + { + if (!fromType) + { + inst->replaceUsesWith(inst->getOperand(0)); + maybeRemoveOldInst(inst); + changed = true; + break; + } + IRBuilder builder(inst); + builder.setInsertBefore(inst); + UInt index = 0; + auto newInst = builder.emitSwizzle(resultType, inst->getOperand(0), 1, &index); + inst->replaceUsesWith(newInst); + maybeRemoveOldInst(inst); + changed = true; + break; + } + auto fromCount = as<IRIntLit>(fromType->getElementCount()); + if (!fromCount) + break; + auto toCount = as<IRIntLit>(resultType->getElementCount()); + if (!toCount) + break; + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto newInst = builder.emitVectorReshape(resultType, inst->getOperand(0)); + if (newInst != inst) + { + inst->replaceUsesWith(newInst); + maybeRemoveOldInst(inst); + changed = true; + } + } + break; + case kIROp_MatrixReshape: + { + auto fromType = as<IRMatrixType>(inst->getOperand(0)->getDataType()); + auto resultType = as<IRMatrixType>(inst->getDataType()); + SLANG_ASSERT(fromType && resultType); + auto fromRows = as<IRIntLit>(fromType->getRowCount()); + if (!fromRows) break; + auto fromCols = as<IRIntLit>(fromType->getColumnCount()); + if (!fromCols) break; + auto toRows = as<IRIntLit>(resultType->getRowCount()); + if (!toRows) break; + auto toCols = as<IRIntLit>(resultType->getColumnCount()); + if (!toCols) break; + List<IRInst*> rows; + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto toRowType = builder.getVectorType(resultType->getElementType(), resultType->getColumnCount()); + for (IRIntegerValue i = 0; i < toRows->getValue(); i++) + { + if (i < fromRows->getValue()) + { + auto originalRow = builder.emitElementExtract(inst->getOperand(0), i); + auto resizedRow = builder.emitVectorReshape(toRowType, originalRow); + rows.add(resizedRow); + } + else + { + auto zero = builder.emitDefaultConstruct(resultType->getElementType()); + auto row = builder.emitMakeVectorFromScalar(toRowType, zero); + rows.add(row); + } + } + auto newInst = builder.emitMakeMatrix(resultType, (UInt)rows.getCount(), rows.getBuffer()); + inst->replaceUsesWith(newInst); + maybeRemoveOldInst(inst); + changed = true; + } + break; case kIROp_Add: case kIROp_Mul: case kIROp_Sub: diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index 63c40b16f..e848d11c1 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -473,17 +473,74 @@ static bool trySimplifyIfElse(IRBuilder& builder, IRIfElse* ifElseInst) static bool trySimplifySwitch(IRBuilder& builder, IRSwitch* switchInst) { + // First, we fuse switch case blocks that is a trivial branch. + // If we see: + // ``` + // someBlock: + // switch(..., case_block_A, ...) + // case_block_A: + // branch blockB; + // ``` + // Then we fold blockB into the switch case operand: + // ``` + // someBlock: + // switch(..., blockB, ...) + // ``` + // We can do this if `blockB` is not a merge block. + // + bool changed = false; + auto fuseSwitchCaseBlock = [&](IRUse* targetUse) + { + for (;;) + { + auto block = as<IRBlock>(targetUse->get()); + if (block->getFirstInst()->getOp() != kIROp_unconditionalBranch) + return; + auto branch = as<IRUnconditionalBranch>(block->getFirstInst()); + // We can't fuse the block if there are phi arguments. + if (branch->getArgCount() != 0) + return; + auto target = branch->getTargetBlock(); + if (target == switchInst->getBreakLabel()) + return; + // target must not be used as a merge block of other control flow constructs. + for (auto use = target->firstUse; use; use = use->nextUse) + { + if (use->getUser() == switchInst || use->getUser() == branch) + continue; + switch (use->getUser()->getOp()) + { + case kIROp_loop: + case kIROp_ifElse: + case kIROp_Switch: + // If the target block is used by a special control flow inst, + // it is likely a merge block and we can't fuse it. + return; + default: + break; + } + } + targetUse->set(target); + changed = true; + } + }; + + fuseSwitchCaseBlock(&switchInst->defaultLabel); + for (UInt i = 0; i < switchInst->getCaseCount(); i++) + fuseSwitchCaseBlock(switchInst->getCaseLabelUse(i)); + + // Next, we check if all switch cases are jumping to the same target. if (!isTrivialSwitch(switchInst)) - return false; + return changed; if (switchInst->getCaseCount() == 0) - return false; + return changed; auto termInst = as<IRUnconditionalBranch>(switchInst->getCaseLabel(0)->getTerminator()); if (!termInst) - return false; + return changed; if (!arePhiArgsEquivalentInBranches(switchInst)) - return false; + return changed; List<IRInst*> args; for (UInt i = 0; i < termInst->getArgCount(); i++) @@ -663,7 +720,7 @@ static bool removeTrivialPhiParams(IRBlock* block) } if (targetVal) { - params[i]->replaceUsesWith(args[i].knownValue); + params[i]->replaceUsesWith(targetVal); params[i]->removeAndDeallocate(); for (auto termInst : termInsts) termInst->removeArgument((UInt)i); @@ -709,6 +766,7 @@ static bool processFunc(IRGlobalValueWithCode* func) { loop->continueBlock.set(loop->getTargetBlock()); continueBlock->removeAndDeallocate(); + changed = true; } // If there isn't any actual back jumps into loop target and there is a trivial diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 36fdbd56a..fd23d1b5e 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -97,23 +97,15 @@ struct SPIRVLegalizationContext : public SourceEmitterBase // Figure out storage class based on var layout. if (auto layout = getVarLayout(inst)) { - if (auto systemValueAttr = layout->findAttr<IRSystemValueSemanticAttr>()) + auto cls = getGlobalParamStorageClass(layout); + if (cls != SpvStorageClassMax) + storageClass = cls; + else if (auto systemValueAttr = layout->findAttr<IRSystemValueSemanticAttr>()) { String semanticName = systemValueAttr->getName(); semanticName = semanticName.toLower(); - if (semanticName == "sv_dispatchthreadid") - { - storageClass = SpvStorageClassInput; - } - else if (semanticName == "sv_groupindex") - { + if (semanticName == "sv_pointsize") storageClass = SpvStorageClassInput; - } - } - else if(const auto parameterGroupTypeLayout = - as<IRParameterGroupTypeLayout>(layout->getTypeLayout())) - { - storageClass = SpvStorageClassUniform; } } @@ -121,10 +113,13 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(m_sharedContext->m_irModule); bool needLoad = true; auto innerType = inst->getFullType(); - if (as<IRConstantBufferType>(innerType) || as<IRParameterBlockType>(innerType)) + auto cbufferType = as<IRConstantBufferType>(innerType); + auto paramBlockType = as<IRParameterBlockType>(innerType); + if (cbufferType || paramBlockType) { innerType = as<IRUniformParameterGroupType>(innerType)->getElementType(); - storageClass = SpvStorageClassUniform; + if (storageClass == SpvStorageClassPrivate) + storageClass = SpvStorageClassUniform; // Constant buffer is already treated like a pointer type, and // we are not adding another layer of indirection when replacing it // with a pointer type. Therefore we don't need to insert a load at @@ -137,6 +132,38 @@ struct SPIRVLegalizationContext : public SourceEmitterBase innerType = wrapConstantBufferElement(inst); } builder.addDecoration(innerType, kIROp_SPIRVBlockDecoration); + + auto varLayoutInst = inst->findDecoration<IRLayoutDecoration>(); + if (paramBlockType) + { + // A parameter block typed global parameter will have a VarLayout + // that contains an OffsetAttr(RegisterSpace, spaceId). + // We need to turn this VarLayout into a standard cbuffer VarLayout + // in the form of OffsetAttr(ConstantBuffer, 0, spaceId). + builder.setInsertBefore(inst); + IRVarLayout* varLayout = nullptr; + if (varLayoutInst) + varLayout = as<IRVarLayout>(varLayoutInst->getLayout()); + if (varLayout) + { + auto registerSpaceOffsetAttr = varLayout->findOffsetAttr(LayoutResourceKind::RegisterSpace); + if (registerSpaceOffsetAttr) + { + List<IRInst*> operands; + for (UInt i = 0; i < varLayout->getOperandCount(); i++) + operands.add(varLayout->getOperand(i)); + operands.add(builder.getVarOffsetAttr(LayoutResourceKind::ConstantBuffer, 0, registerSpaceOffsetAttr->getOffset())); + auto newLayout = builder.getVarLayout(operands); + varLayoutInst->setOperand(0, newLayout); + varLayout->removeAndDeallocate(); + } + } + } + else if (storageClass == SpvStorageClassPushConstant) + { + // Push constant params does not need a VarLayout. + varLayoutInst->removeAndDeallocate(); + } } // Make a pointer type of storageClass. @@ -162,6 +189,70 @@ struct SPIRVLegalizationContext : public SourceEmitterBase processGlobalVar(inst); } + SpvStorageClass getStorageClassFromGlobalParamResourceKind(LayoutResourceKind kind) + { + SpvStorageClass storageClass = SpvStorageClassMax; + switch (kind) + { + case LayoutResourceKind::Uniform: + case LayoutResourceKind::DescriptorTableSlot: + case LayoutResourceKind::ConstantBuffer: + storageClass = SpvStorageClassUniform; + break; + case LayoutResourceKind::VaryingInput: + storageClass = SpvStorageClassInput; + break; + case LayoutResourceKind::VaryingOutput: + storageClass = SpvStorageClassOutput; + break; + case LayoutResourceKind::ShaderResource: + case LayoutResourceKind::UnorderedAccess: + storageClass = SpvStorageClassStorageBuffer; + break; + case LayoutResourceKind::PushConstantBuffer: + storageClass = SpvStorageClassPushConstant; + break; + case LayoutResourceKind::RayPayload: + storageClass = SpvStorageClassRayPayloadKHR; + break; + case LayoutResourceKind::CallablePayload: + storageClass = SpvStorageClassCallableDataKHR; + break; + case LayoutResourceKind::HitAttributes: + storageClass = SpvStorageClassHitAttributeKHR; + break; + case LayoutResourceKind::ShaderRecord: + storageClass = SpvStorageClassShaderRecordBufferKHR; + break; + default: + break; + } + return storageClass; + } + + SpvStorageClass getGlobalParamStorageClass(IRVarLayout* varLayout) + { + SpvStorageClass result = SpvStorageClassMax; + for (auto rr : varLayout->getOffsetAttrs()) + { + auto storageClass = getStorageClassFromGlobalParamResourceKind(rr->getResourceKind()); + // If we haven't inferred a storage class yet, use the one we just found. + if (result == SpvStorageClassMax) + result = storageClass; + else if (result != storageClass) + { + // If we have inferred a storage class, and it is different from the one we just found, + // then we have conflicting uses of the resource, and we cannot infer a storage class. + // An exception is that a uniform storage class can be further specialized by PushConstants. + if (result == SpvStorageClassUniform) + result = storageClass; + else + SLANG_UNEXPECTED("Var layout contains conflicting resource uses, cannot resolve a storage class."); + } + } + return result; + } + void processGlobalVar(IRInst* inst) { auto oldPtrType = as<IRPtrTypeBase>(inst->getDataType()); @@ -184,31 +275,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } else if (const auto varLayout = getVarLayout(inst)) { - for (auto rr : varLayout->getOffsetAttrs()) - { - switch (rr->getResourceKind()) - { - case LayoutResourceKind::Uniform: - case LayoutResourceKind::ShaderResource: - case LayoutResourceKind::DescriptorTableSlot: - storageClass = SpvStorageClassUniform; - break; - case LayoutResourceKind::VaryingInput: - storageClass = SpvStorageClassInput; - break; - case LayoutResourceKind::VaryingOutput: - storageClass = SpvStorageClassOutput; - break; - case LayoutResourceKind::UnorderedAccess: - storageClass = SpvStorageClassStorageBuffer; - break; - case LayoutResourceKind::PushConstantBuffer: - storageClass = SpvStorageClassPushConstant; - break; - default: - break; - } - } + auto cls = getGlobalParamStorageClass(varLayout); + if (cls != SpvStorageClassMax) + storageClass = cls; } IRBuilder builder(m_sharedContext->m_irModule); @@ -312,15 +381,18 @@ struct SPIRVLegalizationContext : public SourceEmitterBase writeBacks.add(WriteBackPair{ arg, tempVar }); } SLANG_ASSERT((UInt)newArgs.getCount() == inst->getArgCount()); - auto newCall = builder.emitCallInst(inst->getFullType(), inst->getCallee(), newArgs); - for (auto wb : writeBacks) + if (writeBacks.getCount()) { - auto newVal = builder.emitLoad(wb.tempVar); - builder.emitStore(wb.originalAddrArg, newVal); + auto newCall = builder.emitCallInst(inst->getFullType(), inst->getCallee(), newArgs); + for (auto wb : writeBacks) + { + auto newVal = builder.emitLoad(wb.tempVar); + builder.emitStore(wb.originalAddrArg, newVal); + } + inst->replaceUsesWith(newCall); + inst->removeAndDeallocate(); + addUsersToWorkList(newCall); } - inst->replaceUsesWith(newCall); - inst->removeAndDeallocate(); - addUsersToWorkList(newCall); } Dictionary<IRInst*, IRInst*> m_mapArrayValueToVar; @@ -346,12 +418,17 @@ struct SPIRVLegalizationContext : public SourceEmitterBase IRBuilder builder(m_sharedContext->m_irModule); IRInst* y = nullptr; + builder.setInsertBefore(inst); if (!m_mapArrayValueToVar.tryGetValue(x, y)) { - setInsertAfterOrdinaryInst(&builder, x); + if (x->getParent()->getOp() == kIROp_Module) + builder.setInsertBefore(inst); + else + setInsertAfterOrdinaryInst(&builder, x); y = builder.emitVar(x->getDataType(), SpvStorageClassFunction); builder.emitStore(y, x); - m_mapArrayValueToVar.set(x, y); + if (x->getParent()->getOp() != kIROp_Module) + m_mapArrayValueToVar.set(x, y); } builder.setInsertBefore(inst); for(Index i = indices.getCount() - 1; i >= 0; --i) diff --git a/source/slang/slang-ir-spirv-snippet.cpp b/source/slang/slang-ir-spirv-snippet.cpp index c19d50ece..98466e8ff 100644 --- a/source/slang/slang-ir-spirv-snippet.cpp +++ b/source/slang/slang-ir-spirv-snippet.cpp @@ -28,6 +28,8 @@ SpvSnippet::ASMType parseASMType(Slang::Misc::TokenReader& tokenReader) return SpvSnippet::ASMType::Double; else if (word == "uint2") return SpvSnippet::ASMType::UInt2; + else if (word == "uint16_t") + return SpvSnippet::ASMType::UInt16; else if (word == "float2") return SpvSnippet::ASMType::Float2; else if (word == "int") @@ -36,6 +38,8 @@ SpvSnippet::ASMType parseASMType(Slang::Misc::TokenReader& tokenReader) return SpvSnippet::ASMType::UInt; else if (word == "_p") return SpvSnippet::ASMType::FloatOrDouble; + else if (word == "half") + return SpvSnippet::ASMType::Half; return SpvSnippet::ASMType::None; } diff --git a/source/slang/slang-ir-spirv-snippet.h b/source/slang/slang-ir-spirv-snippet.h index c1497496d..4a509a230 100644 --- a/source/slang/slang-ir-spirv-snippet.h +++ b/source/slang/slang-ir-spirv-snippet.h @@ -64,6 +64,8 @@ struct SpvSnippet : public RefObject None, Int, UInt, + UInt16, + Half, Float, Double, FloatOrDouble, // Float or double type, depending on the result type of the intrinsic. @@ -83,6 +85,7 @@ struct SpvSnippet : public RefObject { switch (type) { + case ASMType::Half: case ASMType::Float: case ASMType::Double: case ASMType::Float2: @@ -102,6 +105,7 @@ struct SpvSnippet : public RefObject return false; switch (type) { + case ASMType::Half: case ASMType::Float: case ASMType::Double: case ASMType::FloatOrDouble: @@ -112,6 +116,7 @@ struct SpvSnippet : public RefObject case ASMType::Int: return intValues[0] == other.intValues[0]; case ASMType::UInt: + case ASMType::UInt16: return intValues[0] == other.intValues[0]; case ASMType::UInt2: return intValues[0] == other.intValues[0] && intValues[1] == other.intValues[1]; diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index f96cc174c..467580c83 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -998,6 +998,13 @@ void resetScratchDataBit(IRInst* inst, int bitIndex) } } +IRVarLayout* findVarLayout(IRInst* value) +{ + if (auto layoutDecoration = value->findDecoration<IRLayoutDecoration>()) + return as<IRVarLayout>(layoutDecoration->getLayout()); + return nullptr; +} + UnownedStringSlice getBasicTypeNameHint(IRType* basicType) { switch (basicType->getOp()) diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 57a6c7c92..c107ec24a 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -224,6 +224,8 @@ bool isOne(IRInst* inst); void initializeScratchData(IRInst* inst); void resetScratchDataBit(IRInst* inst, int bitIndex); +IRVarLayout* findVarLayout(IRInst* value); + // Run an operation over every block in a module template<typename F> static void overAllBlocks(IRModule* module, F f) diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp index bf1ce1956..d0948fbb1 100644 --- a/source/slang/slang-ir-validate.cpp +++ b/source/slang/slang-ir-validate.cpp @@ -336,16 +336,18 @@ namespace Slang validateIRInstOperands(context, inst); context->seenInsts.add(inst); + if (auto code = as<IRGlobalValueWithCode>(inst)) + { + context->domTree = computeDominatorTree(code); + validateCodeBody(context, code); + } + // If `inst` is itself a parent instruction, then we need to recursively // validate its children. validateIRInstChildren(context, inst); if (auto code = as<IRGlobalValueWithCode>(inst)) - { - context->domTree = computeDominatorTree(code); - validateCodeBody(context, code); context->domTree = nullptr; - } } void validateIRInst(IRInst* inst) diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 8d36c2e86..181970632 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3771,6 +3771,32 @@ namespace Slang } if (targetVectorType->getElementCount() != sourceVectorType->getElementCount()) { + auto fromCount = as<IRIntLit>(sourceVectorType->getElementCount()); + auto toCount = as<IRIntLit>(targetVectorType->getElementCount()); + if (fromCount && toCount) + { + if (toCount->getValue() < fromCount->getValue()) + { + List<UInt> indices; + for (UInt i = 0; i < (UInt)toCount->getValue(); i++) + indices.add(i); + return emitSwizzle(targetVectorType, value, (UInt)indices.getCount(), indices.getBuffer()); + } + else if (toCount->getValue() > fromCount->getValue()) + { + List<IRInst*> args; + for (UInt i = 0; i < (UInt)fromCount->getValue(); i++) + { + auto element = emitSwizzle(sourceVectorType->getElementType(), value , 1, &i); + args.add(element); + } + for (IRIntegerValue i = fromCount->getValue(); i < toCount->getValue(); i++) + { + args.add(emitDefaultConstruct(targetVectorType->getElementType())); + } + return emitMakeVector(targetVectorType, args); + } + } auto reshape = emitIntrinsicInst( getVectorType( sourceVectorType->getElementType(), targetVectorType->getElementCount()), diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp index 4ec0ac64f..ae1a0e9c2 100644 --- a/source/slang/slang-language-server-completion.cpp +++ b/source/slang/slang-language-server-completion.cpp @@ -60,6 +60,7 @@ static const char* hlslSemanticNames[] = { "SV_IsFrontFace", "SV_OutputControlPointID", "SV_Position", + "SV_PointSize", "SV_PrimitiveID", "SV_RenderTargetArrayIndex", "SV_SampleIndex", |
