summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir.cpp
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2025-01-30 00:59:49 -0800
committerGitHub <noreply@github.com>2025-01-30 00:59:49 -0800
commitba9b2785c69c1b8c6d2b4103267b5281815f9f23 (patch)
treee4ba4ca76c6592b90764a0a7ac32502639dc93aa /source/slang/slang-ir.cpp
parent2ae194d51e15c064c3d905e628f7335de7504e32 (diff)
Support cooperative vector (#6223)
* Support cooperative vector without Vulkan-header update Adding a Slang support for cooperative vector. But this commit doesn't have Vulkan-header update.
Diffstat (limited to 'source/slang/slang-ir.cpp')
-rw-r--r--source/slang/slang-ir.cpp50
1 files changed, 50 insertions, 0 deletions
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 32acc7baa..6a7564e67 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2964,6 +2964,13 @@ IRVectorType* IRBuilder::getVectorType(IRType* elementType, IRIntegerValue eleme
return getVectorType(elementType, getIntValue(getIntType(), elementCount));
}
+IRCoopVectorType* IRBuilder::getCoopVectorType(IRType* elementType, IRInst* elementCount)
+{
+ IRInst* operands[] = {elementType, elementCount};
+ return (IRCoopVectorType*)
+ getType(kIROp_CoopVectorType, sizeof(operands) / sizeof(operands[0]), operands);
+}
+
IRMatrixType* IRBuilder::getMatrixType(
IRType* elementType,
IRInst* rowCount,
@@ -3887,6 +3894,26 @@ IRInst* IRBuilder::emitDefaultConstruct(IRType* type, bool fallback)
return nullptr;
return emitIntrinsicInst(type, kIROp_MakeVectorFromScalar, 1, &inner);
}
+ case kIROp_CoopVectorType:
+ {
+ auto coopVecType = as<IRCoopVectorType>(actualType);
+ if (auto count = as<IRIntLit>(coopVecType->getElementCount()))
+ {
+ auto element = emitDefaultConstruct(coopVecType->getElementType(), fallback);
+ if (!element)
+ return nullptr;
+ List<IRInst*> elements;
+ constexpr int maxCount = 4096;
+ if (count->getValue() > maxCount)
+ break;
+ for (IRIntegerValue i = 0; i < count->getValue(); i++)
+ {
+ elements.add(element);
+ }
+ return emitMakeCoopVector(type, elements.getCount(), elements.getBuffer());
+ }
+ break;
+ }
case kIROp_MatrixType:
{
auto inner =
@@ -4171,6 +4198,18 @@ IRInst* IRBuilder::emitGetNativeString(IRInst* str)
return emitIntrinsicInst(getNativeStringType(), kIROp_getNativeStr, 1, &str);
}
+IRInst* IRBuilder::emitGetElement(IRType* type, IRInst* arrayLikeType, IRIntegerValue element)
+{
+ IRInst* args[] = {arrayLikeType, getIntValue(getIntType(), element)};
+ return emitIntrinsicInst(type, kIROp_GetElement, 2, args);
+}
+
+IRInst* IRBuilder::emitGetElementPtr(IRType* type, IRInst* arrayLikeType, IRIntegerValue element)
+{
+ IRInst* args[] = {arrayLikeType, getIntValue(getIntType(), element)};
+ return emitIntrinsicInst(type, kIROp_GetElementPtr, 2, args);
+}
+
IRInst* IRBuilder::emitGetTupleElement(IRType* type, IRInst* tuple, IRInst* element)
{
IRInst* args[] = {tuple, element};
@@ -4345,6 +4384,11 @@ IRInst* IRBuilder::emitMakeMatrixFromScalar(IRType* type, IRInst* scalarValue)
return emitIntrinsicInst(type, kIROp_MakeMatrixFromScalar, 1, &scalarValue);
}
+IRInst* IRBuilder::emitMakeCoopVector(IRType* type, UInt argCount, IRInst* const* args)
+{
+ return emitIntrinsicInst(type, kIROp_MakeCoopVector, argCount, args);
+}
+
IRInst* IRBuilder::emitMakeArray(IRType* type, UInt argCount, IRInst* const* args)
{
return emitIntrinsicInst(type, kIROp_MakeArray, argCount, args);
@@ -5187,6 +5231,10 @@ IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRInst* index)
{
type = vectorType->getElementType();
}
+ else if (auto coopVecType = as<IRCoopVectorType>(valueType))
+ {
+ type = coopVecType->getElementType();
+ }
else if (auto matrixType = as<IRMatrixType>(valueType))
{
type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount());
@@ -8143,6 +8191,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options)
case kIROp_GetAddr:
case kIROp_GetValueFromBoundInterface:
case kIROp_MakeUInt64:
+ case kIROp_MakeCoopVector:
case kIROp_MakeVector:
case kIROp_MakeMatrix:
case kIROp_MakeMatrixFromScalar:
@@ -8627,6 +8676,7 @@ bool isMovableInst(IRInst* inst)
switch (inst->getOp())
{
+ case kIROp_MakeCoopVector:
case kIROp_Add:
case kIROp_Sub:
case kIROp_Mul: