diff options
| author | Jay Kwak <82421531+jkwak-work@users.noreply.github.com> | 2025-01-30 00:59:49 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-30 00:59:49 -0800 |
| commit | ba9b2785c69c1b8c6d2b4103267b5281815f9f23 (patch) | |
| tree | e4ba4ca76c6592b90764a0a7ac32502639dc93aa /source | |
| parent | 2ae194d51e15c064c3d905e628f7335de7504e32 (diff) | |
Support cooperative vector (#6223)
* Support cooperative vector without Vulkan-header update
Adding a Slang support for cooperative vector.
But this commit doesn't have Vulkan-header update.
Diffstat (limited to 'source')
30 files changed, 3074 insertions, 23 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 2224b8e82..d4cce037d 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -3234,6 +3234,18 @@ int __alignOf_intrinsic<T>() return __alignOf_intrinsic_impl<T>(__default<T>()); } +[__unsafeForceInlineEarly] +int32_t __elemToByteOffset<T>(int32_t elemOffset) +{ + return elemOffset * __naturalStrideOf<T>(); +} + +[__unsafeForceInlineEarly] +int32_t __byteToElemOffset<T>(int32_t byteOffset) +{ + return byteOffset / __naturalStrideOf<T>(); +} + __intrinsic_op($(kIROp_TreatAsDynamicUniform)) T asDynamicUniform<T>(T v); diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 1853a82b6..ecab7ff93 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -4530,6 +4530,22 @@ __intrinsic_op($(kIROp_ByteAddressBufferStore)) [require(cpp_cuda_glsl_hlsl_metal_spirv, byteaddressbuffer_rw)] void __byteAddressBufferStore<T>(RasterizerOrderedByteAddressBuffer buffer, int offset, int alignment, T value); +__intrinsic_op($(kIROp_GetUntypedBufferPtr)) +[require(spirv, byteaddressbuffer)] +Ptr<uint[]> __getByteAddressBufferPtr(ByteAddressBuffer buffer); + +__intrinsic_op($(kIROp_GetUntypedBufferPtr)) +[require(spirv, byteaddressbuffer_rw)] +Ptr<uint[]> __getByteAddressBufferPtr(RWByteAddressBuffer buffer); + +__intrinsic_op($(kIROp_GetStructuredBufferPtr)) +[require(spirv, structuredbuffer)] +Ptr<T[]> __getStructuredBufferPtr<T>(StructuredBuffer<T> buffer); + +__intrinsic_op($(kIROp_GetStructuredBufferPtr)) +[require(spirv, structuredbuffer_rw)] +Ptr<T[]> __getStructuredBufferPtr<T>(RWStructuredBuffer<T> buffer); + /** Represents an opaque handle to a read-only structured buffer allocated in global memory. A structured buffer can be viewed as an array of the specified element type. @@ -12414,10 +12430,50 @@ __generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType> __intrinsic_op($(kIROp_IntCast)) T __int_cast(U val); +__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType> +__intrinsic_op($(kIROp_FloatCast)) +T __real_cast(U val); + +__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType> +__intrinsic_op($(kIROp_CastIntToFloat)) +T __int_to_float_cast(U val); + +__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType> +__intrinsic_op($(kIROp_CastFloatToInt)) +T __float_to_int_cast(U val); + +__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType> +[__unsafeForceInlineEarly] +T __arithmetic_cast(U val) +{ + if (__isFloat<T>() && __isInt<U>()) + return __int_to_float_cast<T>(val); + else if (__isInt<T>() && __isFloat<U>()) + return __float_to_int_cast<T>(val); + else if (__isFloat<T>() && __isFloat<U>()) + return __real_cast<T>(val); + else if (__isInt<T>() && __isInt<U>()) + return __int_cast<T>(val); + return T(0); +} + __generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType, let N : int> __intrinsic_op($(kIROp_IntCast)) vector<T,N> __int_cast(vector<U,N> val); +__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType, let N : int> +__intrinsic_op($(kIROp_FloatCast)) +vector<T,N> __real_cast(vector<U,N> val); + +__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType, let N : int> +__intrinsic_op($(kIROp_CastIntToFloat)) +vector<T,N> __int_to_float_cast(vector<U,N> val); + +__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType, let N : int> +__intrinsic_op($(kIROp_CastFloatToInt)) +vector<T,N> __float_to_int_cast(vector<U,N> val); + + /// Extract sign of value. /// @param x The value to extract the sign of. /// @return -1 if `x` is negative, 0 if `x` is zero, and 1 if `x` is positive. @@ -21477,6 +21533,7 @@ ${{{{ } }}}} + /// Represents a bindless handle to a descriptor. A descriptor handle is always an ordinary data type and can be /// declared in any memory location. /// @remarks Opaque descriptor types such as textures(`Texture2D` etc.), `SamplerState` and buffers (e.g. `StructuredBuffer`) @@ -21579,8 +21636,6 @@ extern T getDescriptorFromHandle<T:IOpaqueDescriptor>(DescriptorHandle<T> handle __intrinsic_op($(kIROp_NonUniformResourceIndex)) DescriptorHandle<T> nonuniform<T:IOpaqueDescriptor>(DescriptorHandle<T> ptr); -//@hidden: - __glsl_version(450) __glsl_extension(GL_ARB_shader_clock) [require(glsl_spirv, GL_ARB_shader_clock)] @@ -21633,6 +21688,2216 @@ extension<T, L : IBufferDataLayout> RasterizerOrderedStructuredBuffer<T, L> : IR int getCount() { uint count; uint stride; this.GetDimensions(count, stride); return count; } } +// +// Cooperative Vector +// + +__intrinsic_type($(kIROp_CoopVectorType)) +[require(cooperative_vector)] +struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmetic +{ + // + // Initialization + // + + [ForceInline] + [require(cooperative_vector)] + __init() + { + this = CoopVec<T, N>(T(0)); + } + [ForceInline] + [require(cooperative_vector)] + __init(T t) + { + this.fill(t); + } + [ForceInline] + [require(cooperative_vector)] + __init<U : __BuiltinArithmeticType>(CoopVec<U, N> other) + { + this.copyFrom(other); + } + + [ForceInline] + [require(cooperative_vector)] + __init<each U : __BuiltinArithmeticType>(expand each U args) + { + static_assert(countof(U) == N, "number of arguments to CoopVec constructor must match number of elements"); + this = __makeCoopVec<T, N>(expand (__arithmetic_cast<T>(each args))); + } + [OverloadRank(-10)] + [ForceInline] + __init(int i) + { + this = CoopVec<T, N>(T(i)); + } + [ForceInline] + __init(This x) + { + this = x; + } + + // + // Simple setters + // + + [require(hlsl)] + [mutating] + [ForceInline] + void copyFrom<U : __BuiltinArithmeticType>(CoopVec<U,N> other) + { + __target_switch + { + case hlsl: __intrinsic_asm ".CopyFrom"; + default: + if (__isFloat<T>() && __isInt<U>()) + this = __int_to_float_cast<T>(other); + else if (__isInt<T>() && __isFloat<U>()) + this = __float_to_int_cast<T>(other); + else if (__isFloat<T>() && __isFloat<U>()) + this = __real_cast<T>(other); + else if (__isInt<T>() && __isInt<U>()) + this = __int_cast<T>(other); + } + } + + [require(cooperative_vector)] + [mutating] + [ForceInline] + void fill(T t) + { + __target_switch + { + case spirv: + this = spirv_asm { + OpExtension "SPV_EXT_replicated_composites"; + OpCapability ReplicatedCompositesEXT; + result:$$CoopVec<T, N> = OpCompositeConstructReplicateEXT $t; + }; + case hlsl: + case hlsl: __intrinsic_asm ".Fill"; + default: + for(int i = 0; i < N; ++i) + this[i] = t; + } + } + + // + // Loading and storing + // + + [ForceInline] + [require(cooperative_vector)] + void store(RWByteAddressBuffer buffer, int32_t byteOffset16ByteAligned = 0) + { + __target_switch + { + case spirv: + let ptr = buffer.GetBufferPointer(); + spirv_asm + { + // TODO: Should this be a byte offset + OpCooperativeVectorStoreNV $ptr $byteOffset16ByteAligned $this None; + }; + // Not supported + // case hlsl: + // this.__Store(buffer, byteOffset16ByteAligned); + default: + for(int i = 0; i < N; ++i) + buffer.StoreByteOffset(byteOffset16ByteAligned + __elemToByteOffset<T>(i), this[i]); + } + } + + [ForceInline] + [require(cooperative_vector)] + void store(RWStructuredBuffer<T> buffer, int32_t byteOffset16ByteAligned = 0) + { + __target_switch + { + case spirv: + let ptr = __getStructuredBufferPtr(buffer); + spirv_asm + { + // TODO: Should this be a byte offset + OpCooperativeVectorStoreNV $ptr $byteOffset16ByteAligned $this None; + }; + // Not supported + // case hlsl: + // this.__Store(buffer, byteOffset16ByteAligned); + default: + for(int i = 0; i < N; ++i) + buffer[i + __byteToElemOffset<T>(byteOffset16ByteAligned)] = this[i]; + } + } + + [ForceInline] + [require(cooperative_vector)] + void store<let M : int>(__ref groupshared T[M] data, int32_t byteOffset16ByteAligned = 0) + { + __target_switch + { + case spirv: + spirv_asm{ + OpCooperativeVectorStoreNV &data $byteOffset16ByteAligned $this None; + }; + case hlsl: + this.__Store(data, __byteToElemOffset<T>(byteOffset16ByteAligned)); + default: + for(int i = 0; i < N; ++i) + data[i + __byteToElemOffset<T>(byteOffset16ByteAligned)] = this[i]; + } + } + + /// spirv only storing to a groupshared array of any type + [ForceInline] + [require(spirv, cooperative_vector)] + void storeAny<U, let M : int>(__ref groupshared U[M] data, int32_t byteOffset16ByteAligned = 0) + { + __target_switch + { + case spirv: + spirv_asm{ + OpCooperativeVectorStoreNV &data $byteOffset16ByteAligned $this None; + }; + } + } + [ForceInline] + [require(spirv, cooperative_vector)] + void storeAny<U, let M : int, let L : int>(__ref groupshared vector<U, L>[M] data, int32_t byteOffset16ByteAligned = 0) + { + __target_switch + { + case spirv: + spirv_asm{ + OpCooperativeVectorStoreNV &data $byteOffset16ByteAligned $this None; + }; + } + } + + [ForceInline] + [__NoSideEffect] + [require(cooperative_vector)] + static CoopVec<T, N> load(ByteAddressBuffer buffer, int32_t byteOffset16ByteAligned = 0) + { + __target_switch + { + case spirv: + let ptr = buffer.GetBufferPointer(); + return spirv_asm + { + result:$$CoopVec<T, N> = OpCooperativeVectorLoadNV $ptr $byteOffset16ByteAligned None; + }; + case hlsl: + CoopVec<T, N> ret; + ret.__Load(buffer, byteOffset16ByteAligned); + return ret; + default: + var vec = CoopVec<T, N>(); + for(int i = 0; i < N; ++i) + vec[i] = buffer.LoadByteOffset<T>(byteOffset16ByteAligned + __elemToByteOffset<T>(i)); + return vec; + } + return CoopVec<T, N>(); + } + + [ForceInline] + [__NoSideEffect] + [require(cooperative_vector)] + static CoopVec<T, N> load(RWByteAddressBuffer buffer, int32_t byteOffset16ByteAligned = 0) + { + __target_switch + { + case spirv: + let ptr = buffer.GetBufferPointer(); + return spirv_asm + { + result:$$CoopVec<T, N> = OpCooperativeVectorLoadNV $ptr $byteOffset16ByteAligned None; + }; + case hlsl: + CoopVec<T, N> ret; + ret.__Load(buffer, byteOffset16ByteAligned); + return ret; + default: + var vec = CoopVec<T, N>(); + for(int i = 0; i < N; ++i) + vec[i] = buffer.LoadByteOffset<T>(byteOffset16ByteAligned + __elemToByteOffset<T>(i)); + return vec; + } + return CoopVec<T, N>(); + } + + [ForceInline] + [__NoSideEffect] + [require(cooperative_vector)] + static CoopVec<T, N> load(StructuredBuffer<T> buffer, int32_t byteOffset16ByteAligned = 0) + { + __target_switch + { + case spirv: + let ptr = __getStructuredBufferPtr(buffer); + return spirv_asm + { + result:$$CoopVec<T, N> = OpCooperativeVectorLoadNV $ptr $byteOffset16ByteAligned None; + }; + // Not supported + // case hlsl: + // CoopVec<T, N> ret; + // ret.__Load(buffer, byteOffset16ByteAligned); + // return ret; + default: + var vec = CoopVec<T, N>(); + for(int i = 0; i < N; ++i) + vec[i] = buffer[__byteToElemOffset<T>(byteOffset16ByteAligned) + i]; + return vec; + } + return CoopVec<T, N>(); + } + + [ForceInline] + [__NoSideEffect] + [require(spirv, cooperative_vector)] + static CoopVec<T, N> load(RWStructuredBuffer<T> buffer, int32_t byteOffset16ByteAligned = 0) + { + __target_switch + { + case spirv: + let ptr = __getStructuredBufferPtr(buffer); + return spirv_asm + { + result:$$CoopVec<T, N> = OpCooperativeVectorLoadNV $ptr $byteOffset16ByteAligned None; + }; + // Not supported + // case hlsl: + // CoopVec<T, N> ret; + // ret.__Load(buffer, byteOffset16ByteAligned); + // return ret; + default: + var vec = CoopVec<T, N>(); + for(int i = 0; i < N; ++i) + vec[i] = buffer[__byteToElemOffset<T>(byteOffset16ByteAligned) + i]; + return vec; + } + } + + // Groupshared + [ForceInline] + [__NoSideEffect] + [require(cooperative_vector)] + static CoopVec<T, N> load<let M : int>(__constref groupshared const T[M] data, int32_t byteOffset16ByteAligned = 0) + { + __target_switch + { + case spirv: + return spirv_asm{ + result:$$CoopVec<T, N> = OpCooperativeVectorLoadNV &data $byteOffset16ByteAligned None + }; + case hlsl: + CoopVec<T, N> ret; + ret.__Load(data, __byteToElemOffset<T>(byteOffset16ByteAligned)); + return ret; + default: + CoopVec<T,N> result; + for(int i = 0; i < N; ++i) + result[i] = data[i + __byteToElemOffset<T>(byteOffset16ByteAligned)]; + return result; + } + } + + /// spirv only loading from a groupshared array of any type + [ForceInline] + [__NoSideEffect] + [require(spirv, cooperative_vector)] + static CoopVec<T, N> loadAny<U : __BuiltinArithmeticType, let M : int>(__constref groupshared const U[M] data, int32_t byteOffset16ByteAligned = 0) + { + __target_switch + { + case spirv: + return spirv_asm{ + result:$$CoopVec<T, N> = OpCooperativeVectorLoadNV &data $byteOffset16ByteAligned None + }; + } + } + [ForceInline] + [__NoSideEffect] + [require(spirv, cooperative_vector)] + static CoopVec<T, N> loadAny<U : __BuiltinArithmeticType, let M : int, let L : int>(__constref groupshared const vector<U, L>[M] data, int32_t byteOffset16ByteAligned = 0) + { + __target_switch + { + case spirv: + return spirv_asm{ + result:$$CoopVec<T, N> = OpCooperativeVectorLoadNV &data $byteOffset16ByteAligned None + }; + } + } + + // + // Subscript + // + + __intrinsic_op($(kIROp_GetElement)) + [__NoSideEffect] + T __indexRead(int index); + + __intrinsic_op($(kIROp_GetElementPtr)) + [__ref] + [__NoSideEffect] + Ref<T> __indexRef(int index); + + [ForceInline] + [__NoSideEffect] + int getCount() + { + return N; + } + + __subscript(int index) -> T + { + [__NoSideEffect] + [nonmutating] + get + { + __target_switch + { + case hlsl: __intrinsic_asm ".ReadFromIndex"; + default: return __indexRead(index); + } + } + + [mutating] + set + { + __target_switch + { + case hlsl: __intrinsic_asm ".WriteToIndex"; + default: __indexRef(index) = newValue; + } + } + + // Unavailable on HLSL + // The CoopVector HLSL spec says that indexing with a subscript + // operation can work, but dxc currently crashes with this + // __intrinsic_op($(kIROp_GetElementPtr)) + // [__ref] + // ref; + } + + static CoopVec<T, N> replicate(T t) + { + CoopVec<T, N> ret; + ret.fill(t); + return ret; + } + + // + // Equality and ordering + // + + bool equals(This other) + { + for (int i = 0; i < N; i++) + { + if (this[i] != other[i]) + { + return false; + } + } + return true; + } + bool lessThan(This other) + { + for (int i = 0; i < N; i++) + { + if (this[i] < other[i]) + { + return true; + } + else if (this[i] > other[i]) + { + return false; + } + } + return false; + } + bool lessThanOrEquals(This other) + { + for (int i = 0; i < N; i++) + { + if (this[i] < other[i]) + { + return true; + } + else if (this[i] > other[i]) + { + return false; + } + } + return true; + } + + // + // Arithmetic + // + + __intrinsic_op($(kIROp_Add)) + This __pureAdd(This other); + + [mutating] + [require(hlsl)] + void __mutAdd(This other) + { + __target_switch + { + case hlsl: __intrinsic_asm ".Add"; + } + } + + // TODO: Why is this ForceInline necessary for hlsl, dxc bug? + [ForceInline] + This add(This other) + { + __target_switch + { + case hlsl: + This ret = this; + ret.__mutAdd(other); + return ret; + default: return __pureAdd(other); + } + } + + __intrinsic_op($(kIROp_Sub)) + This __pureSub(This other); + [mutating] + [require(hlsl)] + void __mutSub(This other) + { __target_switch { case hlsl: __intrinsic_asm ".Subtract"; } } + + [ForceInline] + This sub(This other) + { + __target_switch + { + case hlsl: + This ret = this; + ret.__mutSub(other); + return ret; + default: return __pureSub(other); + } + } + + __intrinsic_op($(kIROp_Mul)) + This __pureMul(This other); + + [mutating] + [require(hlsl)] + void __mutMul(This other) + { __target_switch { case hlsl: __intrinsic_asm ".Multiply"; } } + + [ForceInline] + This mul(This other) + { + __target_switch + { + case hlsl: + This ret = this; + ret.__mutMul(other); + return ret; + default: return __pureMul(other); + } + } + + __intrinsic_op($(kIROp_Div)) + This __pureDiv(This other); + + [mutating] + [require(hlsl)] + void __mutDiv(This other) + { __target_switch { case hlsl: __intrinsic_asm ".Divide"; } } + + [ForceInline] + This div(This other) + { + __target_switch + { + case hlsl: + This ret = this; + ret.__mutDiv(other); + return ret; + default: return __pureDiv(other); + } + } + + [mutating] + [require(hlsl)] + void __mutMod(This other) + { __target_switch { case hlsl: __intrinsic_asm ".Mod"; } } + + [ForceInline] + This mod(This other) + { + __target_switch + { + case hlsl: + This ret = this; + ret.__mutMod(other); + return ret; + default: + This ret; + for(int i = 0; i < N; ++i) + ret[i] = this[i] % other[i]; + return ret; + } + } + + __intrinsic_op($(kIROp_Neg)) + This __pureNeg(This other); + + //[ForceInline] + This neg() + { + __target_switch + { + case hlsl: + This ret = this; + for(int i = 0; i < N; ++i) + ret[i] = -this[i]; + return ret; + default: return __pureNeg(this); + } + } + + [mutating] + [require(hlsl)] + void __mutScalarMul(T t) + { __target_switch { case hlsl: __intrinsic_asm ".ScalarMultiply"; } } + + [mutating] + [require(hlsl)] + void __mutMin(This other) + { __target_switch { case hlsl: __intrinsic_asm ".Min"; } } + [mutating] + [require(hlsl)] + void __mutMax(This other) + { __target_switch { case hlsl: __intrinsic_asm ".Max"; } } + [mutating] + [require(hlsl)] + void __mutClamp(This minVal, This maxVal) + { __target_switch { case hlsl: __intrinsic_asm ".Clamp"; } } + + // + // Internal utilities for loading and storing + // + + [mutating] + [require(hlsl, byteaddressbuffer)] + void __Load(const ByteAddressBuffer buffer, uint byteOffset, uint alignment = 0) + { __target_switch { case hlsl: __intrinsic_asm ".Load"; } } + + [mutating] + [require(hlsl, byteaddressbuffer_rw)] + void __Load(const RWByteAddressBuffer buffer, uint byteOffset, uint alignment = 0) + { __target_switch { case hlsl: __intrinsic_asm ".Load"; } } + + __generic<let M : int> + [mutating] + // Careful, this takes the offset in elements + [require(hlsl)] + void __Load(__constref groupshared T buffer[M], uint elemOffset) + { __target_switch { case hlsl: __intrinsic_asm ".Load"; } } + + [require(hlsl, byteaddressbuffer_rw)] + void __Store(RWByteAddressBuffer buffer, uint byteOffset, uint alignment = 0) + { __target_switch { case hlsl: __intrinsic_asm ".Store"; } } + + __generic<let M : int> + [require(hlsl)] + // Careful, this takes the offset in elements + void __Store(__ref groupshared T buffer[M], uint elemOffset) + { __target_switch { case hlsl: __intrinsic_asm ".Store"; } } + +${{{{ +static const struct { + bool isRW; + char const* type; +} kByteAddressBufferCases[] = +{ + {true, "RWByteAddressBuffer"}, + {false, "ByteAddressBuffer"} +}; +for(auto buffer : kByteAddressBufferCases) { +}}}} + [mutating] + [require(hlsl, byteaddressbuffer_rw)] + void __mutMatMul<U : __BuiltinArithmeticType, let K : int>( + CoopVec<U, K> input, uint inputInterpretationHLSL, + $(buffer.type) matrix, uint matrixOffset, uint matrixInterpretationHLSL, + uint m, uint k, uint memoryLayoutHLSL, bool transpose, uint matrixStride) + { + __target_switch + { + case hlsl: __intrinsic_asm ".MatMul"; + } + } + + [mutating] + [require(hlsl, byteaddressbuffer_rw)] + void __mutMatMulAdd<U : __BuiltinArithmeticType, let K : int>( + CoopVec<U, K> input, uint inputInterpretationHLSL, + $(buffer.type) matrix, uint matrixOffset, uint matrixInterpretationHLSL, + $(buffer.type) bias, uint biasOffset, uint biasInterpretationHLSL, + uint m, uint k, uint memoryLayoutHLSL, bool transpose, uint matrixStride) + { + __target_switch + { + case hlsl: __intrinsic_asm ".MatMulAdd"; + } + } + + [mutating] + [ForceInline] + void matMulAccumPacked<U : __BuiltinArithmeticType, let PackedK : int>( + CoopVec<U, PackedK> input, + constexpr CoopVecComponentType inputInterpretation, + constexpr int k, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr CoopVecComponentType matrixInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride + ) + { + static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK + , "for non-packed inputInterpretation values k must be equal to the input vector length"); + static_assert(!__isPackedInputInterpretation(inputInterpretation) + || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); + + __target_switch + { + case hlsl: + let inputInterpretationHLSL = __getHLSLCoopVecComponentType(inputInterpretation); + let matrixInterpretationHLSL = __getHLSLCoopVecComponentType(matrixInterpretation); + let memoryLayoutHLSL = __getHLSLCoopVecMatrixLayout(memoryLayout); + This temp = this; + temp.__mutMatMul( + input, + inputInterpretationHLSL, + matrix, + matrixOffset, + matrixInterpretationHLSL, + N, + k, + memoryLayoutHLSL, + transpose, + matrixStride + ); + this.__mutAdd(temp); + default: this = this + coopVecMatMulPacked<T, N, PackedK, U>( + input, + inputInterpretation, + k, + matrix, + matrixOffset, + matrixInterpretation, + memoryLayout, + transpose, + matrixStride + ); + } + } + + [mutating] + [ForceInline] + void matMulAccum<U : __BuiltinArithmeticType, let K : int>( + CoopVec<U, K> input, + constexpr CoopVecComponentType inputInterpretation, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr CoopVecComponentType matrixInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride + ) + { + static_assert(!__isPackedInputInterpretation(inputInterpretation) + , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually"); + this.matMulAccumPacked<U, K>( + input, + inputInterpretation, + K, + matrix, + matrixOffset, + matrixInterpretation, + memoryLayout, + transpose, + matrixStride + ); + } + + [mutating] + [ForceInline] + void matMulAddAccumPacked<U : __BuiltinArithmeticType, let PackedK : int>( + CoopVec<U, PackedK> input, + constexpr CoopVecComponentType inputInterpretation, + constexpr int k, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr CoopVecComponentType matrixInterpretation, + $(buffer.type) bias, + int32_t biasOffset, + constexpr CoopVecComponentType biasInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride + ) + { + static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK + , "for non-packed inputInterpretation values k must be equal to the input vector length"); + static_assert(!__isPackedInputInterpretation(inputInterpretation) + || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); + + __target_switch + { + case hlsl: + let inputInterpretationHLSL = __getHLSLCoopVecComponentType(inputInterpretation); + let matrixInterpretationHLSL = __getHLSLCoopVecComponentType(matrixInterpretation); + let biasInterpretationHLSL = __getHLSLCoopVecComponentType(biasInterpretation); + let memoryLayoutHLSL = __getHLSLCoopVecMatrixLayout(memoryLayout); + This temp = this; + temp.__mutMatMulAdd( + input, + inputInterpretationHLSL, + matrix, + matrixOffset, + matrixInterpretationHLSL, + bias, + biasOffset, + biasInterpretationHLSL, + N, + k, + memoryLayoutHLSL, + transpose, + matrixStride + ); + this.__mutAdd(temp); + default: this = this + coopVecMatMulAddPacked<T, N, PackedK, U>( + input, + inputInterpretation, + k, + matrix, + matrixOffset, + matrixInterpretation, + bias, + biasOffset, + biasInterpretation, + memoryLayout, + transpose, + matrixStride + ); + } + } + + [mutating] + [ForceInline] + void matMulAddAccum<U : __BuiltinArithmeticType, let K : int>( + CoopVec<U, K> input, + constexpr CoopVecComponentType inputInterpretation, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr CoopVecComponentType matrixInterpretation, + $(buffer.type) bias, + int32_t biasOffset, + constexpr CoopVecComponentType biasInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride + ) + { + static_assert(!__isPackedInputInterpretation(inputInterpretation) + , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually"); + this.matMulAddAccumPacked<U, K>( + input, + inputInterpretation, + K, + matrix, + matrixOffset, + matrixInterpretation, + bias, + biasOffset, + biasInterpretation, + memoryLayout, + transpose, + matrixStride + ); + } + + +${{{{ +} +}}}} +} + +__intrinsic_op($(kIROp_MakeCoopVectorFromValuePack)) +CoopVec<T, N> __makeCoopVec<T : __BuiltinArithmeticType, let N : int, each U>(expand each U args); + +__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType, let N : int> +__intrinsic_op($(kIROp_IntCast)) +[require(cooperative_vector)] +CoopVec<T,N> __int_cast(CoopVec<U,N> val); + +__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType, let N : int> +__intrinsic_op($(kIROp_FloatCast)) +[require(cooperative_vector)] +CoopVec<T,N> __real_cast(CoopVec<U,N> val); + +__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType, let N : int> +__intrinsic_op($(kIROp_CastIntToFloat)) +[require(cooperative_vector)] +CoopVec<T,N> __int_to_float_cast(CoopVec<U,N> val); + +__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType, let N : int> +__intrinsic_op($(kIROp_CastFloatToInt)) +[require(cooperative_vector)] +CoopVec<T,N> __float_to_int_cast(CoopVec<U,N> val); + +__generic<T : __BuiltinArithmeticType, let N : int> +[ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> operator *(CoopVec<T, N> lhs, const T rhs) +{ + __target_switch + { + case spirv: + if (__isFloat<T>()) + { + return spirv_asm + { + result:$$CoopVec<T, N> = OpVectorTimesScalar $lhs $rhs; + }; + } + else + { + for (int i = 0; i < N; ++i) + { + lhs[i] *= rhs; + } + return lhs; + } + case hlsl: + CoopVec<T, N> ret = lhs; + ret.__mutScalarMul(rhs); + return ret; + default: + for (int i = 0; i < N; ++i) + { + lhs[i] *= rhs; + } + return lhs; + } +} + +__generic<T : __BuiltinArithmeticType, let N : int> +[ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> operator *(const T lhs, CoopVec<T, N> rhs) +{ + return rhs * lhs; +} + +[ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> min<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> x, CoopVec<T, N> y) +{ + __target_switch + { + case spirv: + return spirv_asm + { + result:$$CoopVec<T, N> = OpExtInst glsl450 FMin $x $y; + }; + case hlsl: + CoopVec<T, N> ret = x; + ret.__mutMin(y); + return ret; + default: + CoopVec<T, N> ret; + for(int i = 0; i < N; ++i) + ret[i] = min(x[i], y[i]); + + return ret; + } +} + +[ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> max<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> x, CoopVec<T, N> y) +{ + __target_switch + { + case spirv: + return spirv_asm + { + result:$$CoopVec<T, N> = OpExtInst glsl450 FMax $x $y; + }; + case hlsl: + CoopVec<T, N> ret = x; + ret.__mutMax(y); + return ret; + default: + CoopVec<T, N> ret; + for(int i = 0; i < N; ++i) + ret[i] = max(x[i], y[i]); + return ret; + } +} + +[ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> clamp<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> x, CoopVec<T, N> minVal, CoopVec<T, N> maxVal) +{ + __target_switch + { + case spirv: + return spirv_asm + { + result:$$CoopVec<T, N> = OpExtInst glsl450 FClamp $x $minVal $maxVal; + }; + case hlsl: + CoopVec<T, N> ret = x; + ret.__mutClamp(minVal, maxVal); + return ret; + default: + CoopVec<T, N> ret; + for(int i = 0; i < N; ++i) + ret[i] = clamp(x[i], minVal[i], maxVal[i]); + return ret; + } +} + +[ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> min<T : __BuiltinIntegerType, let N : int>(CoopVec<T, N> x, CoopVec<T, N> y) +{ + __target_switch + { + case spirv: + if (__isSignedInt<T>()) + { + return spirv_asm + { + result:$$CoopVec<T, N> = OpExtInst glsl450 SMin $x $y + }; + } + else + { + return spirv_asm + { + result:$$CoopVec<T, N> = OpExtInst glsl450 UMin $x $y + }; + } + case hlsl: + CoopVec<T, N> ret = x; + ret.__mutMin(y); + return ret; + default: + CoopVec<T, N> ret; + for(int i = 0; i < N; ++i) + ret[i] = min(x[i], y[i]); + + return ret; + } +} + +// [ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> max<T : __BuiltinIntegerType, let N : int>(CoopVec<T, N> x, CoopVec<T, N> y) +{ + __target_switch + { + case spirv: + if (__isSignedInt<T>()) + { + return spirv_asm + { + result:$$CoopVec<T, N> = OpExtInst glsl450 SMax $x $y + }; + } + else + { + return spirv_asm + { + result:$$CoopVec<T, N> = OpExtInst glsl450 UMax $x $y + }; + } + case hlsl: + CoopVec<T, N> ret = x; + ret.__mutMax(y); + return ret; + default: + CoopVec<T, N> ret; + for(int i = 0; i < N; ++i) + ret[i] = max(x[i], y[i]); + return ret; + } +} + +// [ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> clamp<T : __BuiltinIntegerType, let N : int>(CoopVec<T, N> x, CoopVec<T, N> minVal, CoopVec<T, N> maxVal) +{ + __target_switch + { + case spirv: + if (__isSignedInt<T>()) + { + return spirv_asm + { + result:$$CoopVec<T, N> = OpExtInst glsl450 SClamp $x $minVal $maxVal + }; + } + else + { + return spirv_asm + { + result:$$CoopVec<T, N> = OpExtInst glsl450 UClamp $x $minVal $maxVal + }; + } + case hlsl: + CoopVec<T, N> ret = x; + ret.__mutClamp(minVal, maxVal); + return ret; + default: + CoopVec<T, N> ret; + for(int i = 0; i < N; ++i) + ret[i] = clamp(x[i], minVal[i], maxVal[i]); + return ret; + } +} + +// [ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> step<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> edge, CoopVec<T, N> x) +{ + __target_switch + { + case spirv: + return spirv_asm + { + result:$$CoopVec<T, N> = OpExtInst glsl450 Step $edge $x; + }; + default: + CoopVec<T, N> ret; + for(int i = 0; i < N; ++i) + ret[i] = step(edge[i], x[i]); + return ret; + } +} + +// [ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> exp<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> x) +{ + __target_switch + { + case spirv: + return spirv_asm + { + result:$$CoopVec<T, N> = OpExtInst glsl450 Exp $x; + }; + default: + CoopVec<T, N> ret; + for(int i = 0; i < N; ++i) + ret[i] = exp(x[i]); + return ret; + } +} + +// [ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> log<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> x) +{ + __target_switch + { + case spirv: + return spirv_asm + { + result:$$CoopVec<T, N> = OpExtInst glsl450 Log $x; + }; + default: + CoopVec<T, N> ret; + for(int i = 0; i < N; ++i) + ret[i] = log(x[i]); + return ret; + } +} + +// [ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> tanh<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> x) +{ + __target_switch + { + case spirv: + return spirv_asm + { + result:$$CoopVec<T, N> = OpExtInst glsl450 Tanh $x; + }; + default: + CoopVec<T, N> ret; + for(int i = 0; i < N; ++i) + ret[i] = tanh(x[i]); + return ret; + } +} + +// TODO: Why does this fail when inlined on HLSL, +// We generate some really weird code... +// [ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> atan<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> yOverX) +{ + __target_switch + { + case spirv: + return spirv_asm + { + result:$$CoopVec<T, N> = OpExtInst glsl450 Atan $yOverX; + }; + default: + CoopVec<T, N> ret; + for(int i = 0; i < N; ++i) + ret[i] = atan(yOverX[i]); + return ret; + } +} + +// [ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> fma<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> a, CoopVec<T, N> b, CoopVec<T, N> c) +{ + // TODO: Investigate, why does this fail if it's not inlined + // replacing fma with mad below also fixes things... + // dxc generated substantially different code + __target_switch + { + case spirv: + return spirv_asm + { + result:$$CoopVec<T, N> = OpExtInst glsl450 Fma $a $b $c; + }; + default: + CoopVec<T, N> ret; + for(int i = 0; i < N; ++i) + ret[i] = mad(a[i], b[i], c[i]); + return ret; + } +} + +// Buffers from which values of arbitrary type can be loaded from byte offsets +interface IPhysicalBuffer +{ + [__unsafeForceInlineEarly] + T LoadByteOffset<T>(int offset); + + [__unsafeForceInlineEarly] + Ptr<uint32_t[]> GetBufferPointer(); +} + +// Buffers to which values of arbitrary type can be stored at byte offsets +interface IRWPhysicalBuffer : IPhysicalBuffer +{ + [__unsafeForceInlineEarly] + void StoreByteOffset<T>(int offset, T element); +} + +extension ByteAddressBuffer : IPhysicalBuffer +{ + [__unsafeForceInlineEarly] + Ptr<uint32_t[]> GetBufferPointer() + { + return __getStructuredBufferPtr(__getEquivalentStructuredBuffer<uint32_t>(this)); + } + + [__unsafeForceInlineEarly] + T LoadByteOffset<T>(int offset) + { + return this.Load<T>(offset); + } +} + +extension RWByteAddressBuffer : IPhysicalBuffer +{ + [__unsafeForceInlineEarly] + Ptr<uint32_t[]> GetBufferPointer() + { + return __getStructuredBufferPtr(__getEquivalentStructuredBuffer<uint32_t>(this)); + } + + [__unsafeForceInlineEarly] + T LoadByteOffset<T>(int offset) + { + return this.Load<T>(offset); + } +} + +extension RWByteAddressBuffer : IRWPhysicalBuffer +{ + [__unsafeForceInlineEarly] + void StoreByteOffset<T>(int offset, T element) + { + return this.Store<T>(offset, element); + } +} + + +// +// Convenience loading functions for cooperative vectors which infer the +// element type for structured buffers and groupshared arrays (and ByteAddressBuffers for consistency +// + +[ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> coopVecLoad<let N : int, T : __BuiltinArithmeticType>(ByteAddressBuffer buffer, int32_t byteOffset16ByteAligned = 0) +{ + return CoopVec<T, N>.load(buffer, byteOffset16ByteAligned); +} + +[ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> coopVecLoad<let N : int, T : __BuiltinArithmeticType>(RWByteAddressBuffer buffer, int32_t byteOffset16ByteAligned = 0) +{ + return CoopVec<T, N>.load(buffer, byteOffset16ByteAligned); +} + +[ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> coopVecLoad<let N : int, T : __BuiltinArithmeticType>(StructuredBuffer<T> buffer, int32_t byteOffset16ByteAligned = 0) +{ + return CoopVec<T, N>.load(buffer, byteOffset16ByteAligned); +} + +[ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> coopVecLoad<let N : int, T : __BuiltinArithmeticType>(RWStructuredBuffer<T> buffer, int32_t byteOffset16ByteAligned = 0) +{ + return CoopVec<T, N>.load(buffer, byteOffset16ByteAligned); +} + +// Groupshared +[ForceInline] +[require(cooperative_vector)] +CoopVec<T, N> coopVecLoadGroupshared<let N : int, T : __BuiltinArithmeticType, let M : int>(__constref groupshared const T[M] data, int32_t byteOffset16ByteAligned = 0) +{ + return CoopVec<T, N>.load(data, byteOffset16ByteAligned); +} + +// +// Coop Vector matrix multiplication +// + +enum CoopVecMatrixLayout +{ + RowMajor, + ColumnMajor, + InferencingOptimal, + TrainingOptimal +}; + +enum CoopVecComponentType +{ + FloatE4M3, + FloatE5M2, + Float16, + Float32, + Float64, + SignedInt8, + SignedInt16, + SignedInt32, + SignedInt64, + SignedInt8Packed, + UnsignedInt8, + UnsignedInt16, + UnsignedInt32, + UnsignedInt64, + UnsignedInt8Packed +}; + +[ForceInline] +int __inputInterpretationPackingFactor(CoopVecComponentType componentType) +{ + switch (componentType) + { + case CoopVecComponentType::SignedInt8Packed: + case CoopVecComponentType::UnsignedInt8Packed: + return 4; + } + return 1; +} + +[ForceInline] +bool __isPackedInputInterpretation(CoopVecComponentType componentType) +{ + return __inputInterpretationPackingFactor(componentType) != 1; +} + +// TODO: We might consider some way of specifying these from our lookup tables +[ForceInline] +uint32_t __getSpvCoopVecMatrixLayout(CoopVecMatrixLayout layout) +{ + switch (layout) + { + case CoopVecMatrixLayout::RowMajor: + return 0; + case CoopVecMatrixLayout::ColumnMajor: + return 1; + case CoopVecMatrixLayout::InferencingOptimal: + return 2; + case CoopVecMatrixLayout::TrainingOptimal: + return 3; + default: + static_assert(false, "unsupported layout value"); + } + return 0xffffffff; +} + +[ForceInline] +uint32_t __getHLSLCoopVecMatrixLayout(CoopVecMatrixLayout layout) +{ + switch (layout) + { + // TODO: Check these are the same + case CoopVecMatrixLayout::RowMajor: + return 0; + case CoopVecMatrixLayout::ColumnMajor: + return 1; + case CoopVecMatrixLayout::InferencingOptimal: + return 2; + case CoopVecMatrixLayout::TrainingOptimal: + return 3; + default: + static_assert(false, "unsupported layout value"); + } + return 0xffffffff; +} + +[ForceInline] +uint32_t __getSpvCoopVecComponentType(CoopVecComponentType componentType) +{ + switch (componentType) + { + case CoopVecComponentType::Float16: + return 0; + case CoopVecComponentType::Float32: + return 1; + case CoopVecComponentType::Float64: + return 2; + case CoopVecComponentType::SignedInt8: + return 3; + case CoopVecComponentType::SignedInt16: + return 4; + case CoopVecComponentType::SignedInt32: + return 5; + case CoopVecComponentType::SignedInt8Packed: + return 1000491000; + case CoopVecComponentType::SignedInt64: + return 6; + case CoopVecComponentType::UnsignedInt8: + return 7; + case CoopVecComponentType::UnsignedInt16: + return 8; + case CoopVecComponentType::UnsignedInt32: + return 9; + case CoopVecComponentType::UnsignedInt8Packed: + return 1000491001; + case CoopVecComponentType::UnsignedInt64: + return 10; + case CoopVecComponentType::FloatE4M3: + return 1000491002; + case CoopVecComponentType::FloatE5M2: + return 1000491003; + default: + static_assert(false, "unsupported componentType value"); + } + return 0xffffffff; +} + +[ForceInline] +uint32_t __getHLSLCoopVecComponentType(CoopVecComponentType componentType) +{ + switch (componentType) + { + case CoopVecComponentType::Float16: + return 0; + case CoopVecComponentType::Float32: + return 1; + case CoopVecComponentType::UnsignedInt8: + return 2; + case CoopVecComponentType::UnsignedInt16: + return 3; + case CoopVecComponentType::UnsignedInt32: + return 4; + case CoopVecComponentType::SignedInt8: + return 5; + case CoopVecComponentType::SignedInt16: + return 6; + case CoopVecComponentType::SignedInt32: + return 7; + case CoopVecComponentType::SignedInt8Packed: + return 8; + case CoopVecComponentType::UnsignedInt8Packed: + return 9; + case CoopVecComponentType::FloatE4M3: + return 10; + case CoopVecComponentType::FloatE5M2: + return 11; + default: + static_assert(false, "unsupported componentType value"); + } + return 32; +} + +[ForceInline] +uint32_t __coopVecComponentTypeStride(CoopVecComponentType componentType) +{ + switch (componentType) + { + case CoopVecComponentType::Float16: + return 2; + case CoopVecComponentType::Float32: + return 4; + case CoopVecComponentType::Float64: + return 8; + case CoopVecComponentType::SignedInt8: + return 1; + case CoopVecComponentType::SignedInt16: + return 2; + case CoopVecComponentType::SignedInt32: + return 4; + case CoopVecComponentType::SignedInt8Packed: + return 4; + case CoopVecComponentType::SignedInt64: + return 8; + case CoopVecComponentType::UnsignedInt8: + return 1; + case CoopVecComponentType::UnsignedInt16: + return 2; + case CoopVecComponentType::UnsignedInt32: + return 4; + case CoopVecComponentType::UnsignedInt8Packed: + return 4; + case CoopVecComponentType::UnsignedInt64: + return 8; + default: + static_assert(false, "unsupported componentType value"); + } + return 0xffffffff; +} + +${{{{ +static const struct { + bool isRW; + char const* type; +} kByteAddressBufferCases_[] = +{ + {true, "RWByteAddressBuffer"}, + {false, "ByteAddressBuffer"}, +}; +for(auto buffer : kByteAddressBufferCases_) { +}}}} + +// TODO: Can we ForceInline for just hlsl? the other platforms don't really +// need it +[ForceInline] +[require(cooperative_vector)] +__generic<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType> +CoopVec<T, M> coopVecMatMulPacked( + CoopVec<U, PackedK> input, + constexpr CoopVecComponentType inputInterpretation, + constexpr int k, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr CoopVecComponentType matrixInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK + , "for non-packed inputInterpretation values k must be equal to the input vector length"); + static_assert(!__isPackedInputInterpretation(inputInterpretation) + || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); + + __target_switch + { + case spirv: + let m : int32_t = M; + let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation); + let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); + let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); + let matrixPtr = matrix.GetBufferPointer(); + return spirv_asm + { + result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride; + }; + + case hlsl: + var ret = CoopVec<T, M>(0); + let inputInterpretationHLSL = __getHLSLCoopVecComponentType(inputInterpretation); + let matrixInterpretationHLSL = __getHLSLCoopVecComponentType(matrixInterpretation); + let memoryLayoutHLSL = __getHLSLCoopVecMatrixLayout(memoryLayout); + ret.__mutMatMul( + input, + inputInterpretationHLSL, + matrix, + matrixOffset, + matrixInterpretationHLSL, + M, + k, + memoryLayoutHLSL, + transpose, + matrixStride + ); + return ret; + + default: + var result = CoopVec<T, M>(0); + var v = CoopVec<T, PackedK*4>(); + // TODO: Insert language from the spec to describe this madness + if(k == PackedK) + { + for(int i = 0; i < k; ++i) + v[i] = __arithmetic_cast<T>(input[i]); + } + else + { + static_assert(k == PackedK*4, "K must be 4 * PackedK for the non-spirv coopVecMatMulPacked backend"); + static_assert(inputInterpretation == CoopVecComponentType::SignedInt8Packed || + inputInterpretation == CoopVecComponentType::UnsignedInt8Packed, + "Packing is only supported for 4*int8 or 4*uint8 vectors"); + for(int i = 0; i < k; ++i) + { + let n = __arithmetic_cast<int32_t>(input[i/4]); + let s = int8_t(n >> ((i % 4) * 8) & 0xff); + v[i] = T(s); + } + } + + for (int i = 0; i < M; ++i) + { + for (int j = 0; j < k; ++j) + { + int row = (transpose ^ memoryLayout == CoopVecMatrixLayout::ColumnMajor) ? j : i; + int col = (transpose ^ memoryLayout == CoopVecMatrixLayout::ColumnMajor) ? i : j; + int offset = matrixOffset + (row * matrixStride + col * __coopVecComponentTypeStride(matrixInterpretation)); + + switch (matrixInterpretation) + { + case CoopVecComponentType::Float16: + result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<half>(offset)) * v[j]; + break; + case CoopVecComponentType::Float32: + result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<float>(offset)) * v[j]; + break; + case CoopVecComponentType::Float64: + result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<double>(offset)) * v[j]; + break; + case CoopVecComponentType::SignedInt8: + result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<int8_t>(offset)) * v[j]; + break; + case CoopVecComponentType::SignedInt16: + result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<int16_t>(offset)) * v[j]; + break; + case CoopVecComponentType::SignedInt32: + result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<int32_t>(offset)) * v[j]; + break; + case CoopVecComponentType::SignedInt64: + result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<int64_t>(offset)) * v[j]; + break; + case CoopVecComponentType::SignedInt8Packed: + result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<int8_t>(offset)) * v[j]; + break; + case CoopVecComponentType::UnsignedInt8: + result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<uint8_t>(offset)) * v[j]; + break; + case CoopVecComponentType::UnsignedInt16: + result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<uint16_t>(offset)) * v[j]; + break; + case CoopVecComponentType::UnsignedInt32: + result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<uint32_t>(offset)) * v[j]; + break; + case CoopVecComponentType::UnsignedInt64: + result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<uint64_t>(offset)) * v[j]; + break; + case CoopVecComponentType::UnsignedInt8Packed: + result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<uint8_t>(offset)) * v[j]; + break; + } + } + } + + return result; + } +} + +[ForceInline] +[require(cooperative_vector)] +__generic<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType> +CoopVec<T, M> coopVecMatMul( + CoopVec<U, K> input, + constexpr CoopVecComponentType inputInterpretation, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr CoopVecComponentType matrixInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(!__isPackedInputInterpretation(inputInterpretation) + , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually"); + return coopVecMatMulPacked< + T, M, K, U>( + input, + inputInterpretation, + K, + matrix, + matrixOffset, + matrixInterpretation, + memoryLayout, + transpose, + matrixStride); +} + +[ForceInline] +[require(cooperative_vector)] +CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType>( + CoopVec<U, PackedK> input, + constexpr CoopVecComponentType inputInterpretation, + constexpr int k, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr CoopVecComponentType matrixInterpretation, + $(buffer.type) bias, + int32_t biasOffset, + constexpr CoopVecComponentType biasInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK + , "for non-packed inputInterpretation values k must be equal to the input vector length"); + static_assert(!__isPackedInputInterpretation(inputInterpretation) + || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); + + __target_switch + { + case spirv: + let m : int32_t = M; + let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation); + let biasInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(biasInterpretation); + let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); + let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); + let matrixPtr = matrix.GetBufferPointer(); + let biasPtr = bias.GetBufferPointer(); + return spirv_asm + { + result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $biasPtr $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride; + }; + + case hlsl: + var ret = CoopVec<T, M>(0); + let inputInterpretationHLSL = __getHLSLCoopVecComponentType(inputInterpretation); + let matrixInterpretationHLSL = __getHLSLCoopVecComponentType(matrixInterpretation); + let biasInterpretationHLSL = __getHLSLCoopVecComponentType(biasInterpretation); + let memoryLayoutHLSL = __getHLSLCoopVecMatrixLayout(memoryLayout); + ret.__mutMatMulAdd( + input, + inputInterpretationHLSL, + matrix, + matrixOffset, + matrixInterpretationHLSL, + bias, + biasOffset, + biasInterpretationHLSL, + M, + k, + memoryLayoutHLSL, + transpose, + matrixStride + ); + return ret; + + default: + var result = coopVecMatMulPacked<T, M, PackedK, U>( + input, + inputInterpretation, + k, + matrix, + matrixOffset, + matrixInterpretation, + memoryLayout, + transpose, + matrixStride); + + for (int i = 0; i < M; ++i) + { + int b = biasOffset + i * __coopVecComponentTypeStride(biasInterpretation); + switch (biasInterpretation) + { + case CoopVecComponentType::Float16: + result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<half>(b)); + break; + case CoopVecComponentType::Float32: + result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<float>(b)); + break; + case CoopVecComponentType::Float64: + result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<double>(b)); + break; + case CoopVecComponentType::SignedInt8: + result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<int8_t>(b)); + break; + case CoopVecComponentType::SignedInt16: + result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<int16_t>(b)); + break; + case CoopVecComponentType::SignedInt32: + result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<int32_t>(b)); + break; + case CoopVecComponentType::SignedInt64: + result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<int64_t>(b)); + break; + case CoopVecComponentType::SignedInt8Packed: + result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<int8_t>(b)); + break; + case CoopVecComponentType::UnsignedInt8: + result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<uint8_t>(b)); + break; + case CoopVecComponentType::UnsignedInt16: + result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<uint16_t>(b)); + break; + case CoopVecComponentType::UnsignedInt32: + result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<uint32_t>(b)); + break; + case CoopVecComponentType::UnsignedInt64: + result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<uint64_t>(b)); + break; + case CoopVecComponentType::UnsignedInt8Packed: + result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<uint8_t>(b)); + break; + } + } + + return result; + } +} + +[ForceInline] +[require(cooperative_vector)] +__generic<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType> +CoopVec<T, M> coopVecMatMulAdd( + CoopVec<U, K> input, + constexpr CoopVecComponentType inputInterpretation, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr CoopVecComponentType matrixInterpretation, + $(buffer.type) bias, + int32_t biasOffset, + constexpr CoopVecComponentType biasInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(!__isPackedInputInterpretation(inputInterpretation) + , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually"); + return coopVecMatMulAddPacked< + T, M, K, U>( + input, + inputInterpretation, + K, + matrix, + matrixOffset, + matrixInterpretation, + bias, + biasOffset, + biasInterpretation, + memoryLayout, + transpose, + matrixStride); +} + +// +// Coop Vector accumulation +// + +${{{{ +if(buffer.isRW) +{ +}}}} +[require(cooperative_vector)] +void coopVecOuterProductAccumulate<T : __BuiltinArithmeticType, let M : int, let N : int>( + CoopVec<T, M> a, + CoopVec<T, N> b, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr uint matrixStride, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr CoopVecComponentType matrixInterpretation, +) +{ + __target_switch + { + case spirv: + let matrixInterpretationSpirv : int = __getSpvCoopVecComponentType(matrixInterpretation); + let memoryLayoutSpirv : int = __getSpvCoopVecMatrixLayout(memoryLayout); + let matrixPtr = matrix.GetBufferPointer(); + spirv_asm + { + OpCapability CooperativeVectorTrainingNV; + OpCooperativeVectorOuterProductAccumulateNV $matrixPtr $matrixOffset $a $b $memoryLayoutSpirv $matrixInterpretationSpirv $matrixStride; + }; + default: + for (int i = 0; i < M; ++i) + { + for (int j = 0; j < N; ++j) + { + T product = a[i] * b[j]; + int row = (memoryLayout == CoopVecMatrixLayout::ColumnMajor) ? j : i; + int col = (memoryLayout == CoopVecMatrixLayout::ColumnMajor) ? i : j; + int offset = matrixOffset + (row * matrixStride + col * __coopVecComponentTypeStride(matrixInterpretation)); + + switch (matrixInterpretation) + { + case CoopVecComponentType::Float16: + matrix.StoreByteOffset<half>(offset, matrix.LoadByteOffset<half>(offset) + __arithmetic_cast<half>(product)); + break; + case CoopVecComponentType::Float32: + matrix.StoreByteOffset<float>(offset, matrix.LoadByteOffset<float>(offset) + __arithmetic_cast<float>(product)); + break; + case CoopVecComponentType::Float64: + matrix.StoreByteOffset<double>(offset, matrix.LoadByteOffset<double>(offset) + __arithmetic_cast<double>(product)); + break; + case CoopVecComponentType::SignedInt8: + matrix.StoreByteOffset<int8_t>(offset, matrix.LoadByteOffset<int8_t>(offset) + __arithmetic_cast<int8_t>(product)); + break; + case CoopVecComponentType::SignedInt16: + matrix.StoreByteOffset<int16_t>(offset, matrix.LoadByteOffset<int16_t>(offset) + __arithmetic_cast<int16_t>(product)); + break; + case CoopVecComponentType::SignedInt32: + matrix.StoreByteOffset<int32_t>(offset, matrix.LoadByteOffset<int32_t>(offset) + __arithmetic_cast<int32_t>(product)); + break; + case CoopVecComponentType::SignedInt64: + matrix.StoreByteOffset<int64_t>(offset, matrix.LoadByteOffset<int64_t>(offset) + __arithmetic_cast<int64_t>(product)); + break; + case CoopVecComponentType::SignedInt8Packed: + matrix.StoreByteOffset<int8_t>(offset, matrix.LoadByteOffset<int8_t>(offset) + __arithmetic_cast<int8_t>(product)); + break; + case CoopVecComponentType::UnsignedInt8: + matrix.StoreByteOffset<uint8_t>(offset, matrix.LoadByteOffset<uint8_t>(offset) + __arithmetic_cast<uint8_t>(product)); + break; + case CoopVecComponentType::UnsignedInt16: + matrix.StoreByteOffset<uint16_t>(offset, matrix.LoadByteOffset<uint16_t>(offset) + __arithmetic_cast<uint16_t>(product)); + break; + case CoopVecComponentType::UnsignedInt32: + matrix.StoreByteOffset<uint32_t>(offset, matrix.LoadByteOffset<uint32_t>(offset) + __arithmetic_cast<uint32_t>(product)); + break; + case CoopVecComponentType::UnsignedInt64: + matrix.StoreByteOffset<uint64_t>(offset, matrix.LoadByteOffset<uint64_t>(offset) + __arithmetic_cast<uint64_t>(product)); + break; + case CoopVecComponentType::UnsignedInt8Packed: + matrix.StoreByteOffset<uint8_t>(offset, matrix.LoadByteOffset<uint8_t>(offset) + __arithmetic_cast<uint8_t>(product)); + break; + } + } + } + } +} + +[require(cooperative_vector)] +void coopVecReduceSumAccumulate<T : __BuiltinArithmeticType, let N : int>( + CoopVec<T, N> v, + $(buffer.type) buffer, + int32_t offset +) +{ + __target_switch + { + case spirv: + let bufferPtr = buffer.GetBufferPointer(); + spirv_asm + { + OpCapability CooperativeVectorTrainingNV; + OpCooperativeVectorReduceSumAccumulateNV $bufferPtr $offset $v; + }; + default: + for (int i = 0; i < N; ++i) + { + int byteOffset = offset + i * __naturalStrideOf<T>(); + T currentValue = buffer.LoadByteOffset<T>(byteOffset); + T newValue = currentValue + __arithmetic_cast<T>(v[i]); + buffer.StoreByteOffset(byteOffset, newValue); + } + } +} + +${{{{ +} // if rw +} // buffer type loop +}}}} + +${{{{ +static const struct { + bool isRW; + char const* type; +} kStructuredBufferCases_[] = +{ + {true, "RWStructuredBuffer<IgnoredBufferElementType>"}, + {false, "StructuredBuffer<IgnoredBufferElementType>"}, +}; +for(auto buffer : kStructuredBufferCases_) { +}}}} + +[require(spirv, cooperative_vector)] +__generic<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType,IgnoredBufferElementType> +CoopVec<T, M> coopVecMatMulPacked( + CoopVec<U, PackedK> input, + constexpr CoopVecComponentType inputInterpretation, + constexpr int k, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr CoopVecComponentType matrixInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK + , "for non-packed inputInterpretation values k must be equal to the input vector length"); + static_assert(!__isPackedInputInterpretation(inputInterpretation) + || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); + __target_switch + { + case spirv: + let m : int32_t = M; + let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation); + let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); + let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); + let matrixPtr = __getStructuredBufferPtr(matrix); + return spirv_asm + { + result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride; + }; + } +} + +// specialized coopVecMatMul for non-packed inputs +[require(spirv, cooperative_vector)] +__generic<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType,IgnoredBufferElementType> +CoopVec<T, M> coopVecMatMul( + CoopVec<U, K> input, + constexpr CoopVecComponentType inputInterpretation, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr CoopVecComponentType matrixInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(!__isPackedInputInterpretation(inputInterpretation) + , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually"); + return coopVecMatMulPacked< + T, M, K, U, IgnoredBufferElementType>( + input, + inputInterpretation, + K, + matrix, + matrixOffset, + matrixInterpretation, + memoryLayout, + transpose, + matrixStride); +} + +[require(spirv, cooperative_vector)] +CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType, IgnoredBufferElementType>( + CoopVec<U, PackedK> input, + constexpr CoopVecComponentType inputInterpretation, + constexpr int k, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr CoopVecComponentType matrixInterpretation, + $(buffer.type) bias, + int32_t biasOffset, + constexpr CoopVecComponentType biasInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK + , "for non-packed inputInterpretation values k must be equal to the input vector length"); + static_assert(!__isPackedInputInterpretation(inputInterpretation) + || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); + + __target_switch + { + case spirv: + let m : int32_t = M; + let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation); + let biasInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(biasInterpretation); + let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); + let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); + let matrixPtr = __getStructuredBufferPtr(matrix); + let biasPtr = __getStructuredBufferPtr(bias); + return spirv_asm + { + result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $biasPtr $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride; + }; + } +} + +[require(spirv, cooperative_vector)] +CoopVec<T, M> coopVecMatMulAdd<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType, IgnoredBufferElementType>( + CoopVec<U, K> input, + constexpr CoopVecComponentType inputInterpretation, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr CoopVecComponentType matrixInterpretation, + $(buffer.type) bias, + int32_t biasOffset, + constexpr CoopVecComponentType biasInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(!__isPackedInputInterpretation(inputInterpretation) + , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually"); + return coopVecMatMulAddPacked< + T, M, K, U, IgnoredBufferElementType>( + input, + inputInterpretation, + K, + matrix, + matrixOffset, + matrixInterpretation, + bias, + biasOffset, + biasInterpretation, + memoryLayout, + transpose, + matrixStride); +} + +// +// Coop Vector accumulation +// + +${{{{ +if(buffer.isRW) +{ +}}}} +[require(spirv, cooperative_vector_training)] +void coopVecOuterProductAccumulate<T : __BuiltinArithmeticType, let M : int, let N : int, IgnoredBufferElementType>( + CoopVec<T, M> a, + CoopVec<T, N> b, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr uint matrixStride, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr CoopVecComponentType matrixInterpretation, +) +{ + __target_switch + { + case spirv: + let matrixInterpretationSpirv : int = __getSpvCoopVecComponentType(matrixInterpretation); + let memoryLayoutSpirv : int = __getSpvCoopVecMatrixLayout(memoryLayout); + let matrixPtr = __getStructuredBufferPtr(matrix); + spirv_asm + { + OpCapability CooperativeVectorTrainingNV; + OpCooperativeVectorOuterProductAccumulateNV $matrixPtr $matrixOffset $a $b $memoryLayoutSpirv $matrixInterpretationSpirv $matrixStride; + }; + } +} + +[require(spirv, cooperative_vector_training)] +void coopVecReduceSumAccumulate<T : __BuiltinArithmeticType, let N : int, IgnoredBufferElementType>( + CoopVec<T, N> v, + $(buffer.type) buffer, + int32_t offset +) +{ + __target_switch + { + case spirv: + let bufferPtr = __getStructuredBufferPtr(buffer); + spirv_asm + { + OpCapability CooperativeVectorTrainingNV; + OpCooperativeVectorReduceSumAccumulateNV $bufferPtr $offset $v; + }; + } +} + +${{{{ +} // if rw +} // buffer type loop +}}}} + + +[require(spirv, cooperative_vector_training)] +void coopVecOuterProductAccumulate<T : __BuiltinArithmeticType, let M : int, let N : int, U : __BuiltinArithmeticType, let IgnoredBufferSize : int>( + CoopVec<T, M> a, + CoopVec<T, N> b, + __ref groupshared U[IgnoredBufferSize] matrix, + int32_t matrixOffset, + constexpr uint matrixStride, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr CoopVecComponentType matrixInterpretation, +) +{ + __target_switch + { + case spirv: + let matrixInterpretationSpirv : int = __getSpvCoopVecComponentType(matrixInterpretation); + let memoryLayoutSpirv : int = __getSpvCoopVecMatrixLayout(memoryLayout); + spirv_asm + { + OpCapability CooperativeVectorTrainingNV; + OpCooperativeVectorOuterProductAccumulateNV &matrix $matrixOffset $a $b $memoryLayoutSpirv $matrixInterpretationSpirv $matrixStride; + }; + } +} + +[require(spirv, cooperative_vector_training)] +void coopVecReduceSumAccumulate<T : __BuiltinArithmeticType, let N : int, U, let IgnoredBufferSize : int>( + CoopVec<T, N> v, + __ref groupshared U[IgnoredBufferSize] buffer, + int32_t offset +) +{ + __target_switch + { + case spirv: + spirv_asm + { + OpCapability CooperativeVectorTrainingNV; + OpCooperativeVectorReduceSumAccumulateNV &buffer $offset $v; + }; + } +} + //@public: /// Mark a variable as being workgroup uniform. diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 56e07d430..2581630dd 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -170,6 +170,11 @@ enum : ConversionCost kConversionCost_ScalarIntegerToFloatMatrix = kConversionCost_IntegerToFloatConversion + kConversionCost_ScalarToMatrix, + // Additional conversion cost to add when promoting from a scalar to + // a CoopVector (this will be added to the cost, if any, of converting + // the element type of the CoopVector) + kConversionCost_ScalarToCoopVector = 1, + // Additional cost when casting an LValue. kConversionCost_LValueCast = 800, diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 23d89957a..29a52a93a 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -355,6 +355,29 @@ Type* AtomicType::getElementType() return as<Type>(_getGenericTypeArg(this, 0)); } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! CoopVectorExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +Type* CoopVectorExpressionType::getElementType() +{ + return as<Type>(_getGenericTypeArg(this, 0)); +} + +IntVal* CoopVectorExpressionType::getElementCount() +{ + return as<IntVal>(_getGenericTypeArg(this, 1)); +} + +void CoopVectorExpressionType::_toTextOverride(StringBuilder& out) +{ + out << toSlice("CoopVector<") << getElementType() << toSlice(",") << getElementCount() + << toSlice(">"); +} + +BasicExpressionType* CoopVectorExpressionType::_getScalarTypeOverride() +{ + return as<BasicExpressionType>(getElementType()); +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void TypeType::_toTextOverride(StringBuilder& out) diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 9d4297919..f60c0485e 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -475,6 +475,17 @@ class AtomicType : public DeclRefType Type* getElementType(); }; +class CoopVectorExpressionType : public ArithmeticExpressionType +{ + SLANG_AST_CLASS(CoopVectorExpressionType) + + void _toTextOverride(StringBuilder& out); + BasicExpressionType* _getScalarTypeOverride(); + + Type* getElementType(); + IntVal* getElementCount(); +}; + // The "type" of an expression that resolves to a type. // For example, in the expression `float(2)` the sub-expression, // `float` would have the type `TypeType(float)`. diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index 3bc54c080..3be09b8d3 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -208,6 +208,7 @@ def _sm_6_4 : _sm_6_3; def _sm_6_5 : _sm_6_4; def _sm_6_6 : _sm_6_5; def _sm_6_7 : _sm_6_6; +def _sm_6_8 : _sm_6_7; /// Represents HLSL NVAPI support. /// [Version] @@ -524,6 +525,14 @@ def SPV_NV_compute_shader_derivatives : _spirv_1_0; /// [EXT] def SPV_GOOGLE_user_type : _spirv_1_0; +/// Represents the SPIR-V extension for SPV_EXT_replicated_composites. +/// [EXT] +def SPV_EXT_replicated_composites : _spirv_1_0; + +/// Represents the SPIR-V extension for SPV_NV_cooperative_vector. +/// [EXT] +def SPV_NV_cooperative_vector : _spirv_1_6 + SPV_EXT_replicated_composites; + // SPIRV Capabilities. /// Represents the SPIR-V capability for atomic float 32 add operations. @@ -662,6 +671,18 @@ def spvDemoteToHelperInvocationEXT : SPV_EXT_demote_to_helper_invocation; /// [EXT] def spvDemoteToHelperInvocation : spvDemoteToHelperInvocationEXT; +/// Represents the SPIR-V capability for replicated composites +/// [EXT] +def spvReplicatedCompositesEXT : SPV_EXT_replicated_composites; + +/// Represents the SPIR-V capability for cooperative vectors +/// [EXT] +def spvCooperativeVectorNV : SPV_NV_cooperative_vector; + +/// Represents the SPIR-V capability for cooperative vector training +/// [EXT] +def spvCooperativeVectorTrainingNV : SPV_NV_cooperative_vector; + /// Represents the SPIR-V capability for maximal reconvergence. /// [EXT] def spvMaximalReconvergenceKHR : SPV_KHR_maximal_reconvergence; @@ -1033,6 +1054,14 @@ alias bufferreference = GL_EXT_buffer_reference; /// Capabilities needed to use GLSL buffer-reference's with int64 /// [Compound] alias bufferreference_int64 = bufferreference + GL_EXT_shader_explicit_arithmetic_types_int64; +/// Capabilities needed to use cooperative vectors +/// Note that cpp and cuda are supported via a fallback non-cooperative implementation +/// No HLSL shader model bound yet +/// [Compound] +alias cooperative_vector = _sm_6_8 | cpp | _cuda_sm_9_0 | spvCooperativeVectorNV; +/// Capabilities needed to train cooperative vectors +/// [Compound] +alias cooperative_vector_training = spvCooperativeVectorTrainingNV; // Non-internal shader stages // @@ -1479,6 +1508,25 @@ alias sm_6_7 = sm_6_7_version | sm_6_6 ; +/// HLSL shader model 6.8 and related capabilities of other targets. +/// Does not include related GLSL/SPIRV extensions. +/// [Version] +alias sm_6_8_version = _sm_6_8 + | _GLSL_460 + | spirv_1_5 + | _cuda_sm_9_0 + | metal + | cpp + ; + +/// HLSL shader model 6.8 and related capabilities of other targets. +/// Includes related GLSL/SPIRV extensions. +/// [Version] +alias sm_6_8 = sm_6_8_version + | sm_6_7 + ; + +// Profiles /// Use `sm_4_0` instead /// [Other] alias DX_4_0 = sm_4_0; @@ -1527,6 +1575,10 @@ alias DX_6_6 = sm_6_6; /// [Other] alias DX_6_7 = sm_6_7; +/// Use `sm_6_8` instead +/// [Other] +alias DX_6_8 = sm_6_8; + // GLSL profile capabilities // /// GLSL 130 and related capabilities of other targets. @@ -2042,10 +2094,10 @@ alias ser_motion_raygen = raygen + ser_motion; /// User should not use this capability /// [Other] -alias all = _sm_6_7 + hlsl_nvapi - | sm_6_7 +alias all = _sm_6_8 + hlsl_nvapi + | glsl_spirv_1_5 + sm_6_8 + ser + shaderclock + texturefootprint + fragmentshaderinterlock + _GL_NV_shader_subgroup_partitioned + _GL_NV_ray_tracing_motion_blur + _GL_NV_shader_texture_footprint - | spirv_1_5 + sm_6_7 + | spirv_1_5 + sm_6_8 + ser + shaderclock + texturefootprint + fragmentshaderinterlock + spvGroupNonUniformPartitionedNV + spvRayTracingMotionBlurNV + spvRayTracingMotionBlurNV; diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 248c83fe5..db8548dce 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -34,6 +34,8 @@ BuiltinConversionKind SemanticsVisitor::getImplicitConversionBuiltinKind(Decl* d bool SemanticsVisitor::isEffectivelyScalarForInitializerLists(Type* type) { + if (as<CoopVectorExpressionType>(type)) + return false; if (as<ArrayExpressionType>(type)) return false; if (as<VectorExpressionType>(type)) @@ -282,6 +284,50 @@ bool SemanticsVisitor::_readAggregateValueFromInitializerList( } } } + else if (auto toCoopVectorType = as<CoopVectorExpressionType>(toType)) + { + auto toElementCount = toCoopVectorType->getElementCount(); + auto toElementType = toCoopVectorType->getElementType(); + + UInt elementCount = 0; + if (auto constElementCount = as<ConstantIntVal>(toElementCount)) + { + elementCount = (UInt)constElementCount->getValue(); + } + else + { + // We don't know the element count statically, + // so what are we supposed to be doing? + // + if (outToExpr) + { + getSink()->diagnose( + fromInitializerListExpr, + Diagnostics::cannotUseInitializerListForCoopVectorOfUnknownSize, + toElementCount); + } + return false; + } + + for (UInt ee = 0; ee < elementCount; ++ee) + { + Expr* coercedArg = nullptr; + bool argResult = _readValueFromInitializerList( + toElementType, + outToExpr ? &coercedArg : nullptr, + fromInitializerListExpr, + ioArgIndex); + + // No point in trying further if any argument fails + if (!argResult) + return false; + + if (coercedArg) + { + coercedArgs.add(coercedArg); + } + } + } else if (auto toArrayType = as<ArrayExpressionType>(toType)) { // TODO(tfoley): If we can compute the size of the array statically, diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 5ace72708..448534ce8 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -871,6 +871,7 @@ String GetHLSLProfileName(Profile profile) CASE(DX_6_5, _6_5); CASE(DX_6_6, _6_6); CASE(DX_6_7, _6_7); + CASE(DX_6_8, _6_8); #undef CASE default: diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 345be7a54..682980107 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -1415,6 +1415,11 @@ DIAGNOSTIC( Error, cannotUseInitializerListForType, "cannot use initializer list for type '$0'") +DIAGNOSTIC( + 30505, + Error, + cannotUseInitializerListForCoopVectorOfUnknownSize, + "cannot use initializer list for CoopVector of statically unknown size '$0'") // 3062x: variables DIAGNOSTIC( diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 180ef9909..db2c0150f 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1469,6 +1469,8 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) case kIROp_MakeArray: case kIROp_swizzleSet: case kIROp_MakeArrayFromElement: + case kIROp_MakeCoopVector: + return false; } @@ -2362,6 +2364,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO emitSimpleValue(inst); break; + case kIROp_MakeCoopVector: case kIROp_MakeVector: case kIROp_MakeMatrix: case kIROp_VectorReshape: @@ -3096,6 +3099,51 @@ void CLikeSourceEmitter::_emitInst(IRInst* inst) m_writer->advanceToSourceLocation(inst->sourceLoc); } + if (auto coopVecType = as<IRCoopVectorType>(inst->getDataType())) + { + switch (inst->getOp()) + { + case kIROp_MakeCoopVector: + { + emitType(coopVecType, getName(inst)); + m_writer->emit(";\n"); + + auto elemCount = as<IRIntLit>(coopVecType->getOperand(1)); + IRIntegerValue elemCountValue = elemCount->getValue(); + for (IRIntegerValue i = 0; i < elemCountValue; ++i) + { + m_writer->emit(getName(inst)); + m_writer->emit(".WriteToIndex("); + m_writer->emit(i); + m_writer->emit(", "); + emitDereferenceOperand(inst->getOperand(i), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + } + return; + } + case kIROp_Call: + emitType(coopVecType, getName(inst)); + m_writer->emit(";\n"); + + m_writer->emit(getName(inst)); + m_writer->emit(".CopyFrom("); + emitCallExpr((IRCall*)inst, getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return; + case kIROp_Load: + emitType(coopVecType, getName(inst)); + m_writer->emit(";\n"); + + m_writer->emit(getName(inst)); + m_writer->emit(".CopyFrom("); + emitDereferenceOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return; + default: + break; + } + } + switch (inst->getOp()) { default: @@ -3342,11 +3390,21 @@ void CLikeSourceEmitter::_emitStoreImpl(IRStore* store) { auto srcVal = store->getVal(); auto dstPtr = store->getPtr(); - auto prec = getInfo(EmitOp::Assign); - emitDereferenceOperand(dstPtr, leftSide(getInfo(EmitOp::General), prec)); - m_writer->emit(" = "); - emitOperand(srcVal, rightSide(prec, getInfo(EmitOp::General))); - m_writer->emit(";\n"); + if (isPointerOfType(dstPtr->getDataType(), kIROp_CoopVectorType)) + { + emitDereferenceOperand(dstPtr, getInfo(EmitOp::General)); + m_writer->emit(".CopyFrom("); + emitDereferenceOperand(srcVal, getInfo(EmitOp::General)); + m_writer->emit(");\n"); + } + else + { + auto prec = getInfo(EmitOp::Assign); + emitDereferenceOperand(dstPtr, leftSide(getInfo(EmitOp::General), prec)); + m_writer->emit(" = "); + emitOperand(srcVal, rightSide(prec, getInfo(EmitOp::General))); + m_writer->emit(";\n"); + } } void CLikeSourceEmitter::_emitInstAsDefaultInitializedVar(IRInst* inst, IRType* type) @@ -4627,7 +4685,45 @@ void CLikeSourceEmitter::emitVar(IRVar* varDecl) { if (store->getPtr() == varDecl) { - _emitInstAsVarInitializerImpl(store->getVal()); + const bool isCoopVectorType = varType->getOp() == kIROp_CoopVectorType; + if (isCoopVectorType && store->getVal()->getOp() == kIROp_Load) + { + m_writer->emit(";\n"); + m_writer->emit(getName(varDecl)); + m_writer->emit(".CopyFrom("); + emitDereferenceOperand(store->getVal()->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + } + else if (isCoopVectorType && store->getVal()->getOp() == kIROp_Call) + { + m_writer->emit(";\n"); + m_writer->emit(getName(varDecl)); + m_writer->emit(".CopyFrom("); + emitCallExpr((IRCall*)store->getVal(), getInfo(EmitOp::General)); + m_writer->emit(")"); + } + else if (isCoopVectorType && store->getVal()->getOp() == kIROp_MakeCoopVector) + { + auto coopVecType = as<IRCoopVectorType>(store->getVal()->getDataType()); + auto elemCount = as<IRIntLit>(coopVecType->getOperand(1)); + IRIntegerValue elemCountValue = elemCount->getValue(); + for (IRIntegerValue i = 0; i < elemCountValue; ++i) + { + m_writer->emit(";\n"); + m_writer->emit(getName(varDecl)); + m_writer->emit(".WriteToIndex("); + m_writer->emit(i); + m_writer->emit(", "); + emitDereferenceOperand( + store->getVal()->getOperand(i), + getInfo(EmitOp::General)); + m_writer->emit(")"); + } + } + else + { + _emitInstAsVarInitializerImpl(store->getVal()); + } } } diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 61457d98e..1429139f9 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -3370,6 +3370,23 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) return; } + else if (auto specializedType = as<IRSpecialize>(type)) + { + // If a `specialize` instruction made it this far, then + // it represents an intrinsic generic type. + // + emitSimpleType((IRType*)getSpecializedValue(specializedType)); + m_writer->emit("<"); + UInt argCount = specializedType->getArgCount(); + for (UInt ii = 0; ii < argCount; ++ii) + { + if (ii != 0) + m_writer->emit(", "); + emitVal(specializedType->getArg(ii), getInfo(EmitOp::General)); + } + m_writer->emit(" >"); + return; + } auto decorated = getResolvedInstForDecorations(type); UnownedStringSlice intrinsicDef; diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 7bd3bb3db..53545b31f 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -765,6 +765,7 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu m_writer->emit("GroupMemoryBatrierWithGroupSync();\n"); return true; } + case kIROp_MakeCoopVector: case kIROp_MakeVector: case kIROp_MakeMatrix: { @@ -1359,6 +1360,16 @@ void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) emitSimpleTypeImpl(cast<IRAtomicType>(type)->getElementType()); return; } + case kIROp_CoopVectorType: + { + auto coopVecType = (IRCoopVectorType*)type; + m_writer->emit("CoopVector<"); + emitType(coopVecType->getElementType()); + m_writer->emit(","); + m_writer->emit(getIntVal(coopVecType->getElementCount())); + m_writer->emit(">"); + return; + } case kIROp_ConstRefType: { emitSimpleTypeImpl(as<IRConstRefType>(type)->getValueType()); diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h index 6b99a7f50..07c7af3d2 100644 --- a/source/slang/slang-emit-hlsl.h +++ b/source/slang/slang-emit-hlsl.h @@ -82,7 +82,7 @@ protected: virtual void emitPostKeywordTypeAttributesImpl(IRInst* inst) SLANG_OVERRIDE; - void _emitPrefixTypeAttr(IRAttr* attr) SLANG_OVERRIDE; + virtual void _emitPrefixTypeAttr(IRAttr* attr) SLANG_OVERRIDE; // Emit a single `register` semantic, as appropriate for a given resource-type-specific layout // info Keyword to use in the uniform case (`register` for globals, `packoffset` inside a diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h index 1f2996646..8c5316f51 100644 --- a/source/slang/slang-emit-spirv-ops.h +++ b/source/slang/slang-emit-spirv-ops.h @@ -138,6 +138,19 @@ SpvInst* emitOpTypeVector( componentCount); } +template<typename T1, typename T2> +SpvInst* emitOpTypeCoopVec(IRInst* inst, const T1& componentType, const T2& componentCount) +{ + static_assert(isSingular<T1>); + return emitInstMemoized( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + inst, + SpvOpTypeCooperativeVectorNV, + kResultID, + componentType, + componentCount); +} + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypeMatrix template<typename T> SpvInst* emitOpTypeMatrix(IRInst* inst, const T& columnType, const SpvLiteralInteger& columnCount) diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 24d8cc0c6..2c36ae5f7 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1628,6 +1628,16 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex static_cast<IRIntLit*>(vectorType->getElementCount())->getValue(), vectorType); } + case kIROp_CoopVectorType: + { + auto coopVecType = static_cast<IRCoopVectorType*>(inst); + requireSPIRVCapability(SpvCapabilityCooperativeVectorNV); + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_cooperative_vector")); + return ensureCoopVecType( + static_cast<IRBasicType*>(coopVecType->getElementType())->getBaseType(), + static_cast<IRIntLit*>(coopVecType->getElementCount())->getValue(), + coopVecType); + } case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(inst); @@ -1778,6 +1788,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex numElems->getValue()); } case kIROp_MakeVector: + case kIROp_MakeCoopVector: case kIROp_MakeArray: case kIROp_MakeStruct: return emitCompositeConstruct(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst); @@ -2361,6 +2372,27 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return result; } + /// Similar to ensureVectorType but for CoopVecType + SpvInst* ensureCoopVecType( + BaseType baseType, + IRIntegerValue elementCount, + IRCoopVectorType* inst) + { + IRBuilder builder(m_irModule); + if (!inst) + { + builder.setInsertInto(m_irModule->getModuleInst()); + inst = builder.getCoopVectorType( + builder.getBasicType(baseType), + builder.getIntValue(builder.getIntType(), elementCount)); + } + auto result = emitOpTypeCoopVec( + inst, + inst->getElementType(), + emitIntConstant(elementCount, builder.getIntType())); + return result; + } + bool _maybeEmitInterpolationModifierDecoration(IRInterpolationMode mode, SpvId varInst) { switch (mode) @@ -3417,6 +3449,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_StructuredBufferGetDimensions: result = emitStructuredBufferGetDimensions(parent, inst); break; + case kIROp_GetStructuredBufferPtr: + case kIROp_GetUntypedBufferPtr: + result = emitGetBufferPtr(parent, inst); + break; case kIROp_swizzle: result = emitSwizzle(parent, as<IRSwizzle>(inst)); break; @@ -3711,6 +3747,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex result = emitSplat(parent, inst, scalar, numElems->getValue()); } break; + case kIROp_MakeCoopVector: + result = emitConstruct(parent, inst); + break; case kIROp_MakeArray: result = emitConstruct(parent, inst); break; @@ -6079,7 +6118,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto baseTy = base->getDataType(); SLANG_ASSERT( as<IRPointerLikeType>(baseTy) || as<IRArrayType>(baseTy) || as<IRVectorType>(baseTy) || - as<IRMatrixType>(baseTy)); + as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy)); IRBuilder builder(m_irModule); builder.setInsertBefore(inst); @@ -6097,7 +6136,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } else { - SLANG_ASSERT(as<IRVectorType>(baseTy)); + SLANG_ASSERT(as<IRVectorType>(baseTy) || as<IRCoopVectorType>(baseTy)); // SPIRV Only allows dynamic element extract on vector types. return emitOpVectorExtractDynamic( parent, @@ -6307,6 +6346,27 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return result; } + SpvInst* emitGetBufferPtr(SpvInstParent* parent, IRInst* inst) + { + IRBuilder builder(inst); + auto addressSpace = + isSpirv14OrLater() ? AddressSpace::StorageBuffer : AddressSpace::Uniform; + // The buffer is a global parameter, so it's a pointer + IRPtrTypeBase* bufPtrType = cast<IRPtrTypeBase>(inst->getOperand(0)->getDataType()); + // It's lowered to a struct type.. + IRStructType* bufType = cast<IRStructType>(bufPtrType->getValueType()); + // containing an unsized array, specifically one with an explicit + // stride, which is not expressible in spirv_asm blocks + IRArrayTypeBase* arrayType = + cast<IRArrayTypeBase>(bufType->getFields().getFirst()->getFieldType()); + return emitOpAccessChain( + parent, + inst, + builder.getPtrType(arrayType, addressSpace), + inst->getOperand(0), + makeArray(emitIntConstant(0, builder.getIntType()))); + } + SpvInst* emitSwizzle(SpvInstParent* parent, IRSwizzle* inst) { if (inst->getElementCount() == 1) @@ -6478,7 +6538,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex IRType* toType = nullptr; bool isMatrixCast = false; - if (as<IRVectorType>(fromTypeV) || as<IRVectorType>(toTypeV)) + if (as<IRVectorType>(fromTypeV) || as<IRVectorType>(toTypeV) || + as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV)) { fromType = getVectorElementType(fromTypeV); toType = getVectorElementType(toTypeV); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index cf4070df4..d8594ffdd 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -66,6 +66,7 @@ #include "slang-ir-lower-bit-cast.h" #include "slang-ir-lower-buffer-element-type.h" #include "slang-ir-lower-combined-texture-sampler.h" +#include "slang-ir-lower-coopvec.h" #include "slang-ir-lower-dynamic-resource-heap.h" #include "slang-ir-lower-generics.h" #include "slang-ir-lower-glsl-ssbo-types.h" @@ -1016,6 +1017,16 @@ Result linkAndOptimizeIR( #endif validateIRModuleIfEnabled(codeGenContext, irModule); + switch (target) + { + case CodeGenTarget::SPIRV: + case CodeGenTarget::SPIRVAssembly: + case CodeGenTarget::HLSL: + break; + default: + lowerCooperativeVectors(irModule, sink); + } + // Inline calls to any functions marked with [__unsafeInlineEarly] or [ForceInline]. performForceInlining(irModule); @@ -1476,14 +1487,15 @@ Result linkAndOptimizeIR( { default: break; + case CodeGenTarget::HLSL: case CodeGenTarget::GLSL: case CodeGenTarget::WGSL: - moveGlobalVarInitializationToEntryPoints(irModule); + moveGlobalVarInitializationToEntryPoints(irModule, targetProgram); break; // For SPIR-V to SROA across 2 entry-points a value must not be a global case CodeGenTarget::SPIRV: case CodeGenTarget::SPIRVAssembly: - moveGlobalVarInitializationToEntryPoints(irModule); + moveGlobalVarInitializationToEntryPoints(irModule, targetProgram); if (targetProgram->getOptionSet().getBoolOption( CompilerOptionName::EnableExperimentalPasses)) introduceExplicitGlobalContext(irModule, target); @@ -1494,7 +1506,7 @@ Result linkAndOptimizeIR( case CodeGenTarget::Metal: case CodeGenTarget::CPPSource: case CodeGenTarget::CUDASource: - moveGlobalVarInitializationToEntryPoints(irModule); + moveGlobalVarInitializationToEntryPoints(irModule, targetProgram); introduceExplicitGlobalContext(irModule, target); if (target == CodeGenTarget::CPPSource) { diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 7507e2fac..d65a22e77 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -2619,6 +2619,7 @@ bool canTypeBeStored(IRInst* type) case kIROp_ClassType: case kIROp_FloatType: case kIROp_VectorType: + case kIROp_CoopVectorType: case kIROp_MatrixType: case kIROp_BackwardDiffIntermediateContextType: return true; diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp index 6c297fd30..ff6d64319 100644 --- a/source/slang/slang-ir-constexpr.cpp +++ b/source/slang/slang-ir-constexpr.cpp @@ -104,6 +104,7 @@ bool opCanBeConstExpr(IROp op) case kIROp_MakeMatrix: case kIROp_MakeMatrixFromScalar: case kIROp_MatrixReshape: + case kIROp_MakeCoopVector: case kIROp_VectorReshape: case kIROp_CastFloatToInt: case kIROp_CastIntToFloat: diff --git a/source/slang/slang-ir-explicit-global-init.cpp b/source/slang/slang-ir-explicit-global-init.cpp index 97a04f48a..8945f0c5f 100644 --- a/source/slang/slang-ir-explicit-global-init.cpp +++ b/source/slang/slang-ir-explicit-global-init.cpp @@ -39,6 +39,7 @@ namespace Slang struct MoveGlobalVarInitializationToEntryPointsPass { IRModule* m_module; + TargetProgram* m_targetProgram; // In the Slang IR, a global variable represents a pointer // to the storage for the variable but it *also* encodes @@ -66,9 +67,10 @@ struct MoveGlobalVarInitializationToEntryPointsPass }; List<GlobalVarInfo> m_globalVarsWithInit; - void processModule(IRModule* module) + void processModule(IRModule* module, TargetProgram* targetProgram) { m_module = module; + m_targetProgram = targetProgram; // We start by looking for global variables with // initialization logic in the IR, and processing @@ -113,8 +115,32 @@ struct MoveGlobalVarInitializationToEntryPointsPass } } + bool shouldMoveGlobalVarInitialization(IRGlobalVar* globalVar) + { + // Currently CoopVector for DXC cannot be created from + // constructors with arguments. When CoopVector is used as a + // global variable, its initialization has to happen at the + // beginning of the entry point. + // + // At the same time, we don't want to apply + // "moveGlobalVarInitializationToEntryPoints" to the rest of + // the global variables when targeting HLSL. + // + if (isD3DTarget(m_targetProgram->getTargetReq())) + { + auto valueType = globalVar->getDataType()->getValueType(); + if (as<IRCoopVectorType>(valueType)) + return true; + return false; + } + return true; + } + void processGlobalVarWithInit(IRGlobalVar* globalVar, IRBlock* firstBlock) { + if (!shouldMoveGlobalVarInitialization(globalVar)) + return; + IRBuilder builder(m_module); builder.setInsertBefore(globalVar); @@ -216,10 +242,10 @@ struct MoveGlobalVarInitializationToEntryPointsPass }; /// Move initialization logic off of global variables and onto each entry point -void moveGlobalVarInitializationToEntryPoints(IRModule* module) +void moveGlobalVarInitializationToEntryPoints(IRModule* module, TargetProgram* targetProgram) { MoveGlobalVarInitializationToEntryPointsPass pass; - pass.processModule(module); + pass.processModule(module, targetProgram); } } // namespace Slang diff --git a/source/slang/slang-ir-explicit-global-init.h b/source/slang/slang-ir-explicit-global-init.h index 3244d90a2..8554a7e3b 100644 --- a/source/slang/slang-ir-explicit-global-init.h +++ b/source/slang/slang-ir-explicit-global-init.h @@ -4,7 +4,8 @@ namespace Slang { struct IRModule; +class TargetProgram; /// Move initialization logic off of global variables and onto each entry point -void moveGlobalVarInitializationToEntryPoints(IRModule* module); +void moveGlobalVarInitializationToEntryPoints(IRModule* module, TargetProgram* targetProgram); } // namespace Slang diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 591ffabf0..f1e9624f3 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -234,6 +234,7 @@ INST(Nop, nop, 0, 0) INST(RayQueryType, RayQuery, 1, HOISTABLE) INST(HitObjectType, HitObject, 0, HOISTABLE) +INST(CoopVectorType, CoopVectorType, 2, HOISTABLE) // Opaque type that can be dynamically cast to other resource types. INST(DynamicResourceType, DynamicResource, 0, HOISTABLE) @@ -363,6 +364,8 @@ INST(MatrixReshape, matrixReshape, 1, 0) INST(VectorReshape, vectorReshape, 1, 0) INST(MakeArray, makeArray, 0, 0) INST(MakeArrayFromElement, makeArrayFromElement, 1, 0) +INST(MakeCoopVector, makeCoopVector, 0, 0) +INST(MakeCoopVectorFromValuePack, makeCoopVectorFromValuePack, 1, 0) INST(MakeStruct, makeStruct, 0, 0) INST(MakeTuple, makeTuple, 0, 0) INST(MakeTargetTuple, makeTuple, 0, 0) @@ -1250,6 +1253,11 @@ INST(CudaKernelLaunch, CudaKernelLaunch, 6, 0) // Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) +// Gets a T[] pointer to the underlying data of a StructuredBuffer etc... +INST(GetStructuredBufferPtr, getStructuredBufferPtr, 1, 0) +// Gets a uint[] pointer to the underlying data of a ByteAddressBuffer etc... +INST(GetUntypedBufferPtr, getUntypedBufferPtr, 1, 0) + /* Layout */ INST(VarLayout, varLayout, 1, HOISTABLE) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index caee45ba8..a883172ff 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3018,6 +3018,11 @@ struct IRMakeVectorFromScalar : IRInst IR_LEAF_ISA(MakeVectorFromScalar) }; +struct IRMakeCoopVector : IRInst +{ + IR_LEAF_ISA(MakeCoopVector) +}; + // An Instruction that creates a differential pair value from a // primal and differential. @@ -3735,6 +3740,8 @@ public: IRVectorType* getVectorType(IRType* elementType, IRIntegerValue elementCount); + IRCoopVectorType* getCoopVectorType(IRType* elementType, IRInst* elementCount); + IRMatrixType* getMatrixType( IRType* elementType, IRInst* rowCount, @@ -4056,6 +4063,9 @@ public: IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, UInt element); IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, IRInst* element); + IRInst* emitGetElement(IRType* type, IRInst* arrayLikeType, IRIntegerValue element); + IRInst* emitGetElementPtr(IRType* type, IRInst* arrayLikeType, IRIntegerValue element); + IRInst* emitMakeResultError(IRType* resultType, IRInst* errorVal); IRInst* emitMakeResultValue(IRType* resultType, IRInst* val); IRInst* emitIsResultError(IRInst* result); @@ -4094,6 +4104,8 @@ public: IRInst* emitMakeMatrixFromScalar(IRType* type, IRInst* scalarValue); + IRInst* emitMakeCoopVector(IRType* type, UInt argCount, IRInst* const* args); + IRInst* emitMakeArray(IRType* type, UInt argCount, IRInst* const* args); IRInst* emitMakeArrayList(IRType* type, UInt argCount, IRInst* const* args); diff --git a/source/slang/slang-ir-lower-coopvec.cpp b/source/slang/slang-ir-lower-coopvec.cpp new file mode 100644 index 000000000..c8b24597c --- /dev/null +++ b/source/slang/slang-ir-lower-coopvec.cpp @@ -0,0 +1,234 @@ +#include "slang-ir-lower-coopvec.h" + +#include "slang-ir-insts.h" +#include "slang-ir.h" + +namespace Slang +{ +struct CoopVecLoweringContext +{ + IRModule* module; + DiagnosticSink* sink; + + InstWorkList workList; + InstHashSet workListSet; + + CoopVecLoweringContext(IRModule* inModule) + : module(inModule), workList(inModule), workListSet(inModule) + { + } + + struct LoweredCoopVecInfo : public RefObject + { + IRType* coopvecType; + IRArrayType* arrayType; + }; + Dictionary<IRInst*, RefPtr<LoweredCoopVecInfo>> mapLoweredArrayToCoopVecInfo; + Dictionary<IRInst*, RefPtr<LoweredCoopVecInfo>> loweredCoopVecs; + + IRType* maybeLowerCoopVecType(IRBuilder* builder, IRType* type) + { + if (const auto cvt = as<IRCoopVectorType>(type)) + { + if (auto info = getLoweredCoopVecType(builder, cvt)) + return info->arrayType; + } + return type; + } + + LoweredCoopVecInfo* getLoweredCoopVecType(IRBuilder* builder, IRCoopVectorType* type) + { + if (auto loweredInfo = loweredCoopVecs.tryGetValue(type)) + return loweredInfo->Ptr(); + if (auto loweredInfo = mapLoweredArrayToCoopVecInfo.tryGetValue(type)) + return loweredInfo->Ptr(); + + if (!type) + return nullptr; + + RefPtr<LoweredCoopVecInfo> info = new LoweredCoopVecInfo(); + info->coopvecType = (IRType*)type; + info->arrayType = builder->getArrayType(type->getElementType(), type->getElementCount()); + builder->addNameHintDecoration(info->arrayType, UnownedStringSlice("CoopVec")); + + mapLoweredArrayToCoopVecInfo[info->arrayType] = info; + loweredCoopVecs[type] = info; + return info.Ptr(); + } + + void addToWorkList(IRInst* inst) + { + for (auto ii = inst->getParent(); ii; ii = ii->getParent()) + { + if (as<IRGeneric>(ii)) + return; + } + + if (workListSet.contains(inst)) + return; + + workList.add(inst); + workListSet.add(inst); + } + + void processMakeCoopVec(IRInst* inst) + { + IRBuilder builderStorage(module); + auto builder = &builderStorage; + builder->setInsertBefore(inst); + + const auto cvt = as<IRCoopVectorType>(inst->getDataType()); + SLANG_ASSERT(cvt); + auto info = getLoweredCoopVecType(builder, cvt); + List<IRInst*> operands; + operands.setCount(Index(inst->getOperandCount())); + UIndex i = 0; + for (auto operand = inst->getOperands(); i < inst->getOperandCount(); operand++, i++) + operands[Index(i)] = operand->get(); + auto makeArray = + builder->emitMakeArray(info->arrayType, operands.getCount(), operands.begin()); + inst->replaceUsesWith(makeArray); + inst->removeAndDeallocate(); + } + + void processGetCoopVecElement(IRGetElement*) {} + + void processGetElementPtr(IRGetElementPtr*) {} + + void processGetElement(IRGetElement*) {} + + void processCoopVecType(IRCoopVectorType* inst) + { + IRBuilder builderStorage(module); + auto builder = &builderStorage; + builder->setInsertBefore(inst); + + auto loweredCoopVecInfo = getLoweredCoopVecType(builder, inst); + SLANG_ASSERT(loweredCoopVecInfo); + SLANG_UNUSED(loweredCoopVecInfo); + } + + void processUpdateElement(IRUpdateElement*) {} + + void processEntrywiseOp(IRInst* inst) + { + SLANG_ASSERT(inst->getOperandCount()); + if (!as<IRCoopVectorType>(inst->getDataType())) + return; + List<IRInst*> operands; + IRIntegerValue width = 0; + IRType* resultElementType = nullptr; + UIndex opIndex = 0; + for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); + operand++, opIndex++) + { + operands.add(operand->get()); + if (const auto cv = as<IRCoopVectorType>(operand->get()->getDataType())) + { + width = getIntVal(cv->getElementCount()); + resultElementType = cv->getElementType(); + } + } + if (width == 0) + return; + IRBuilder builder(module); + IRType* resultElementPtrType = builder.getPtrType(resultElementType); + builder.setInsertBefore(inst); + const auto result = builder.emitVar(inst->getFullType()); + List<IRInst*> entrywiseOperands; + entrywiseOperands.setCount(operands.getCount()); + for (IRIntegerValue i = 0; i < width; ++i) + { + for (int j = 0; j < operands.getCount(); ++j) + { + if (const auto cv = as<IRCoopVectorType>(operands[j]->getDataType())) + { + SLANG_ASSERT(getIntVal(cv->getElementCount()) == width); + const auto elementType = cv->getElementType(); + entrywiseOperands[j] = builder.emitGetElement(elementType, operands[j], i); + } + else + { + entrywiseOperands[j] = operands[j]; + } + } + const auto x = builder.emitIntrinsicInst( + resultElementType, + inst->getOp(), + entrywiseOperands.getCount(), + entrywiseOperands.begin()); + const auto d = builder.emitGetElementPtr(resultElementPtrType, result, i); + builder.emitStore(d, x); + } + const auto v = builder.emitLoad(result); + inst->replaceUsesWith(v); + inst->removeAndDeallocate(); + } + + void processInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_CoopVectorType: + processCoopVecType((IRCoopVectorType*)inst); + break; + case kIROp_MakeCoopVector: + processMakeCoopVec((IRMakeCoopVector*)inst); + break; + case kIROp_GetElement: + processGetElement((IRGetElement*)inst); + break; + case kIROp_GetElementPtr: + processGetElementPtr((IRGetElementPtr*)inst); + break; + case kIROp_UpdateElement: + processUpdateElement((IRUpdateElement*)inst); + break; + case kIROp_Neg: + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_IntCast: + case kIROp_FloatCast: + case kIROp_CastIntToFloat: + case kIROp_CastFloatToInt: + processEntrywiseOp(inst); + break; + default: + break; + } + } + + void processModule() + { + IRBuilder builder(module); + + addToWorkList(module->getModuleInst()); + + while (workList.getCount() != 0) + { + IRInst* inst = workList.getLast(); + + workList.removeLast(); + workListSet.remove(inst); + + processInst(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + addToWorkList(child); + } + + // Replace all coopvec types with sized array types + for (const auto& [key, value] : loweredCoopVecs) + key->replaceUsesWith(value->arrayType); + } +}; + +void lowerCooperativeVectors(IRModule* module, DiagnosticSink* sink) +{ + CoopVecLoweringContext context(module); + context.sink = sink; + context.processModule(); +} +} // namespace Slang diff --git a/source/slang/slang-ir-lower-coopvec.h b/source/slang/slang-ir-lower-coopvec.h new file mode 100644 index 000000000..60b8943d5 --- /dev/null +++ b/source/slang/slang-ir-lower-coopvec.h @@ -0,0 +1,12 @@ +#pragma once + +namespace Slang +{ + +struct IRModule; +class DiagnosticSink; + +/// Lower Cooperative Vectors to ordinary arrays +void lowerCooperativeVectors(IRModule* module, DiagnosticSink* sink); + +} // namespace Slang diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index ced0b5eb0..fc399954b 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -353,6 +353,31 @@ struct PeepholeContext : InstPassBase break; } break; + case kIROp_MakeCoopVectorFromValuePack: + { + const auto pack = inst->getOperand(0); + if (const auto packType = as<IRTypePack>(pack->getDataType())) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + List<IRInst*> args; + for (UInt j = 0; j < packType->getOperandCount(); ++j) + { + const auto e = builder.emitGetTupleElement( + cast<IRType>(packType->getOperand(j)), + pack, + j); + args.add(e); + } + const auto cvt = builder.getCoopVectorType( + args[0]->getDataType(), + builder.getIntValue(builder.getIntType(), args.getCount())); + const auto v = builder.emitMakeCoopVector(cvt, args.getCount(), args.begin()); + inst->replaceUsesWith(v); + inst->removeAndDeallocate(); + } + } + break; case kIROp_FieldExtract: if (inst->getOperand(0)->getOp() == kIROp_MakeStruct) { @@ -1091,7 +1116,12 @@ struct PeepholeContext : InstPassBase break; auto type = inst->getOperand(0)->getDataType(); IRSizeAndAlignment sizeAlignment; - getNaturalSizeAndAlignment(targetProgram->getOptionSet(), type, &sizeAlignment); + const auto res = getNaturalSizeAndAlignment( + targetProgram->getOptionSet(), + type, + &sizeAlignment); + if (!SLANG_SUCCEEDED(res)) + break; IRBuilder builder(module); builder.setInsertBefore(inst); auto stride = diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 3043e7e31..dd7819e1a 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -21,6 +21,8 @@ IRType* getVectorElementType(IRType* type) { if (auto vectorType = as<IRVectorType>(type)) return vectorType->getElementType(); + if (auto coopVecType = as<IRCoopVectorType>(type)) + return coopVecType->getElementType(); return type; } @@ -1388,6 +1390,7 @@ bool isZero(IRInst* inst) return as<IRFloatLit>(inst)->getValue() == 0.0; case kIROp_BoolLit: return as<IRBoolLit>(inst)->getValue() == false; + case kIROp_MakeCoopVector: case kIROp_MakeVector: case kIROp_MakeVectorFromScalar: case kIROp_MakeMatrix: @@ -1422,6 +1425,7 @@ bool isOne(IRInst* inst) return as<IRFloatLit>(inst)->getValue() == 1.0; case kIROp_BoolLit: return as<IRBoolLit>(inst)->getValue(); + case kIROp_MakeCoopVector: case kIROp_MakeVector: case kIROp_MakeVectorFromScalar: case kIROp_MakeMatrix: diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 32acc7baa..6a7564e67 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2964,6 +2964,13 @@ IRVectorType* IRBuilder::getVectorType(IRType* elementType, IRIntegerValue eleme return getVectorType(elementType, getIntValue(getIntType(), elementCount)); } +IRCoopVectorType* IRBuilder::getCoopVectorType(IRType* elementType, IRInst* elementCount) +{ + IRInst* operands[] = {elementType, elementCount}; + return (IRCoopVectorType*) + getType(kIROp_CoopVectorType, sizeof(operands) / sizeof(operands[0]), operands); +} + IRMatrixType* IRBuilder::getMatrixType( IRType* elementType, IRInst* rowCount, @@ -3887,6 +3894,26 @@ IRInst* IRBuilder::emitDefaultConstruct(IRType* type, bool fallback) return nullptr; return emitIntrinsicInst(type, kIROp_MakeVectorFromScalar, 1, &inner); } + case kIROp_CoopVectorType: + { + auto coopVecType = as<IRCoopVectorType>(actualType); + if (auto count = as<IRIntLit>(coopVecType->getElementCount())) + { + auto element = emitDefaultConstruct(coopVecType->getElementType(), fallback); + if (!element) + return nullptr; + List<IRInst*> elements; + constexpr int maxCount = 4096; + if (count->getValue() > maxCount) + break; + for (IRIntegerValue i = 0; i < count->getValue(); i++) + { + elements.add(element); + } + return emitMakeCoopVector(type, elements.getCount(), elements.getBuffer()); + } + break; + } case kIROp_MatrixType: { auto inner = @@ -4171,6 +4198,18 @@ IRInst* IRBuilder::emitGetNativeString(IRInst* str) return emitIntrinsicInst(getNativeStringType(), kIROp_getNativeStr, 1, &str); } +IRInst* IRBuilder::emitGetElement(IRType* type, IRInst* arrayLikeType, IRIntegerValue element) +{ + IRInst* args[] = {arrayLikeType, getIntValue(getIntType(), element)}; + return emitIntrinsicInst(type, kIROp_GetElement, 2, args); +} + +IRInst* IRBuilder::emitGetElementPtr(IRType* type, IRInst* arrayLikeType, IRIntegerValue element) +{ + IRInst* args[] = {arrayLikeType, getIntValue(getIntType(), element)}; + return emitIntrinsicInst(type, kIROp_GetElementPtr, 2, args); +} + IRInst* IRBuilder::emitGetTupleElement(IRType* type, IRInst* tuple, IRInst* element) { IRInst* args[] = {tuple, element}; @@ -4345,6 +4384,11 @@ IRInst* IRBuilder::emitMakeMatrixFromScalar(IRType* type, IRInst* scalarValue) return emitIntrinsicInst(type, kIROp_MakeMatrixFromScalar, 1, &scalarValue); } +IRInst* IRBuilder::emitMakeCoopVector(IRType* type, UInt argCount, IRInst* const* args) +{ + return emitIntrinsicInst(type, kIROp_MakeCoopVector, argCount, args); +} + IRInst* IRBuilder::emitMakeArray(IRType* type, UInt argCount, IRInst* const* args) { return emitIntrinsicInst(type, kIROp_MakeArray, argCount, args); @@ -5187,6 +5231,10 @@ IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRInst* index) { type = vectorType->getElementType(); } + else if (auto coopVecType = as<IRCoopVectorType>(valueType)) + { + type = coopVecType->getElementType(); + } else if (auto matrixType = as<IRMatrixType>(valueType)) { type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount()); @@ -8143,6 +8191,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_GetAddr: case kIROp_GetValueFromBoundInterface: case kIROp_MakeUInt64: + case kIROp_MakeCoopVector: case kIROp_MakeVector: case kIROp_MakeMatrix: case kIROp_MakeMatrixFromScalar: @@ -8627,6 +8676,7 @@ bool isMovableInst(IRInst* inst) switch (inst->getOp()) { + case kIROp_MakeCoopVector: case kIROp_Add: case kIROp_Sub: case kIROp_Mul: diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index b29b3b815..8f53c9f14 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1862,6 +1862,14 @@ struct IRHitObjectType : IRType IR_LEAF_ISA(HitObjectType) }; +struct IRCoopVectorType : IRType +{ + IRType* getElementType() { return (IRType*)getOperand(0); } + IRInst* getElementCount() { return getOperand(1); } + + IR_LEAF_ISA(CoopVectorType) +}; + bool isDefinition(IRInst* inVal); // A structure type is represented as a parent instruction, diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index ba6b3413d..5c0c3edfb 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4852,6 +4852,29 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> return LoweredValInfo::simple( getBuilder()->emitMakeMatrix(irType, args.getCount(), args.getBuffer())); } + else if (auto coopVecType = as<CoopVectorExpressionType>(type)) + { + UInt elementCount = (UInt)getIntVal(coopVecType->getElementCount()); + + for (UInt ee = 0; ee < argCount; ++ee) + { + auto argExpr = expr->args[ee]; + LoweredValInfo argVal = lowerRValueExpr(context, argExpr); + args.add(getSimpleVal(context, argVal)); + } + if (elementCount > argCount) + { + auto irDefaultValue = + getSimpleVal(context, getDefaultVal(coopVecType->getElementType())); + for (UInt ee = argCount; ee < elementCount; ++ee) + { + args.add(irDefaultValue); + } + } + + return LoweredValInfo::simple( + getBuilder()->emitMakeCoopVector(irType, args.getCount(), args.getBuffer())); + } else if (auto declRefType = as<DeclRefType>(type)) { DeclRef<Decl> declRef = declRefType->getDeclRef(); diff --git a/source/slang/slang-profile-defs.h b/source/slang/slang-profile-defs.h index 8328bf063..57ad3ab25 100644 --- a/source/slang/slang-profile-defs.h +++ b/source/slang/slang-profile-defs.h @@ -104,6 +104,7 @@ PROFILE_VERSION(DX_6_4, DX) PROFILE_VERSION(DX_6_5, DX) PROFILE_VERSION(DX_6_6, DX) PROFILE_VERSION(DX_6_7, DX) +PROFILE_VERSION(DX_6_8, DX) PROFILE_VERSION(GLSL_150, GLSL) PROFILE_VERSION(GLSL_330, GLSL) @@ -139,6 +140,7 @@ PROFILE(DX_Compute_6_4, cs_6_4, Compute, DX_6_4) PROFILE(DX_Compute_6_5, cs_6_5, Compute, DX_6_5) PROFILE(DX_Compute_6_6, cs_6_6, Compute, DX_6_6) PROFILE(DX_Compute_6_7, cs_6_7, Compute, DX_6_7) +PROFILE(DX_Compute_6_8, cs_6_8, Compute, DX_6_8) PROFILE(DX_Domain_5_0, ds_5_0, Domain, DX_5_0) PROFILE(DX_Domain_5_1, ds_5_1, Domain, DX_5_1) @@ -150,6 +152,7 @@ PROFILE(DX_Domain_6_4, ds_6_4, Domain, DX_6_4) PROFILE(DX_Domain_6_5, ds_6_5, Domain, DX_6_5) PROFILE(DX_Domain_6_6, ds_6_6, Domain, DX_6_6) PROFILE(DX_Domain_6_7, ds_6_7, Domain, DX_6_7) +PROFILE(DX_Domain_6_8, ds_6_8, Domain, DX_6_8) PROFILE(DX_Geometry_4_0, gs_4_0, Geometry, DX_4_0) PROFILE(DX_Geometry_4_1, gs_4_1, Geometry, DX_4_1) @@ -163,6 +166,7 @@ PROFILE(DX_Geometry_6_4, gs_6_4, Geometry, DX_6_4) PROFILE(DX_Geometry_6_5, gs_6_5, Geometry, DX_6_5) PROFILE(DX_Geometry_6_6, gs_6_6, Geometry, DX_6_6) PROFILE(DX_Geometry_6_7, gs_6_7, Geometry, DX_6_7) +PROFILE(DX_Geometry_6_8, gs_6_8, Geometry, DX_6_8) PROFILE(DX_Hull_5_0, hs_5_0, Hull, DX_5_0) PROFILE(DX_Hull_5_1, hs_5_1, Hull, DX_5_1) @@ -174,6 +178,7 @@ PROFILE(DX_Hull_6_4, hs_6_4, Hull, DX_6_4) PROFILE(DX_Hull_6_5, hs_6_5, Hull, DX_6_5) PROFILE(DX_Hull_6_6, hs_6_6, Hull, DX_6_6) PROFILE(DX_Hull_6_7, hs_6_7, Hull, DX_6_7) +PROFILE(DX_Hull_6_8, hs_6_8, Hull, DX_6_8) PROFILE(DX_Fragment_4_0, ps_4_0, Fragment, DX_4_0) PROFILE(DX_Fragment_4_1, ps_4_1, Fragment, DX_4_1) @@ -187,6 +192,7 @@ PROFILE(DX_Fragment_6_4, ps_6_4, Fragment, DX_6_4) PROFILE(DX_Fragment_6_5, ps_6_5, Fragment, DX_6_5) PROFILE(DX_Fragment_6_6, ps_6_6, Fragment, DX_6_6) PROFILE(DX_Fragment_6_7, ps_6_7, Fragment, DX_6_7) +PROFILE(DX_Fragment_6_8, ps_6_8, Fragment, DX_6_8) PROFILE(DX_Vertex_4_0, vs_4_0, Vertex, DX_4_0) PROFILE(DX_Vertex_4_1, vs_4_1, Vertex, DX_4_1) @@ -200,14 +206,17 @@ PROFILE(DX_Vertex_6_4, vs_6_4, Vertex, DX_6_4) PROFILE(DX_Vertex_6_5, vs_6_5, Vertex, DX_6_5) PROFILE(DX_Vertex_6_6, vs_6_6, Vertex, DX_6_6) PROFILE(DX_Vertex_6_7, vs_6_7, Vertex, DX_6_7) +PROFILE(DX_Vertex_6_8, vs_6_8, Vertex, DX_6_8) PROFILE(DX_Mesh_6_5, ms_6_5, Mesh, DX_6_5) PROFILE(DX_Mesh_6_6, ms_6_6, Mesh, DX_6_6) PROFILE(DX_Mesh_6_7, ms_6_7, Mesh, DX_6_7) +PROFILE(DX_Mesh_6_8, ms_6_8, Mesh, DX_6_8) PROFILE(DX_Amplification_6_5, as_6_5, Amplification, DX_6_5) PROFILE(DX_Amplification_6_6, as_6_6, Amplification, DX_6_6) PROFILE(DX_Amplification_6_7, as_6_7, Amplification, DX_6_7) +PROFILE(DX_Amplification_6_8, as_6_8, Amplification, DX_6_8) // TODO: consider making `lib_*_*` alias these... PROFILE(DX_None_4_0, sm_4_0, Unknown, DX_4_0) @@ -234,6 +243,7 @@ PROFILE(DX_Lib_6_4, lib_6_4, Unknown, DX_6_4) PROFILE(DX_Lib_6_5, lib_6_5, Unknown, DX_6_5) PROFILE(DX_Lib_6_6, lib_6_6, Unknown, DX_6_6) PROFILE(DX_Lib_6_7, lib_6_7, Unknown, DX_6_7) +PROFILE(DX_Lib_6_8, lib_6_8, Unknown, DX_6_8) PROFILE_ALIAS(DX_None_6_1, DX_Lib_6_1, sm_6_1) PROFILE_ALIAS(DX_None_6_2, DX_Lib_6_2, sm_6_2) @@ -242,6 +252,7 @@ PROFILE_ALIAS(DX_None_6_4, DX_Lib_6_4, sm_6_4) PROFILE_ALIAS(DX_None_6_5, DX_Lib_6_5, sm_6_5) PROFILE_ALIAS(DX_None_6_6, DX_Lib_6_6, sm_6_6) PROFILE_ALIAS(DX_None_6_7, DX_Lib_6_7, sm_6_7) +PROFILE_ALIAS(DX_None_6_8, DX_Lib_6_8, sm_6_8) PROFILE(METAL_LIB_2_3, metallib_2_3, Unknown, METAL_2_3) PROFILE(METAL_LIB_2_4, metallib_2_4, Unknown, METAL_2_4) |
