summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2025-04-15 03:33:41 -0700
committerGitHub <noreply@github.com>2025-04-15 10:33:41 +0000
commita6174ff9443507dece534aa193f8c45e8f0ce7db (patch)
treed488be53789c67a8b190067bc5711d5565ddecc2
parent5902ee7f3822178c601edcb128a846c502bca3de (diff)
Document CoopVec functions (#6777)
Documenting CoopVec related functions. This commit also fixes a few warning printed from the doc generation tool. Some of comments are removed or converted from /// to //, because the overloading functions can have /// style comment only once.
-rw-r--r--source/slang/hlsl.meta.slang306
1 files changed, 241 insertions, 65 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index ae1f6da98..bdaa2bad0 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -11964,6 +11964,7 @@ T NonUniformResourceIndex(T index);
/// HLSL allows NonUniformResourceIndex around non int/uint types.
/// It's effect is presumably to ignore it, which the following implementation does.
/// We should also look to add a warning for this scenario.
+/// @deprecated
[__unsafeForceInlineEarly]
[deprecated("NonUniformResourceIndex on a type other than uint/int is deprecated and has no effect")]
T NonUniformResourceIndex<T>(T value) { return value; }
@@ -20536,11 +20537,11 @@ void ReorderThread( HitObject HitOrMiss, uint CoherenceHint, uint NumCoherenceHi
}
}
- /// Is equivalent to
- /// ```
- /// void ReorderThread( HitObject HitOrMiss, uint CoherenceHint, uint NumCoherenceHintBitsFromLSB );
- /// ```
- /// With CoherenceHint and NumCoherenceHintBitsFromLSB as 0, meaning they are ignored.
+ // Is equivalent to
+ // ```
+ // void ReorderThread( HitObject HitOrMiss, uint CoherenceHint, uint NumCoherenceHintBitsFromLSB );
+ // ```
+ // With CoherenceHint and NumCoherenceHintBitsFromLSB as 0, meaning they are ignored.
[__requiresNVAPI]
__glsl_extension(GL_EXT_ray_tracing)
@@ -22011,6 +22012,11 @@ extension<T, L : IBufferDataLayout> RasterizerOrderedStructuredBuffer<T, L> : IR
// Cooperative Vector
//
+/// Represents a Cooperative Vector type that is for matrix-vector multiplication that
+/// can take an advantage of the hardware acceleration. It can be used for evaluations
+/// of neural network in graphics and compute pipeline.
+/// @param T The element type of the CoopVec.
+/// @param N The vector size.
__intrinsic_type($(kIROp_CoopVectorType))
[require(cooperative_vector)]
struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmetic
@@ -22025,12 +22031,14 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
{
this = CoopVec<T, N>(T(0));
}
+
[ForceInline]
[require(cooperative_vector)]
__init(T t)
{
this.fill(t);
}
+
[ForceInline]
[require(cooperative_vector)]
__init<U : __BuiltinArithmeticType>(CoopVec<U, N> other)
@@ -22045,12 +22053,14 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
static_assert(countof(U) == N, "number of arguments to CoopVec constructor must match number of elements");
this = __makeCoopVec<T, N>(expand (__arithmetic_cast<T>(each args)));
}
+
[OverloadRank(-10)]
[ForceInline]
__init(int i)
{
this = CoopVec<T, N>(T(i));
}
+
[ForceInline]
__init(This x)
{
@@ -22061,6 +22071,9 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
// Simple setters
//
+ /// Copy values from another CoopVec instance into this one. The source CoopVec can have a different element type,
+ /// in which case appropriate type conversion will be performed.
+ /// @param other The source CoopVec to copy from.
[require(hlsl)]
[mutating]
[ForceInline]
@@ -22081,6 +22094,8 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
}
}
+ /// Fill all elements of this CoopVec with the specified value.
+ /// @param t The value to fill all elements with.
[require(cooperative_vector)]
[mutating]
[ForceInline]
@@ -22106,6 +22121,9 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
// Loading and storing
//
+ /// Store all elements of this CoopVec into a buffer at a specified offset.
+ /// @param buffer The destination buffer to store the values into.
+ /// @param byteOffset16ByteAligned The byte offset from the start of the buffer where the data will be stored. Must be 16-byte aligned.
[ForceInline]
[require(cooperative_vector)]
void store(RWByteAddressBuffer buffer, int32_t byteOffset16ByteAligned = 0)
@@ -22119,9 +22137,10 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
// TODO: Should this be a byte offset
OpCooperativeVectorStoreNV $ptr $byteOffset16ByteAligned $this None;
};
- // Not supported
- // case hlsl:
- // this.__Store(buffer, byteOffset16ByteAligned);
+#ifdef NOT_SUPPORTED_YET
+ case hlsl:
+ this.__Store(buffer, byteOffset16ByteAligned);
+#endif
default:
for(int i = 0; i < N; ++i)
buffer.StoreByteOffset(byteOffset16ByteAligned + __elemToByteOffset<T>(i), this[i]);
@@ -22141,9 +22160,10 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
// TODO: Should this be a byte offset
OpCooperativeVectorStoreNV $ptr $byteOffset16ByteAligned $this None;
};
- // Not supported
- // case hlsl:
- // this.__Store(buffer, byteOffset16ByteAligned);
+#ifdef NOT_SUPPORTED_YET
+ case hlsl:
+ this.__Store(buffer, byteOffset16ByteAligned);
+#endif
default:
for(int i = 0; i < N; ++i)
buffer[i + __byteToElemOffset<T>(byteOffset16ByteAligned)] = this[i];
@@ -22168,7 +22188,9 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
}
}
- /// spirv only storing to a groupshared array of any type
+ /// Store the value to a groupshared array of any type. This method is only available when targeting SPIR-V.
+ /// @param data The destination array where the data will be stored. The array element type can be different from the CoopVec element type.
+ /// @param byteOffset16ByteAligned The byte offset from the start of `data`. Must be a multiple of 16 bytes.
[ForceInline]
[require(spirv, cooperative_vector)]
void storeAny<U, let M : int>(__ref groupshared U[M] data, int32_t byteOffset16ByteAligned = 0)
@@ -22181,6 +22203,7 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
};
}
}
+
[ForceInline]
[require(spirv, cooperative_vector)]
void storeAny<U, let M : int, let L : int>(__ref groupshared vector<U, L>[M] data, int32_t byteOffset16ByteAligned = 0)
@@ -22194,6 +22217,10 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
}
}
+ /// Load values from a byte-addressable buffer into a cooperative vector.
+ /// @param buffer The source buffer to load data from.
+ /// @param byteOffset16ByteAligned The byte offset from the start of the buffer. Must be 16-byte aligned.
+ /// @return A new cooperative vector containing the loaded values.
[ForceInline]
[__NoSideEffect]
[require(cooperative_vector)]
@@ -22259,11 +22286,12 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
{
result:$$CoopVec<T, N> = OpCooperativeVectorLoadNV $ptr $byteOffset16ByteAligned None;
};
- // Not supported
- // case hlsl:
- // CoopVec<T, N> ret;
- // ret.__Load(buffer, byteOffset16ByteAligned);
- // return ret;
+#ifdef NOT_SUPPORTED_YET
+ case hlsl:
+ CoopVec<T, N> ret;
+ ret.__Load(buffer, byteOffset16ByteAligned);
+ return ret;
+#endif
default:
var vec = CoopVec<T, N>();
for(int i = 0; i < N; ++i)
@@ -22286,11 +22314,12 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
{
result:$$CoopVec<T, N> = OpCooperativeVectorLoadNV $ptr $byteOffset16ByteAligned None;
};
- // Not supported
- // case hlsl:
- // CoopVec<T, N> ret;
- // ret.__Load(buffer, byteOffset16ByteAligned);
- // return ret;
+#ifdef NOT_SUPPORTED_YET
+ case hlsl:
+ CoopVec<T, N> ret;
+ ret.__Load(buffer, byteOffset16ByteAligned);
+ return ret;
+#endif
default:
var vec = CoopVec<T, N>();
for(int i = 0; i < N; ++i)
@@ -22323,7 +22352,11 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
}
}
- /// spirv only loading from a groupshared array of any type
+ /// Load values from a groupshared array into a CoopVec, allowing type conversion between source and destination elements.
+ /// This operation is only available when targeting SPIR-V.
+ /// @param data The source groupshared array to load from. The element type U can be different from the CoopVec element type T.
+ /// @param byteOffset16ByteAligned The byte offset from the start of the array. Must be 16-byte aligned.
+ /// @return A new CoopVec containing the loaded and type-converted values.
[ForceInline]
[__NoSideEffect]
[require(spirv, cooperative_vector)]
@@ -22337,6 +22370,7 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
};
}
}
+
[ForceInline]
[__NoSideEffect]
[require(spirv, cooperative_vector)]
@@ -22371,8 +22405,10 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
return N;
}
+ /// Access an individual element in the Cooperative vector by index.
__subscript(int index) -> T
{
+ [ForceInline]
[__NoSideEffect]
[nonmutating]
get
@@ -22384,6 +22420,7 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
}
}
+ [ForceInline]
[mutating]
set
{
@@ -22402,6 +22439,9 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
// ref;
}
+ /// Creates a new cooperative vector with all elements initialized to the specified scalar value.
+ /// @param t The scalar value to replicate across all elements.
+ /// @return A new cooperative vector where each element equals the input value.
static CoopVec<T, N> replicate(T t)
{
CoopVec<T, N> ret;
@@ -22410,9 +22450,12 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
}
//
- // Equality and ordering
+ // IComparable
//
+ /// Checks if this cooperative vector is equal to another cooperative vector by comparing all elements.
+ /// @param other The cooperative vector to compare against.
+ /// @return True if all corresponding elements are equal, false otherwise.
bool equals(This other)
{
for (int i = 0; i < N; i++)
@@ -22424,6 +22467,13 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
}
return true;
}
+
+ /// Compares two cooperative vectors lexicographically.
+ /// @param other The cooperative vector to compare against.
+ /// @return True if this vector is lexicographically less than the other vector.
+ /// @remarks This function exists only to conform to IComparable. For cooperative vectors,
+ /// lexicographical comparison has limited practical use since the vectors are meant for
+ /// parallel computation rather than ordering.
bool lessThan(This other)
{
for (int i = 0; i < N; i++)
@@ -22439,6 +22489,13 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
}
return false;
}
+
+ /// Compares two cooperative vectors lexicographically.
+ /// @param other The cooperative vector to compare against.
+ /// @return True if this vector is lexicographically less than or equal to the other vector.
+ /// @remarks This function exists only to conform to IComparable. For cooperative vectors,
+ /// lexicographical comparison has limited practical use since the vectors are meant for
+ /// parallel computation rather than ordering.
bool lessThanOrEquals(This other)
{
for (int i = 0; i < N; i++)
@@ -22472,6 +22529,9 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
}
}
+ /// Performs component-wise addition with another cooperative vector.
+ /// @param other The cooperative vector to add to this vector.
+ /// @return A new cooperative vector containing the sum of the two vectors.
// TODO: Why is this ForceInline necessary for hlsl, dxc bug?
[ForceInline]
This add(This other)
@@ -22488,11 +22548,15 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
__intrinsic_op($(kIROp_Sub))
This __pureSub(This other);
+
[mutating]
[require(hlsl)]
void __mutSub(This other)
{ __target_switch { case hlsl: __intrinsic_asm ".Subtract"; } }
+ /// Performs component-wise subtraction with another cooperative vector.
+ /// @param other The cooperative vector to subtract from this vector.
+ /// @return A new cooperative vector containing the difference of the two vectors.
[ForceInline]
This sub(This other)
{
@@ -22514,6 +22578,9 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
void __mutMul(This other)
{ __target_switch { case hlsl: __intrinsic_asm ".Multiply"; } }
+ /// Performs component-wise multiplication with another cooperative vector.
+ /// @param other The cooperative vector to multiply with this vector.
+ /// @return A new cooperative vector containing the product of the two vectors.
[ForceInline]
This mul(This other)
{
@@ -22535,6 +22602,9 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
void __mutDiv(This other)
{ __target_switch { case hlsl: __intrinsic_asm ".Divide"; } }
+ /// Performs component-wise division with another cooperative vector.
+ /// @param other The cooperative vector to divide this vector by.
+ /// @return A new cooperative vector containing the quotient of the two vectors.
[ForceInline]
This div(This other)
{
@@ -22553,6 +22623,9 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
void __mutMod(This other)
{ __target_switch { case hlsl: __intrinsic_asm ".Mod"; } }
+ /// Performs component-wise remainder operation between two cooperative vectors.
+ /// @param other The cooperative vector to compute the remainder with.
+ /// @return A new cooperative vector containing the remainder of the division between corresponding components.
[ForceInline]
This mod(This other)
{
@@ -22573,6 +22646,8 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
__intrinsic_op($(kIROp_Neg))
This __pureNeg(This other);
+ /// Returns a new cooperative vector where each component has its sign negated.
+ /// @return A new cooperative vector containing the negated values.
//[ForceInline]
This neg()
{
@@ -22596,10 +22671,12 @@ struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmeti
[require(hlsl)]
void __mutMin(This other)
{ __target_switch { case hlsl: __intrinsic_asm ".Min"; } }
+
[mutating]
[require(hlsl)]
void __mutMax(This other)
{ __target_switch { case hlsl: __intrinsic_asm ".Max"; } }
+
[mutating]
[require(hlsl)]
void __mutClamp(This minVal, This maxVal)
@@ -22674,6 +22751,19 @@ for(auto buffer : kByteAddressBufferCases) {
}
}
+ /// Multiply the given input Cooperative vector with the given matrix and accumulate the result into this vector.
+ /// @param input The input Cooperative vector to multiply with the matrix.
+ /// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as packed values).
+ /// @param k The number of columns in the matrix.
+ /// @param matrix The matrix buffer to multiply with.
+ /// @param matrixOffset Byte offset into the matrix buffer.
+ /// @param matrixInterpretation Specifies how to interpret the values in the matrix.
+ /// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major).
+ /// @param transpose Whether to transpose the matrix before multiplication.
+ /// @param matrixStride The stride between matrix rows/columns in bytes.
+ /// @remarks Unlike matMulAccum, this function supports packed input interpretations where multiple values
+ /// can be packed into each element of the input vector. The k parameter specifies the actual number of
+ /// values to use from the packed input.
[mutating]
[ForceInline]
void matMulAccumPacked<U : __BuiltinArithmeticType, let PackedK : int>(
@@ -22728,6 +22818,15 @@ for(auto buffer : kByteAddressBufferCases) {
}
}
+ /// Accumulate the result from a matrix multiplication between an input Cooperative vector and a matrix.
+ /// @param input The input Cooperative vector to multiply with the matrix.
+ /// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as 8-bit integers, 16-bit floats, etc).
+ /// @param matrix The matrix to multiply with the input vector.
+ /// @param matrixOffset Byte offset into the matrix buffer.
+ /// @param matrixInterpretation Specifies how to interpret the values in the matrix (e.g. as 8-bit integers, 16-bit floats, etc).
+ /// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major).
+ /// @param transpose Whether to transpose the matrix before multiplication.
+ /// @param matrixStride The stride in bytes between rows/columns of the matrix.
[mutating]
[ForceInline]
void matMulAccum<U : __BuiltinArithmeticType, let K : int>(
@@ -22820,6 +22919,20 @@ for(auto buffer : kByteAddressBufferCases) {
}
}
+ /// Performs matrix multiplication and accumulation with bias: this += input * matrix + bias
+ /// @param input The input vector to multiply with the matrix
+ /// @param inputInterpretation How to interpret the input vector elements (must not be packed)
+ /// @param matrix The matrix buffer to multiply with
+ /// @param matrixOffset Byte offset into the matrix buffer
+ /// @param matrixInterpretation How to interpret the matrix elements
+ /// @param bias The bias buffer to add
+ /// @param biasOffset Byte offset into the bias buffer
+ /// @param biasInterpretation How to interpret the bias elements
+ /// @param memoryLayout Memory layout of the matrix (row or column major)
+ /// @param transpose Whether to transpose the matrix before multiplication
+ /// @param matrixStride Stride between matrix rows/columns in bytes
+ /// @remark The key difference from matMulAddAccumPacked is that this method enforces k must equal the input vector length,
+ /// while matMulAddAccumPacked allows k to be specified independently for packed interpretations.
[mutating]
[ForceInline]
void matMulAddAccum<U : __BuiltinArithmeticType, let K : int>(
@@ -23278,6 +23391,10 @@ extension RWByteAddressBuffer : IRWPhysicalBuffer
// element type for structured buffers and groupshared arrays (and ByteAddressBuffers for consistency
//
+/// Load values from a byte-addressable buffer into a cooperative vector.
+/// @param buffer The source buffer to load data from.
+/// @param byteOffset16ByteAligned The byte offset from the start of the buffer. Must be 16-byte aligned.
+/// @return A new cooperative vector containing the loaded values.
[ForceInline]
[require(cooperative_vector)]
CoopVec<T, N> coopVecLoad<let N : int, T : __BuiltinArithmeticType>(ByteAddressBuffer buffer, int32_t byteOffset16ByteAligned = 0)
@@ -23318,6 +23435,9 @@ CoopVec<T, N> coopVecLoadGroupshared<let N : int, T : __BuiltinArithmeticType, l
// Coop Vector matrix multiplication
//
+/// Specifies the memory layout for matrices used in cooperative vector operations.
+/// @remarks This enum defines different matrix layout options that affect how matrix data is stored and accessed,
+/// including standard row-major and column-major layouts as well as specialized layouts optimized for specific operations.
enum CoopVecMatrixLayout
{
RowMajor,
@@ -23326,6 +23446,9 @@ enum CoopVecMatrixLayout
TrainingOptimal
};
+/// Specifies how to interpret the values in a cooperative vector or matrix.
+/// @remarks This enum defines the various data types that can be used for elements in cooperative vectors and matrices,
+/// including packed formats where multiple values can be stored in a single element.
enum CoopVecComponentType
{
FloatE4M3,
@@ -23528,6 +23651,20 @@ static const struct {
for(auto buffer : kByteAddressBufferCases_) {
}}}}
+/// Multiply a cooperative vector with a matrix and return the result.
+/// @param input The input cooperative vector to multiply with the matrix.
+/// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as packed values).
+/// @param k The number of columns in the matrix.
+/// @param matrix The matrix buffer to multiply with.
+/// @param matrixOffset Byte offset into the matrix buffer.
+/// @param matrixInterpretation Specifies how to interpret the values in the matrix.
+/// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major).
+/// @param transpose Whether to transpose the matrix before multiplication.
+/// @param matrixStride The stride between matrix rows/columns in bytes.
+/// @return A new cooperative vector containing the result of the matrix multiplication.
+/// @remarks Unlike coopVecMatMul, this function supports packed input interpretations where multiple values
+/// can be packed into each element of the input vector. The k parameter specifies the actual number of
+/// values to use from the packed input.
// TODO: Can we ForceInline for just hlsl? the other platforms don't really
// need it
[ForceInline]
@@ -23663,6 +23800,16 @@ CoopVec<T, M> coopVecMatMulPacked(
}
}
+/// Multiply a cooperative vector with a matrix.
+/// @param input The input cooperative vector to multiply with the matrix.
+/// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as 8-bit integers, 16-bit floats, etc).
+/// @param matrix The matrix to multiply with the input vector.
+/// @param matrixOffset Byte offset into the matrix buffer.
+/// @param matrixInterpretation Specifies how to interpret the values in the matrix (e.g. as 8-bit integers, 16-bit floats, etc).
+/// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major).
+/// @param transpose Whether to transpose the matrix before multiplication.
+/// @param matrixStride The stride in bytes between rows/columns of the matrix.
+/// @return A new cooperative vector containing the result of the matrix multiplication.
[ForceInline]
[require(cooperative_vector)]
__generic<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType>
@@ -23692,6 +23839,23 @@ CoopVec<T, M> coopVecMatMul(
matrixStride);
}
+/// Multiply a cooperative vector with a matrix and add a bias vector.
+/// @param input The input cooperative vector to multiply with the matrix.
+/// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as packed values).
+/// @param k The number of columns in the matrix.
+/// @param matrix The matrix buffer to multiply with.
+/// @param matrixOffset Byte offset into the matrix buffer.
+/// @param matrixInterpretation Specifies how to interpret the values in the matrix.
+/// @param bias The bias buffer to add after multiplication.
+/// @param biasOffset Byte offset into the bias buffer.
+/// @param biasInterpretation Specifies how to interpret the values in the bias vector.
+/// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major).
+/// @param transpose Whether to transpose the matrix before multiplication.
+/// @param matrixStride The stride between matrix rows/columns in bytes.
+/// @return A new cooperative vector containing the result of the matrix multiplication with added bias.
+/// @remarks Unlike coopVecMatMulAdd, this function supports packed input interpretations where multiple values
+/// can be packed into each element of the input vector. The k parameter specifies the actual number of
+/// values to use from the packed input.
[ForceInline]
[require(cooperative_vector)]
CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType>(
@@ -23816,6 +23980,19 @@ CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, l
}
}
+/// Multiply a cooperative vector with a matrix and add a bias vector.
+/// @param input The input cooperative vector to multiply with the matrix.
+/// @param inputInterpretation Specifies how to interpret the values in the input vector.
+/// @param matrix The matrix buffer to multiply with.
+/// @param matrixOffset Byte offset into the matrix buffer.
+/// @param matrixInterpretation Specifies how to interpret the values in the matrix.
+/// @param bias The bias buffer to add after multiplication.
+/// @param biasOffset Byte offset into the bias buffer.
+/// @param biasInterpretation Specifies how to interpret the values in the bias vector.
+/// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major).
+/// @param transpose Whether to transpose the matrix before multiplication.
+/// @param matrixStride The stride between matrix rows/columns in bytes.
+/// @return A new cooperative vector containing the result of the matrix multiplication plus bias.
[ForceInline]
[require(cooperative_vector)]
__generic<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType>
@@ -23859,6 +24036,14 @@ ${{{{
if(buffer.isRW)
{
}}}}
+/// Accumulate the outer product of two cooperative vectors into a matrix.
+/// @param a The first cooperative vector.
+/// @param b The second cooperative vector.
+/// @param matrix The matrix buffer to accumulate the result into.
+/// @param matrixOffset Byte offset into the matrix buffer.
+/// @param matrixStride The stride between matrix rows/columns in bytes.
+/// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major).
+/// @param matrixInterpretation Specifies how to interpret the values in the matrix.
[require(cooperative_vector)]
void coopVecOuterProductAccumulate<T : __BuiltinArithmeticType, let M : int, let N : int>(
CoopVec<T, M> a,
@@ -23940,6 +24125,10 @@ void coopVecOuterProductAccumulate<T : __BuiltinArithmeticType, let M : int, let
}
}
+/// Accumulate the sum of a cooperative vector into a buffer at the specified offset.
+/// @param v The cooperative vector to sum.
+/// @param buffer The buffer to accumulate the sum into.
+/// @param offset Byte offset into the buffer.
[require(cooperative_vector)]
void coopVecReduceSumAccumulate<T : __BuiltinArithmeticType, let N : int>(
CoopVec<T, N> v,
@@ -23986,6 +24175,20 @@ static const struct {
for(auto buffer : kStructuredBufferCases_) {
}}}}
+/// Multiply a cooperative vector with a matrix and return the result.
+/// @param input The input cooperative vector to multiply with the matrix.
+/// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as packed values).
+/// @param k The number of columns in the matrix.
+/// @param matrix The matrix buffer to multiply with.
+/// @param matrixOffset Byte offset into the matrix buffer.
+/// @param matrixInterpretation Specifies how to interpret the values in the matrix.
+/// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major).
+/// @param transpose Whether to transpose the matrix before multiplication.
+/// @param matrixStride The stride between matrix rows/columns in bytes.
+/// @return A new cooperative vector containing the result of the matrix multiplication.
+/// @remarks Unlike coopVecMatMul, this function supports packed input interpretations where multiple values
+/// can be packed into each element of the input vector. The k parameter specifies the actual number of
+/// values to use from the packed input.
[require(spirv, cooperative_vector)]
__generic<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType,IgnoredBufferElementType>
CoopVec<T, M> coopVecMatMulPacked(
@@ -24288,7 +24491,7 @@ uint32_t4 unpack_u8u32(uint8_t4_packed packed)
return unpackUint4x8ToUint32(packed);
}
-/// Pack a vector of 4 unsigned 32 bit integers into a packed value of 4 8-bit integers, dropping unused bits.
+/// Pack a vector of 4 unsigned 32/16 bit integers into a packed value of 4 8-bit integers, dropping unused bits.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -24297,7 +24500,7 @@ uint8_t4_packed pack_u8(uint32_t4 unpackedValue)
return packUint4x8(unpackedValue);
}
-/// Pack a vector of 4 signed 32 bit integers into a packed value of 4 8-bit integers, dropping unused bits.
+/// Pack a vector of 4 signed 32/16 bit integers into a packed value of 4 8-bit integers, dropping unused bits.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -24306,7 +24509,6 @@ int8_t4_packed pack_s8(int32_t4 unpackedValue)
return packInt4x8(unpackedValue);
}
-/// Pack a vector of 4 unsigned 16 bit integers into a packed value of 4 8-bit integers, dropping unused bits.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -24315,7 +24517,6 @@ uint8_t4_packed pack_u8(uint16_t4 unpackedValue)
return packUint4x8(unpackedValue);
}
-/// Pack a vector of 4 signed 16 bit integers into a packed value of 4 8-bit integers, dropping unused bits.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -24324,7 +24525,7 @@ int8_t4_packed pack_s8(int16_t4 unpackedValue)
return packInt4x8(unpackedValue);
}
-/// Pack a vector of 4 unsigned 32 bit integers into a packed value of 4 8-bit integers,
+/// Pack a vector of 4 unsigned 32/16 bit integers into a packed value of 4 8-bit integers,
/// clamping each value to the range [0, 255] to ensure it fits within 8 bits.
[__readNone]
[ForceInline]
@@ -24334,7 +24535,7 @@ uint8_t4_packed pack_clamp_u8(int32_t4 unpackedValue)
return packUint4x8Clamp(unpackedValue);
}
-/// Pack a vector of 4 signed 32 bit integers into a packed value of 4 8-bit integers,
+/// Pack a vector of 4 signed 32/16 bit integers into a packed value of 4 8-bit integers,
/// clamping each value to the range [-128, 127] to ensure it fits within 8 bits.
[__readNone]
[ForceInline]
@@ -24344,8 +24545,6 @@ int8_t4_packed pack_clamp_s8(int32_t4 unpackedValue)
return packInt4x8Clamp(unpackedValue);
}
-/// Pack a vector of 4 unsigned 16 bit integers into a packed value of 4 8-bit integers,
-/// clamping each value to the range [0, 255] to ensure it fits within 8 bits.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -24354,8 +24553,6 @@ uint8_t4_packed pack_clamp_u8(int16_t4 unpackedValue)
return packUint4x8Clamp(unpackedValue);
}
-/// Pack a vector of 4 signed 16 bit integers into a packed value of 4 8-bit integers,
-/// clamping each value to the range [-128, 127] to ensure it fits within 8 bits.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -24570,7 +24767,7 @@ int16_t4 unpackInt4x8ToInt16(uint packedValue)
}
}
-/// Pack a vector of 4 unsigned 32 bit integers into a packed value of 4 8-bit integers, dropping unused bits.
+/// Pack a vector of 4 unsigned 32/16 bit integers into a packed value of 4 8-bit integers, dropping unused bits.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -24588,7 +24785,6 @@ uint packUint4x8(uint32_t4 unpackedValue)
}
}
-/// Pack a vector of 4 unsigned 16 bit integers into a packed value of 4 8-bit integers, dropping unused bits.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -24602,7 +24798,7 @@ uint packUint4x8(uint16_t4 unpackedValue)
}
}
-/// Pack a vector of 4 signed 32 bit integers into a packed value of 4 8-bit integers, dropping unused bits.
+/// Pack a vector of 4 signed 32/16 bit integers into a packed value of 4 8-bit integers, dropping unused bits.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -24617,7 +24813,6 @@ uint packInt4x8(int32_t4 unpackedValue)
}
}
-/// Pack a vector of 4 signed 16 bit integers into a packed value of 4 8-bit integers, dropping unused bits.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -24631,8 +24826,8 @@ uint packInt4x8(int16_t4 unpackedValue)
}
}
-/// Pack a vector of 4 signed 32 bit integers into a packed value of 4 8-bit integers,
-/// clamping each value to the range [-128, 127] to ensure it fits within 8 bits.
+/// Pack a vector of 4 signed 32/16 bit integers into a packed value of 4 8-bit integers,
+/// clamping each value to the range [0, 255] to ensure it fits within 8 bits.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -24647,8 +24842,6 @@ uint packUint4x8Clamp(int32_t4 unpackedValue)
}
}
-/// Pack a vector of 4 unsigned 16 bit integers into a packed value of 4 8-bit integers,
-/// clamping each value to the range [0, 255] to ensure it fits within 8 bits.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -24662,7 +24855,7 @@ uint packUint4x8Clamp(int16_t4 unpackedValue)
}
}
-/// Pack a vector of 4 signed 32 bit integers into a packed value of 4 8-bit integers,
+/// Pack a vector of 4 signed 32/16 bit integers into a packed value of 4 8-bit integers,
/// clamping each value to the range [-128, 127] to ensure it fits within 8 bits.
[__readNone]
[ForceInline]
@@ -24678,8 +24871,6 @@ uint packInt4x8Clamp(int32_t4 unpackedValue)
}
}
-/// Pack a vector of 4 signed 16 bit integers into a packed value of 4 8-bit integers,
-/// clamping each value to the range [-128, 127] to ensure it fits within 8 bits.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -24893,7 +25084,7 @@ half2 unpackHalf2x16ToHalf(uint packedValue)
return half2(unpackHalf2x16ToFloat(packedValue));
}
-/// Convert a 4-component vector of normalized unsigned single-precision floating-point
+/// Convert a 4-component vector of normalized unsigned single/half-precision floating-point
/// values to four 8-bit integer values, then pack these 8-bit values into a
/// 32-bit unsigned integer.
[__readNone]
@@ -24917,9 +25108,6 @@ uint packUnorm4x8(float4 unpackedValue)
}
}
-/// Convert a 4-component vector of normalized unsigned half-precision floating-point
-/// values to four 8-bit integer values, then pack these 8-bit values into a
-/// 32-bit unsigned integer.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -24933,7 +25121,7 @@ uint packUnorm4x8(half4 unpackedValue)
}
}
-/// Convert a 4-component vector of normalized signed single-precision floating-point
+/// Convert a 4-component vector of normalized signed single/half-precision floating-point
/// values to four 8-bit integer values, then pack these 8-bit values into a
/// 32-bit unsigned integer.
[__readNone]
@@ -24957,9 +25145,6 @@ uint packSnorm4x8(float4 unpackedValue)
}
}
-/// Convert a 4-component vector of normalized signed half-precision floating-point
-/// values to four 8-bit integer values, then pack these 8-bit values into a
-/// 32-bit unsigned integer.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -24973,7 +25158,7 @@ uint packSnorm4x8(half4 unpackedValue)
}
}
-/// Convert a 2-component vector of normalized unsigned single-precision floating-point
+/// Convert a 2-component vector of normalized unsigned single/half-precision floating-point
/// values to two 16-bit integer values, then pack these 16-bit values into a
/// 32-bit unsigned integer.
[__readNone]
@@ -24997,9 +25182,6 @@ uint packUnorm2x16(float2 unpackedValue)
}
}
-/// Convert a 2-component vector of normalized unsigned half-precision floating-point
-/// values to two 16-bit integer values, then pack these 16-bit values into a
-/// 32-bit unsigned integer.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -25013,7 +25195,7 @@ uint packUnorm2x16(half2 unpackedValue)
}
}
-/// Convert a 2-component vector of normalized signed single-precision floating-point
+/// Convert a 2-component vector of normalized signed single/half-precision floating-point
/// values to two 16-bit integer values, then pack these 16-bit values into a
/// 32-bit unsigned integer.
[__readNone]
@@ -25037,9 +25219,6 @@ uint packSnorm2x16(float2 unpackedValue)
}
}
-/// Convert a 2-component vector of normalized signed half-precision floating-point
-/// values to two 16-bit integer values, then pack these 16-bit values into a
-/// 32-bit unsigned integer.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]
@@ -25053,7 +25232,7 @@ uint packSnorm2x16(half2 unpackedValue)
}
}
-/// Convert a 2-component vector of IEEE-754 binary16 single-precision floating-point
+/// Convert a 2-component vector of IEEE-754 binary16 single/half-precision floating-point
/// values to two 16-bit integer values, then pack these 16-bit values into a
/// 32-bit unsigned integer.
[__readNone]
@@ -25076,9 +25255,6 @@ uint packHalf2x16(float2 unpackedValue)
}
}
-/// Convert a 2-component vector of IEEE-754 binary16 half-precision floating-point
-/// values to two 16-bit integer values, then pack these 16-bit values into a
-/// 32-bit unsigned integer.
[__readNone]
[ForceInline]
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, pack_vector)]