summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2025-01-30 00:59:49 -0800
committerGitHub <noreply@github.com>2025-01-30 00:59:49 -0800
commitba9b2785c69c1b8c6d2b4103267b5281815f9f23 (patch)
treee4ba4ca76c6592b90764a0a7ac32502639dc93aa /source/slang/slang-emit-spirv.cpp
parent2ae194d51e15c064c3d905e628f7335de7504e32 (diff)
Support cooperative vector (#6223)
* Support cooperative vector without Vulkan-header update Adding a Slang support for cooperative vector. But this commit doesn't have Vulkan-header update.
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp67
1 files changed, 64 insertions, 3 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 24d8cc0c6..2c36ae5f7 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1628,6 +1628,16 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
static_cast<IRIntLit*>(vectorType->getElementCount())->getValue(),
vectorType);
}
+ case kIROp_CoopVectorType:
+ {
+ auto coopVecType = static_cast<IRCoopVectorType*>(inst);
+ requireSPIRVCapability(SpvCapabilityCooperativeVectorNV);
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_cooperative_vector"));
+ return ensureCoopVecType(
+ static_cast<IRBasicType*>(coopVecType->getElementType())->getBaseType(),
+ static_cast<IRIntLit*>(coopVecType->getElementCount())->getValue(),
+ coopVecType);
+ }
case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(inst);
@@ -1778,6 +1788,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
numElems->getValue());
}
case kIROp_MakeVector:
+ case kIROp_MakeCoopVector:
case kIROp_MakeArray:
case kIROp_MakeStruct:
return emitCompositeConstruct(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst);
@@ -2361,6 +2372,27 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return result;
}
+ /// Similar to ensureVectorType but for CoopVecType
+ SpvInst* ensureCoopVecType(
+ BaseType baseType,
+ IRIntegerValue elementCount,
+ IRCoopVectorType* inst)
+ {
+ IRBuilder builder(m_irModule);
+ if (!inst)
+ {
+ builder.setInsertInto(m_irModule->getModuleInst());
+ inst = builder.getCoopVectorType(
+ builder.getBasicType(baseType),
+ builder.getIntValue(builder.getIntType(), elementCount));
+ }
+ auto result = emitOpTypeCoopVec(
+ inst,
+ inst->getElementType(),
+ emitIntConstant(elementCount, builder.getIntType()));
+ return result;
+ }
+
bool _maybeEmitInterpolationModifierDecoration(IRInterpolationMode mode, SpvId varInst)
{
switch (mode)
@@ -3417,6 +3449,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case kIROp_StructuredBufferGetDimensions:
result = emitStructuredBufferGetDimensions(parent, inst);
break;
+ case kIROp_GetStructuredBufferPtr:
+ case kIROp_GetUntypedBufferPtr:
+ result = emitGetBufferPtr(parent, inst);
+ break;
case kIROp_swizzle:
result = emitSwizzle(parent, as<IRSwizzle>(inst));
break;
@@ -3711,6 +3747,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
result = emitSplat(parent, inst, scalar, numElems->getValue());
}
break;
+ case kIROp_MakeCoopVector:
+ result = emitConstruct(parent, inst);
+ break;
case kIROp_MakeArray:
result = emitConstruct(parent, inst);
break;
@@ -6079,7 +6118,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
const auto baseTy = base->getDataType();
SLANG_ASSERT(
as<IRPointerLikeType>(baseTy) || as<IRArrayType>(baseTy) || as<IRVectorType>(baseTy) ||
- as<IRMatrixType>(baseTy));
+ as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy));
IRBuilder builder(m_irModule);
builder.setInsertBefore(inst);
@@ -6097,7 +6136,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
}
else
{
- SLANG_ASSERT(as<IRVectorType>(baseTy));
+ SLANG_ASSERT(as<IRVectorType>(baseTy) || as<IRCoopVectorType>(baseTy));
// SPIRV Only allows dynamic element extract on vector types.
return emitOpVectorExtractDynamic(
parent,
@@ -6307,6 +6346,27 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return result;
}
+ SpvInst* emitGetBufferPtr(SpvInstParent* parent, IRInst* inst)
+ {
+ IRBuilder builder(inst);
+ auto addressSpace =
+ isSpirv14OrLater() ? AddressSpace::StorageBuffer : AddressSpace::Uniform;
+ // The buffer is a global parameter, so it's a pointer
+ IRPtrTypeBase* bufPtrType = cast<IRPtrTypeBase>(inst->getOperand(0)->getDataType());
+ // It's lowered to a struct type..
+ IRStructType* bufType = cast<IRStructType>(bufPtrType->getValueType());
+ // containing an unsized array, specifically one with an explicit
+ // stride, which is not expressible in spirv_asm blocks
+ IRArrayTypeBase* arrayType =
+ cast<IRArrayTypeBase>(bufType->getFields().getFirst()->getFieldType());
+ return emitOpAccessChain(
+ parent,
+ inst,
+ builder.getPtrType(arrayType, addressSpace),
+ inst->getOperand(0),
+ makeArray(emitIntConstant(0, builder.getIntType())));
+ }
+
SpvInst* emitSwizzle(SpvInstParent* parent, IRSwizzle* inst)
{
if (inst->getElementCount() == 1)
@@ -6478,7 +6538,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
IRType* toType = nullptr;
bool isMatrixCast = false;
- if (as<IRVectorType>(fromTypeV) || as<IRVectorType>(toTypeV))
+ if (as<IRVectorType>(fromTypeV) || as<IRVectorType>(toTypeV) ||
+ as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV))
{
fromType = getVectorElementType(fromTypeV);
toType = getVectorElementType(toTypeV);