diff options
| author | Jay Kwak <82421531+jkwak-work@users.noreply.github.com> | 2025-01-30 00:59:49 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-30 00:59:49 -0800 |
| commit | ba9b2785c69c1b8c6d2b4103267b5281815f9f23 (patch) | |
| tree | e4ba4ca76c6592b90764a0a7ac32502639dc93aa /source/slang/slang-emit-spirv.cpp | |
| parent | 2ae194d51e15c064c3d905e628f7335de7504e32 (diff) | |
Support cooperative vector (#6223)
* Support cooperative vector without Vulkan-header update
Adding a Slang support for cooperative vector.
But this commit doesn't have Vulkan-header update.
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 67 |
1 files changed, 64 insertions, 3 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 24d8cc0c6..2c36ae5f7 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1628,6 +1628,16 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex static_cast<IRIntLit*>(vectorType->getElementCount())->getValue(), vectorType); } + case kIROp_CoopVectorType: + { + auto coopVecType = static_cast<IRCoopVectorType*>(inst); + requireSPIRVCapability(SpvCapabilityCooperativeVectorNV); + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_cooperative_vector")); + return ensureCoopVecType( + static_cast<IRBasicType*>(coopVecType->getElementType())->getBaseType(), + static_cast<IRIntLit*>(coopVecType->getElementCount())->getValue(), + coopVecType); + } case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(inst); @@ -1778,6 +1788,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex numElems->getValue()); } case kIROp_MakeVector: + case kIROp_MakeCoopVector: case kIROp_MakeArray: case kIROp_MakeStruct: return emitCompositeConstruct(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst); @@ -2361,6 +2372,27 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return result; } + /// Similar to ensureVectorType but for CoopVecType + SpvInst* ensureCoopVecType( + BaseType baseType, + IRIntegerValue elementCount, + IRCoopVectorType* inst) + { + IRBuilder builder(m_irModule); + if (!inst) + { + builder.setInsertInto(m_irModule->getModuleInst()); + inst = builder.getCoopVectorType( + builder.getBasicType(baseType), + builder.getIntValue(builder.getIntType(), elementCount)); + } + auto result = emitOpTypeCoopVec( + inst, + inst->getElementType(), + emitIntConstant(elementCount, builder.getIntType())); + return result; + } + bool _maybeEmitInterpolationModifierDecoration(IRInterpolationMode mode, SpvId varInst) { switch (mode) @@ -3417,6 +3449,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_StructuredBufferGetDimensions: result = emitStructuredBufferGetDimensions(parent, inst); break; + case kIROp_GetStructuredBufferPtr: + case kIROp_GetUntypedBufferPtr: + result = emitGetBufferPtr(parent, inst); + break; case kIROp_swizzle: result = emitSwizzle(parent, as<IRSwizzle>(inst)); break; @@ -3711,6 +3747,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex result = emitSplat(parent, inst, scalar, numElems->getValue()); } break; + case kIROp_MakeCoopVector: + result = emitConstruct(parent, inst); + break; case kIROp_MakeArray: result = emitConstruct(parent, inst); break; @@ -6079,7 +6118,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto baseTy = base->getDataType(); SLANG_ASSERT( as<IRPointerLikeType>(baseTy) || as<IRArrayType>(baseTy) || as<IRVectorType>(baseTy) || - as<IRMatrixType>(baseTy)); + as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy)); IRBuilder builder(m_irModule); builder.setInsertBefore(inst); @@ -6097,7 +6136,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } else { - SLANG_ASSERT(as<IRVectorType>(baseTy)); + SLANG_ASSERT(as<IRVectorType>(baseTy) || as<IRCoopVectorType>(baseTy)); // SPIRV Only allows dynamic element extract on vector types. return emitOpVectorExtractDynamic( parent, @@ -6307,6 +6346,27 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return result; } + SpvInst* emitGetBufferPtr(SpvInstParent* parent, IRInst* inst) + { + IRBuilder builder(inst); + auto addressSpace = + isSpirv14OrLater() ? AddressSpace::StorageBuffer : AddressSpace::Uniform; + // The buffer is a global parameter, so it's a pointer + IRPtrTypeBase* bufPtrType = cast<IRPtrTypeBase>(inst->getOperand(0)->getDataType()); + // It's lowered to a struct type.. + IRStructType* bufType = cast<IRStructType>(bufPtrType->getValueType()); + // containing an unsized array, specifically one with an explicit + // stride, which is not expressible in spirv_asm blocks + IRArrayTypeBase* arrayType = + cast<IRArrayTypeBase>(bufType->getFields().getFirst()->getFieldType()); + return emitOpAccessChain( + parent, + inst, + builder.getPtrType(arrayType, addressSpace), + inst->getOperand(0), + makeArray(emitIntConstant(0, builder.getIntType()))); + } + SpvInst* emitSwizzle(SpvInstParent* parent, IRSwizzle* inst) { if (inst->getElementCount() == 1) @@ -6478,7 +6538,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex IRType* toType = nullptr; bool isMatrixCast = false; - if (as<IRVectorType>(fromTypeV) || as<IRVectorType>(toTypeV)) + if (as<IRVectorType>(fromTypeV) || as<IRVectorType>(toTypeV) || + as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV)) { fromType = getVectorElementType(fromTypeV); toType = getVectorElementType(toTypeV); |
