summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2025-01-30 00:59:49 -0800
committerGitHub <noreply@github.com>2025-01-30 00:59:49 -0800
commitba9b2785c69c1b8c6d2b4103267b5281815f9f23 (patch)
treee4ba4ca76c6592b90764a0a7ac32502639dc93aa /source
parent2ae194d51e15c064c3d905e628f7335de7504e32 (diff)
Support cooperative vector (#6223)
* Support cooperative vector without Vulkan-header update Adding a Slang support for cooperative vector. But this commit doesn't have Vulkan-header update.
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang12
-rw-r--r--source/slang/hlsl.meta.slang2269
-rw-r--r--source/slang/slang-ast-support-types.h5
-rw-r--r--source/slang/slang-ast-type.cpp23
-rw-r--r--source/slang/slang-ast-type.h11
-rw-r--r--source/slang/slang-capabilities.capdef58
-rw-r--r--source/slang/slang-check-conversion.cpp46
-rw-r--r--source/slang/slang-compiler.cpp1
-rw-r--r--source/slang/slang-diagnostic-defs.h5
-rw-r--r--source/slang/slang-emit-c-like.cpp108
-rw-r--r--source/slang/slang-emit-glsl.cpp17
-rw-r--r--source/slang/slang-emit-hlsl.cpp11
-rw-r--r--source/slang/slang-emit-hlsl.h2
-rw-r--r--source/slang/slang-emit-spirv-ops.h13
-rw-r--r--source/slang/slang-emit-spirv.cpp67
-rw-r--r--source/slang/slang-emit.cpp18
-rw-r--r--source/slang/slang-ir-autodiff.cpp1
-rw-r--r--source/slang/slang-ir-constexpr.cpp1
-rw-r--r--source/slang/slang-ir-explicit-global-init.cpp32
-rw-r--r--source/slang/slang-ir-explicit-global-init.h3
-rw-r--r--source/slang/slang-ir-inst-defs.h8
-rw-r--r--source/slang/slang-ir-insts.h12
-rw-r--r--source/slang/slang-ir-lower-coopvec.cpp234
-rw-r--r--source/slang/slang-ir-lower-coopvec.h12
-rw-r--r--source/slang/slang-ir-peephole.cpp32
-rw-r--r--source/slang/slang-ir-util.cpp4
-rw-r--r--source/slang/slang-ir.cpp50
-rw-r--r--source/slang/slang-ir.h8
-rw-r--r--source/slang/slang-lower-to-ir.cpp23
-rw-r--r--source/slang/slang-profile-defs.h11
30 files changed, 3074 insertions, 23 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 2224b8e82..d4cce037d 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -3234,6 +3234,18 @@ int __alignOf_intrinsic<T>()
return __alignOf_intrinsic_impl<T>(__default<T>());
}
+[__unsafeForceInlineEarly]
+int32_t __elemToByteOffset<T>(int32_t elemOffset)
+{
+ return elemOffset * __naturalStrideOf<T>();
+}
+
+[__unsafeForceInlineEarly]
+int32_t __byteToElemOffset<T>(int32_t byteOffset)
+{
+ return byteOffset / __naturalStrideOf<T>();
+}
+
__intrinsic_op($(kIROp_TreatAsDynamicUniform))
T asDynamicUniform<T>(T v);
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 1853a82b6..ecab7ff93 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -4530,6 +4530,22 @@ __intrinsic_op($(kIROp_ByteAddressBufferStore))
[require(cpp_cuda_glsl_hlsl_metal_spirv, byteaddressbuffer_rw)]
void __byteAddressBufferStore<T>(RasterizerOrderedByteAddressBuffer buffer, int offset, int alignment, T value);
+__intrinsic_op($(kIROp_GetUntypedBufferPtr))
+[require(spirv, byteaddressbuffer)]
+Ptr<uint[]> __getByteAddressBufferPtr(ByteAddressBuffer buffer);
+
+__intrinsic_op($(kIROp_GetUntypedBufferPtr))
+[require(spirv, byteaddressbuffer_rw)]
+Ptr<uint[]> __getByteAddressBufferPtr(RWByteAddressBuffer buffer);
+
+__intrinsic_op($(kIROp_GetStructuredBufferPtr))
+[require(spirv, structuredbuffer)]
+Ptr<T[]> __getStructuredBufferPtr<T>(StructuredBuffer<T> buffer);
+
+__intrinsic_op($(kIROp_GetStructuredBufferPtr))
+[require(spirv, structuredbuffer_rw)]
+Ptr<T[]> __getStructuredBufferPtr<T>(RWStructuredBuffer<T> buffer);
+
/**
Represents an opaque handle to a read-only structured buffer allocated in global memory.
A structured buffer can be viewed as an array of the specified element type.
@@ -12414,10 +12430,50 @@ __generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType>
__intrinsic_op($(kIROp_IntCast))
T __int_cast(U val);
+__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType>
+__intrinsic_op($(kIROp_FloatCast))
+T __real_cast(U val);
+
+__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType>
+__intrinsic_op($(kIROp_CastIntToFloat))
+T __int_to_float_cast(U val);
+
+__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType>
+__intrinsic_op($(kIROp_CastFloatToInt))
+T __float_to_int_cast(U val);
+
+__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType>
+[__unsafeForceInlineEarly]
+T __arithmetic_cast(U val)
+{
+ if (__isFloat<T>() && __isInt<U>())
+ return __int_to_float_cast<T>(val);
+ else if (__isInt<T>() && __isFloat<U>())
+ return __float_to_int_cast<T>(val);
+ else if (__isFloat<T>() && __isFloat<U>())
+ return __real_cast<T>(val);
+ else if (__isInt<T>() && __isInt<U>())
+ return __int_cast<T>(val);
+ return T(0);
+}
+
__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType, let N : int>
__intrinsic_op($(kIROp_IntCast))
vector<T,N> __int_cast(vector<U,N> val);
+__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType, let N : int>
+__intrinsic_op($(kIROp_FloatCast))
+vector<T,N> __real_cast(vector<U,N> val);
+
+__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType, let N : int>
+__intrinsic_op($(kIROp_CastIntToFloat))
+vector<T,N> __int_to_float_cast(vector<U,N> val);
+
+__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType, let N : int>
+__intrinsic_op($(kIROp_CastFloatToInt))
+vector<T,N> __float_to_int_cast(vector<U,N> val);
+
+
/// Extract sign of value.
/// @param x The value to extract the sign of.
/// @return -1 if `x` is negative, 0 if `x` is zero, and 1 if `x` is positive.
@@ -21477,6 +21533,7 @@ ${{{{
}
}}}}
+
/// Represents a bindless handle to a descriptor. A descriptor handle is always an ordinary data type and can be
/// declared in any memory location.
/// @remarks Opaque descriptor types such as textures(`Texture2D` etc.), `SamplerState` and buffers (e.g. `StructuredBuffer`)
@@ -21579,8 +21636,6 @@ extern T getDescriptorFromHandle<T:IOpaqueDescriptor>(DescriptorHandle<T> handle
__intrinsic_op($(kIROp_NonUniformResourceIndex))
DescriptorHandle<T> nonuniform<T:IOpaqueDescriptor>(DescriptorHandle<T> ptr);
-//@hidden:
-
__glsl_version(450)
__glsl_extension(GL_ARB_shader_clock)
[require(glsl_spirv, GL_ARB_shader_clock)]
@@ -21633,6 +21688,2216 @@ extension<T, L : IBufferDataLayout> RasterizerOrderedStructuredBuffer<T, L> : IR
int getCount() { uint count; uint stride; this.GetDimensions(count, stride); return count; }
}
+//
+// Cooperative Vector
+//
+
+__intrinsic_type($(kIROp_CoopVectorType))
+[require(cooperative_vector)]
+struct CoopVec<T : __BuiltinArithmeticType, let N : int> : IArray<T>, IArithmetic
+{
+ //
+ // Initialization
+ //
+
+ [ForceInline]
+ [require(cooperative_vector)]
+ __init()
+ {
+ this = CoopVec<T, N>(T(0));
+ }
+ [ForceInline]
+ [require(cooperative_vector)]
+ __init(T t)
+ {
+ this.fill(t);
+ }
+ [ForceInline]
+ [require(cooperative_vector)]
+ __init<U : __BuiltinArithmeticType>(CoopVec<U, N> other)
+ {
+ this.copyFrom(other);
+ }
+
+ [ForceInline]
+ [require(cooperative_vector)]
+ __init<each U : __BuiltinArithmeticType>(expand each U args)
+ {
+ static_assert(countof(U) == N, "number of arguments to CoopVec constructor must match number of elements");
+ this = __makeCoopVec<T, N>(expand (__arithmetic_cast<T>(each args)));
+ }
+ [OverloadRank(-10)]
+ [ForceInline]
+ __init(int i)
+ {
+ this = CoopVec<T, N>(T(i));
+ }
+ [ForceInline]
+ __init(This x)
+ {
+ this = x;
+ }
+
+ //
+ // Simple setters
+ //
+
+ [require(hlsl)]
+ [mutating]
+ [ForceInline]
+ void copyFrom<U : __BuiltinArithmeticType>(CoopVec<U,N> other)
+ {
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm ".CopyFrom";
+ default:
+ if (__isFloat<T>() && __isInt<U>())
+ this = __int_to_float_cast<T>(other);
+ else if (__isInt<T>() && __isFloat<U>())
+ this = __float_to_int_cast<T>(other);
+ else if (__isFloat<T>() && __isFloat<U>())
+ this = __real_cast<T>(other);
+ else if (__isInt<T>() && __isInt<U>())
+ this = __int_cast<T>(other);
+ }
+ }
+
+ [require(cooperative_vector)]
+ [mutating]
+ [ForceInline]
+ void fill(T t)
+ {
+ __target_switch
+ {
+ case spirv:
+ this = spirv_asm {
+ OpExtension "SPV_EXT_replicated_composites";
+ OpCapability ReplicatedCompositesEXT;
+ result:$$CoopVec<T, N> = OpCompositeConstructReplicateEXT $t;
+ };
+ case hlsl:
+ case hlsl: __intrinsic_asm ".Fill";
+ default:
+ for(int i = 0; i < N; ++i)
+ this[i] = t;
+ }
+ }
+
+ //
+ // Loading and storing
+ //
+
+ [ForceInline]
+ [require(cooperative_vector)]
+ void store(RWByteAddressBuffer buffer, int32_t byteOffset16ByteAligned = 0)
+ {
+ __target_switch
+ {
+ case spirv:
+ let ptr = buffer.GetBufferPointer();
+ spirv_asm
+ {
+ // TODO: Should this be a byte offset
+ OpCooperativeVectorStoreNV $ptr $byteOffset16ByteAligned $this None;
+ };
+ // Not supported
+ // case hlsl:
+ // this.__Store(buffer, byteOffset16ByteAligned);
+ default:
+ for(int i = 0; i < N; ++i)
+ buffer.StoreByteOffset(byteOffset16ByteAligned + __elemToByteOffset<T>(i), this[i]);
+ }
+ }
+
+ [ForceInline]
+ [require(cooperative_vector)]
+ void store(RWStructuredBuffer<T> buffer, int32_t byteOffset16ByteAligned = 0)
+ {
+ __target_switch
+ {
+ case spirv:
+ let ptr = __getStructuredBufferPtr(buffer);
+ spirv_asm
+ {
+ // TODO: Should this be a byte offset
+ OpCooperativeVectorStoreNV $ptr $byteOffset16ByteAligned $this None;
+ };
+ // Not supported
+ // case hlsl:
+ // this.__Store(buffer, byteOffset16ByteAligned);
+ default:
+ for(int i = 0; i < N; ++i)
+ buffer[i + __byteToElemOffset<T>(byteOffset16ByteAligned)] = this[i];
+ }
+ }
+
+ [ForceInline]
+ [require(cooperative_vector)]
+ void store<let M : int>(__ref groupshared T[M] data, int32_t byteOffset16ByteAligned = 0)
+ {
+ __target_switch
+ {
+ case spirv:
+ spirv_asm{
+ OpCooperativeVectorStoreNV &data $byteOffset16ByteAligned $this None;
+ };
+ case hlsl:
+ this.__Store(data, __byteToElemOffset<T>(byteOffset16ByteAligned));
+ default:
+ for(int i = 0; i < N; ++i)
+ data[i + __byteToElemOffset<T>(byteOffset16ByteAligned)] = this[i];
+ }
+ }
+
+ /// spirv only storing to a groupshared array of any type
+ [ForceInline]
+ [require(spirv, cooperative_vector)]
+ void storeAny<U, let M : int>(__ref groupshared U[M] data, int32_t byteOffset16ByteAligned = 0)
+ {
+ __target_switch
+ {
+ case spirv:
+ spirv_asm{
+ OpCooperativeVectorStoreNV &data $byteOffset16ByteAligned $this None;
+ };
+ }
+ }
+ [ForceInline]
+ [require(spirv, cooperative_vector)]
+ void storeAny<U, let M : int, let L : int>(__ref groupshared vector<U, L>[M] data, int32_t byteOffset16ByteAligned = 0)
+ {
+ __target_switch
+ {
+ case spirv:
+ spirv_asm{
+ OpCooperativeVectorStoreNV &data $byteOffset16ByteAligned $this None;
+ };
+ }
+ }
+
+ [ForceInline]
+ [__NoSideEffect]
+ [require(cooperative_vector)]
+ static CoopVec<T, N> load(ByteAddressBuffer buffer, int32_t byteOffset16ByteAligned = 0)
+ {
+ __target_switch
+ {
+ case spirv:
+ let ptr = buffer.GetBufferPointer();
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpCooperativeVectorLoadNV $ptr $byteOffset16ByteAligned None;
+ };
+ case hlsl:
+ CoopVec<T, N> ret;
+ ret.__Load(buffer, byteOffset16ByteAligned);
+ return ret;
+ default:
+ var vec = CoopVec<T, N>();
+ for(int i = 0; i < N; ++i)
+ vec[i] = buffer.LoadByteOffset<T>(byteOffset16ByteAligned + __elemToByteOffset<T>(i));
+ return vec;
+ }
+ return CoopVec<T, N>();
+ }
+
+ [ForceInline]
+ [__NoSideEffect]
+ [require(cooperative_vector)]
+ static CoopVec<T, N> load(RWByteAddressBuffer buffer, int32_t byteOffset16ByteAligned = 0)
+ {
+ __target_switch
+ {
+ case spirv:
+ let ptr = buffer.GetBufferPointer();
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpCooperativeVectorLoadNV $ptr $byteOffset16ByteAligned None;
+ };
+ case hlsl:
+ CoopVec<T, N> ret;
+ ret.__Load(buffer, byteOffset16ByteAligned);
+ return ret;
+ default:
+ var vec = CoopVec<T, N>();
+ for(int i = 0; i < N; ++i)
+ vec[i] = buffer.LoadByteOffset<T>(byteOffset16ByteAligned + __elemToByteOffset<T>(i));
+ return vec;
+ }
+ return CoopVec<T, N>();
+ }
+
+ [ForceInline]
+ [__NoSideEffect]
+ [require(cooperative_vector)]
+ static CoopVec<T, N> load(StructuredBuffer<T> buffer, int32_t byteOffset16ByteAligned = 0)
+ {
+ __target_switch
+ {
+ case spirv:
+ let ptr = __getStructuredBufferPtr(buffer);
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpCooperativeVectorLoadNV $ptr $byteOffset16ByteAligned None;
+ };
+ // Not supported
+ // case hlsl:
+ // CoopVec<T, N> ret;
+ // ret.__Load(buffer, byteOffset16ByteAligned);
+ // return ret;
+ default:
+ var vec = CoopVec<T, N>();
+ for(int i = 0; i < N; ++i)
+ vec[i] = buffer[__byteToElemOffset<T>(byteOffset16ByteAligned) + i];
+ return vec;
+ }
+ return CoopVec<T, N>();
+ }
+
+ [ForceInline]
+ [__NoSideEffect]
+ [require(spirv, cooperative_vector)]
+ static CoopVec<T, N> load(RWStructuredBuffer<T> buffer, int32_t byteOffset16ByteAligned = 0)
+ {
+ __target_switch
+ {
+ case spirv:
+ let ptr = __getStructuredBufferPtr(buffer);
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpCooperativeVectorLoadNV $ptr $byteOffset16ByteAligned None;
+ };
+ // Not supported
+ // case hlsl:
+ // CoopVec<T, N> ret;
+ // ret.__Load(buffer, byteOffset16ByteAligned);
+ // return ret;
+ default:
+ var vec = CoopVec<T, N>();
+ for(int i = 0; i < N; ++i)
+ vec[i] = buffer[__byteToElemOffset<T>(byteOffset16ByteAligned) + i];
+ return vec;
+ }
+ }
+
+ // Groupshared
+ [ForceInline]
+ [__NoSideEffect]
+ [require(cooperative_vector)]
+ static CoopVec<T, N> load<let M : int>(__constref groupshared const T[M] data, int32_t byteOffset16ByteAligned = 0)
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm{
+ result:$$CoopVec<T, N> = OpCooperativeVectorLoadNV &data $byteOffset16ByteAligned None
+ };
+ case hlsl:
+ CoopVec<T, N> ret;
+ ret.__Load(data, __byteToElemOffset<T>(byteOffset16ByteAligned));
+ return ret;
+ default:
+ CoopVec<T,N> result;
+ for(int i = 0; i < N; ++i)
+ result[i] = data[i + __byteToElemOffset<T>(byteOffset16ByteAligned)];
+ return result;
+ }
+ }
+
+ /// spirv only loading from a groupshared array of any type
+ [ForceInline]
+ [__NoSideEffect]
+ [require(spirv, cooperative_vector)]
+ static CoopVec<T, N> loadAny<U : __BuiltinArithmeticType, let M : int>(__constref groupshared const U[M] data, int32_t byteOffset16ByteAligned = 0)
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm{
+ result:$$CoopVec<T, N> = OpCooperativeVectorLoadNV &data $byteOffset16ByteAligned None
+ };
+ }
+ }
+ [ForceInline]
+ [__NoSideEffect]
+ [require(spirv, cooperative_vector)]
+ static CoopVec<T, N> loadAny<U : __BuiltinArithmeticType, let M : int, let L : int>(__constref groupshared const vector<U, L>[M] data, int32_t byteOffset16ByteAligned = 0)
+ {
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm{
+ result:$$CoopVec<T, N> = OpCooperativeVectorLoadNV &data $byteOffset16ByteAligned None
+ };
+ }
+ }
+
+ //
+ // Subscript
+ //
+
+ __intrinsic_op($(kIROp_GetElement))
+ [__NoSideEffect]
+ T __indexRead(int index);
+
+ __intrinsic_op($(kIROp_GetElementPtr))
+ [__ref]
+ [__NoSideEffect]
+ Ref<T> __indexRef(int index);
+
+ [ForceInline]
+ [__NoSideEffect]
+ int getCount()
+ {
+ return N;
+ }
+
+ __subscript(int index) -> T
+ {
+ [__NoSideEffect]
+ [nonmutating]
+ get
+ {
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm ".ReadFromIndex";
+ default: return __indexRead(index);
+ }
+ }
+
+ [mutating]
+ set
+ {
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm ".WriteToIndex";
+ default: __indexRef(index) = newValue;
+ }
+ }
+
+ // Unavailable on HLSL
+ // The CoopVector HLSL spec says that indexing with a subscript
+ // operation can work, but dxc currently crashes with this
+ // __intrinsic_op($(kIROp_GetElementPtr))
+ // [__ref]
+ // ref;
+ }
+
+ static CoopVec<T, N> replicate(T t)
+ {
+ CoopVec<T, N> ret;
+ ret.fill(t);
+ return ret;
+ }
+
+ //
+ // Equality and ordering
+ //
+
+ bool equals(This other)
+ {
+ for (int i = 0; i < N; i++)
+ {
+ if (this[i] != other[i])
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+ bool lessThan(This other)
+ {
+ for (int i = 0; i < N; i++)
+ {
+ if (this[i] < other[i])
+ {
+ return true;
+ }
+ else if (this[i] > other[i])
+ {
+ return false;
+ }
+ }
+ return false;
+ }
+ bool lessThanOrEquals(This other)
+ {
+ for (int i = 0; i < N; i++)
+ {
+ if (this[i] < other[i])
+ {
+ return true;
+ }
+ else if (this[i] > other[i])
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ //
+ // Arithmetic
+ //
+
+ __intrinsic_op($(kIROp_Add))
+ This __pureAdd(This other);
+
+ [mutating]
+ [require(hlsl)]
+ void __mutAdd(This other)
+ {
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm ".Add";
+ }
+ }
+
+ // TODO: Why is this ForceInline necessary for hlsl, dxc bug?
+ [ForceInline]
+ This add(This other)
+ {
+ __target_switch
+ {
+ case hlsl:
+ This ret = this;
+ ret.__mutAdd(other);
+ return ret;
+ default: return __pureAdd(other);
+ }
+ }
+
+ __intrinsic_op($(kIROp_Sub))
+ This __pureSub(This other);
+ [mutating]
+ [require(hlsl)]
+ void __mutSub(This other)
+ { __target_switch { case hlsl: __intrinsic_asm ".Subtract"; } }
+
+ [ForceInline]
+ This sub(This other)
+ {
+ __target_switch
+ {
+ case hlsl:
+ This ret = this;
+ ret.__mutSub(other);
+ return ret;
+ default: return __pureSub(other);
+ }
+ }
+
+ __intrinsic_op($(kIROp_Mul))
+ This __pureMul(This other);
+
+ [mutating]
+ [require(hlsl)]
+ void __mutMul(This other)
+ { __target_switch { case hlsl: __intrinsic_asm ".Multiply"; } }
+
+ [ForceInline]
+ This mul(This other)
+ {
+ __target_switch
+ {
+ case hlsl:
+ This ret = this;
+ ret.__mutMul(other);
+ return ret;
+ default: return __pureMul(other);
+ }
+ }
+
+ __intrinsic_op($(kIROp_Div))
+ This __pureDiv(This other);
+
+ [mutating]
+ [require(hlsl)]
+ void __mutDiv(This other)
+ { __target_switch { case hlsl: __intrinsic_asm ".Divide"; } }
+
+ [ForceInline]
+ This div(This other)
+ {
+ __target_switch
+ {
+ case hlsl:
+ This ret = this;
+ ret.__mutDiv(other);
+ return ret;
+ default: return __pureDiv(other);
+ }
+ }
+
+ [mutating]
+ [require(hlsl)]
+ void __mutMod(This other)
+ { __target_switch { case hlsl: __intrinsic_asm ".Mod"; } }
+
+ [ForceInline]
+ This mod(This other)
+ {
+ __target_switch
+ {
+ case hlsl:
+ This ret = this;
+ ret.__mutMod(other);
+ return ret;
+ default:
+ This ret;
+ for(int i = 0; i < N; ++i)
+ ret[i] = this[i] % other[i];
+ return ret;
+ }
+ }
+
+ __intrinsic_op($(kIROp_Neg))
+ This __pureNeg(This other);
+
+ //[ForceInline]
+ This neg()
+ {
+ __target_switch
+ {
+ case hlsl:
+ This ret = this;
+ for(int i = 0; i < N; ++i)
+ ret[i] = -this[i];
+ return ret;
+ default: return __pureNeg(this);
+ }
+ }
+
+ [mutating]
+ [require(hlsl)]
+ void __mutScalarMul(T t)
+ { __target_switch { case hlsl: __intrinsic_asm ".ScalarMultiply"; } }
+
+ [mutating]
+ [require(hlsl)]
+ void __mutMin(This other)
+ { __target_switch { case hlsl: __intrinsic_asm ".Min"; } }
+ [mutating]
+ [require(hlsl)]
+ void __mutMax(This other)
+ { __target_switch { case hlsl: __intrinsic_asm ".Max"; } }
+ [mutating]
+ [require(hlsl)]
+ void __mutClamp(This minVal, This maxVal)
+ { __target_switch { case hlsl: __intrinsic_asm ".Clamp"; } }
+
+ //
+ // Internal utilities for loading and storing
+ //
+
+ [mutating]
+ [require(hlsl, byteaddressbuffer)]
+ void __Load(const ByteAddressBuffer buffer, uint byteOffset, uint alignment = 0)
+ { __target_switch { case hlsl: __intrinsic_asm ".Load"; } }
+
+ [mutating]
+ [require(hlsl, byteaddressbuffer_rw)]
+ void __Load(const RWByteAddressBuffer buffer, uint byteOffset, uint alignment = 0)
+ { __target_switch { case hlsl: __intrinsic_asm ".Load"; } }
+
+ __generic<let M : int>
+ [mutating]
+ // Careful, this takes the offset in elements
+ [require(hlsl)]
+ void __Load(__constref groupshared T buffer[M], uint elemOffset)
+ { __target_switch { case hlsl: __intrinsic_asm ".Load"; } }
+
+ [require(hlsl, byteaddressbuffer_rw)]
+ void __Store(RWByteAddressBuffer buffer, uint byteOffset, uint alignment = 0)
+ { __target_switch { case hlsl: __intrinsic_asm ".Store"; } }
+
+ __generic<let M : int>
+ [require(hlsl)]
+ // Careful, this takes the offset in elements
+ void __Store(__ref groupshared T buffer[M], uint elemOffset)
+ { __target_switch { case hlsl: __intrinsic_asm ".Store"; } }
+
+${{{{
+static const struct {
+ bool isRW;
+ char const* type;
+} kByteAddressBufferCases[] =
+{
+ {true, "RWByteAddressBuffer"},
+ {false, "ByteAddressBuffer"}
+};
+for(auto buffer : kByteAddressBufferCases) {
+}}}}
+ [mutating]
+ [require(hlsl, byteaddressbuffer_rw)]
+ void __mutMatMul<U : __BuiltinArithmeticType, let K : int>(
+ CoopVec<U, K> input, uint inputInterpretationHLSL,
+ $(buffer.type) matrix, uint matrixOffset, uint matrixInterpretationHLSL,
+ uint m, uint k, uint memoryLayoutHLSL, bool transpose, uint matrixStride)
+ {
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm ".MatMul";
+ }
+ }
+
+ [mutating]
+ [require(hlsl, byteaddressbuffer_rw)]
+ void __mutMatMulAdd<U : __BuiltinArithmeticType, let K : int>(
+ CoopVec<U, K> input, uint inputInterpretationHLSL,
+ $(buffer.type) matrix, uint matrixOffset, uint matrixInterpretationHLSL,
+ $(buffer.type) bias, uint biasOffset, uint biasInterpretationHLSL,
+ uint m, uint k, uint memoryLayoutHLSL, bool transpose, uint matrixStride)
+ {
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm ".MatMulAdd";
+ }
+ }
+
+ [mutating]
+ [ForceInline]
+ void matMulAccumPacked<U : __BuiltinArithmeticType, let PackedK : int>(
+ CoopVec<U, PackedK> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ constexpr int k,
+ $(buffer.type) matrix,
+ int32_t matrixOffset,
+ constexpr CoopVecComponentType matrixInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+ )
+ {
+ static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK
+ , "for non-packed inputInterpretation values k must be equal to the input vector length");
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK
+ , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor");
+
+ __target_switch
+ {
+ case hlsl:
+ let inputInterpretationHLSL = __getHLSLCoopVecComponentType(inputInterpretation);
+ let matrixInterpretationHLSL = __getHLSLCoopVecComponentType(matrixInterpretation);
+ let memoryLayoutHLSL = __getHLSLCoopVecMatrixLayout(memoryLayout);
+ This temp = this;
+ temp.__mutMatMul(
+ input,
+ inputInterpretationHLSL,
+ matrix,
+ matrixOffset,
+ matrixInterpretationHLSL,
+ N,
+ k,
+ memoryLayoutHLSL,
+ transpose,
+ matrixStride
+ );
+ this.__mutAdd(temp);
+ default: this = this + coopVecMatMulPacked<T, N, PackedK, U>(
+ input,
+ inputInterpretation,
+ k,
+ matrix,
+ matrixOffset,
+ matrixInterpretation,
+ memoryLayout,
+ transpose,
+ matrixStride
+ );
+ }
+ }
+
+ [mutating]
+ [ForceInline]
+ void matMulAccum<U : __BuiltinArithmeticType, let K : int>(
+ CoopVec<U, K> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ $(buffer.type) matrix,
+ int32_t matrixOffset,
+ constexpr CoopVecComponentType matrixInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+ )
+ {
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually");
+ this.matMulAccumPacked<U, K>(
+ input,
+ inputInterpretation,
+ K,
+ matrix,
+ matrixOffset,
+ matrixInterpretation,
+ memoryLayout,
+ transpose,
+ matrixStride
+ );
+ }
+
+ [mutating]
+ [ForceInline]
+ void matMulAddAccumPacked<U : __BuiltinArithmeticType, let PackedK : int>(
+ CoopVec<U, PackedK> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ constexpr int k,
+ $(buffer.type) matrix,
+ int32_t matrixOffset,
+ constexpr CoopVecComponentType matrixInterpretation,
+ $(buffer.type) bias,
+ int32_t biasOffset,
+ constexpr CoopVecComponentType biasInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+ )
+ {
+ static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK
+ , "for non-packed inputInterpretation values k must be equal to the input vector length");
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK
+ , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor");
+
+ __target_switch
+ {
+ case hlsl:
+ let inputInterpretationHLSL = __getHLSLCoopVecComponentType(inputInterpretation);
+ let matrixInterpretationHLSL = __getHLSLCoopVecComponentType(matrixInterpretation);
+ let biasInterpretationHLSL = __getHLSLCoopVecComponentType(biasInterpretation);
+ let memoryLayoutHLSL = __getHLSLCoopVecMatrixLayout(memoryLayout);
+ This temp = this;
+ temp.__mutMatMulAdd(
+ input,
+ inputInterpretationHLSL,
+ matrix,
+ matrixOffset,
+ matrixInterpretationHLSL,
+ bias,
+ biasOffset,
+ biasInterpretationHLSL,
+ N,
+ k,
+ memoryLayoutHLSL,
+ transpose,
+ matrixStride
+ );
+ this.__mutAdd(temp);
+ default: this = this + coopVecMatMulAddPacked<T, N, PackedK, U>(
+ input,
+ inputInterpretation,
+ k,
+ matrix,
+ matrixOffset,
+ matrixInterpretation,
+ bias,
+ biasOffset,
+ biasInterpretation,
+ memoryLayout,
+ transpose,
+ matrixStride
+ );
+ }
+ }
+
+ [mutating]
+ [ForceInline]
+ void matMulAddAccum<U : __BuiltinArithmeticType, let K : int>(
+ CoopVec<U, K> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ $(buffer.type) matrix,
+ int32_t matrixOffset,
+ constexpr CoopVecComponentType matrixInterpretation,
+ $(buffer.type) bias,
+ int32_t biasOffset,
+ constexpr CoopVecComponentType biasInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+ )
+ {
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually");
+ this.matMulAddAccumPacked<U, K>(
+ input,
+ inputInterpretation,
+ K,
+ matrix,
+ matrixOffset,
+ matrixInterpretation,
+ bias,
+ biasOffset,
+ biasInterpretation,
+ memoryLayout,
+ transpose,
+ matrixStride
+ );
+ }
+
+
+${{{{
+}
+}}}}
+}
+
+__intrinsic_op($(kIROp_MakeCoopVectorFromValuePack))
+CoopVec<T, N> __makeCoopVec<T : __BuiltinArithmeticType, let N : int, each U>(expand each U args);
+
+__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType, let N : int>
+__intrinsic_op($(kIROp_IntCast))
+[require(cooperative_vector)]
+CoopVec<T,N> __int_cast(CoopVec<U,N> val);
+
+__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType, let N : int>
+__intrinsic_op($(kIROp_FloatCast))
+[require(cooperative_vector)]
+CoopVec<T,N> __real_cast(CoopVec<U,N> val);
+
+__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType, let N : int>
+__intrinsic_op($(kIROp_CastIntToFloat))
+[require(cooperative_vector)]
+CoopVec<T,N> __int_to_float_cast(CoopVec<U,N> val);
+
+__generic<T:__BuiltinArithmeticType, U:__BuiltinArithmeticType, let N : int>
+__intrinsic_op($(kIROp_CastFloatToInt))
+[require(cooperative_vector)]
+CoopVec<T,N> __float_to_int_cast(CoopVec<U,N> val);
+
+__generic<T : __BuiltinArithmeticType, let N : int>
+[ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> operator *(CoopVec<T, N> lhs, const T rhs)
+{
+ __target_switch
+ {
+ case spirv:
+ if (__isFloat<T>())
+ {
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpVectorTimesScalar $lhs $rhs;
+ };
+ }
+ else
+ {
+ for (int i = 0; i < N; ++i)
+ {
+ lhs[i] *= rhs;
+ }
+ return lhs;
+ }
+ case hlsl:
+ CoopVec<T, N> ret = lhs;
+ ret.__mutScalarMul(rhs);
+ return ret;
+ default:
+ for (int i = 0; i < N; ++i)
+ {
+ lhs[i] *= rhs;
+ }
+ return lhs;
+ }
+}
+
+__generic<T : __BuiltinArithmeticType, let N : int>
+[ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> operator *(const T lhs, CoopVec<T, N> rhs)
+{
+ return rhs * lhs;
+}
+
+[ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> min<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> x, CoopVec<T, N> y)
+{
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpExtInst glsl450 FMin $x $y;
+ };
+ case hlsl:
+ CoopVec<T, N> ret = x;
+ ret.__mutMin(y);
+ return ret;
+ default:
+ CoopVec<T, N> ret;
+ for(int i = 0; i < N; ++i)
+ ret[i] = min(x[i], y[i]);
+
+ return ret;
+ }
+}
+
+[ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> max<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> x, CoopVec<T, N> y)
+{
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpExtInst glsl450 FMax $x $y;
+ };
+ case hlsl:
+ CoopVec<T, N> ret = x;
+ ret.__mutMax(y);
+ return ret;
+ default:
+ CoopVec<T, N> ret;
+ for(int i = 0; i < N; ++i)
+ ret[i] = max(x[i], y[i]);
+ return ret;
+ }
+}
+
+[ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> clamp<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> x, CoopVec<T, N> minVal, CoopVec<T, N> maxVal)
+{
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpExtInst glsl450 FClamp $x $minVal $maxVal;
+ };
+ case hlsl:
+ CoopVec<T, N> ret = x;
+ ret.__mutClamp(minVal, maxVal);
+ return ret;
+ default:
+ CoopVec<T, N> ret;
+ for(int i = 0; i < N; ++i)
+ ret[i] = clamp(x[i], minVal[i], maxVal[i]);
+ return ret;
+ }
+}
+
+[ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> min<T : __BuiltinIntegerType, let N : int>(CoopVec<T, N> x, CoopVec<T, N> y)
+{
+ __target_switch
+ {
+ case spirv:
+ if (__isSignedInt<T>())
+ {
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpExtInst glsl450 SMin $x $y
+ };
+ }
+ else
+ {
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpExtInst glsl450 UMin $x $y
+ };
+ }
+ case hlsl:
+ CoopVec<T, N> ret = x;
+ ret.__mutMin(y);
+ return ret;
+ default:
+ CoopVec<T, N> ret;
+ for(int i = 0; i < N; ++i)
+ ret[i] = min(x[i], y[i]);
+
+ return ret;
+ }
+}
+
+// [ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> max<T : __BuiltinIntegerType, let N : int>(CoopVec<T, N> x, CoopVec<T, N> y)
+{
+ __target_switch
+ {
+ case spirv:
+ if (__isSignedInt<T>())
+ {
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpExtInst glsl450 SMax $x $y
+ };
+ }
+ else
+ {
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpExtInst glsl450 UMax $x $y
+ };
+ }
+ case hlsl:
+ CoopVec<T, N> ret = x;
+ ret.__mutMax(y);
+ return ret;
+ default:
+ CoopVec<T, N> ret;
+ for(int i = 0; i < N; ++i)
+ ret[i] = max(x[i], y[i]);
+ return ret;
+ }
+}
+
+// [ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> clamp<T : __BuiltinIntegerType, let N : int>(CoopVec<T, N> x, CoopVec<T, N> minVal, CoopVec<T, N> maxVal)
+{
+ __target_switch
+ {
+ case spirv:
+ if (__isSignedInt<T>())
+ {
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpExtInst glsl450 SClamp $x $minVal $maxVal
+ };
+ }
+ else
+ {
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpExtInst glsl450 UClamp $x $minVal $maxVal
+ };
+ }
+ case hlsl:
+ CoopVec<T, N> ret = x;
+ ret.__mutClamp(minVal, maxVal);
+ return ret;
+ default:
+ CoopVec<T, N> ret;
+ for(int i = 0; i < N; ++i)
+ ret[i] = clamp(x[i], minVal[i], maxVal[i]);
+ return ret;
+ }
+}
+
+// [ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> step<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> edge, CoopVec<T, N> x)
+{
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpExtInst glsl450 Step $edge $x;
+ };
+ default:
+ CoopVec<T, N> ret;
+ for(int i = 0; i < N; ++i)
+ ret[i] = step(edge[i], x[i]);
+ return ret;
+ }
+}
+
+// [ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> exp<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> x)
+{
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpExtInst glsl450 Exp $x;
+ };
+ default:
+ CoopVec<T, N> ret;
+ for(int i = 0; i < N; ++i)
+ ret[i] = exp(x[i]);
+ return ret;
+ }
+}
+
+// [ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> log<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> x)
+{
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpExtInst glsl450 Log $x;
+ };
+ default:
+ CoopVec<T, N> ret;
+ for(int i = 0; i < N; ++i)
+ ret[i] = log(x[i]);
+ return ret;
+ }
+}
+
+// [ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> tanh<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> x)
+{
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpExtInst glsl450 Tanh $x;
+ };
+ default:
+ CoopVec<T, N> ret;
+ for(int i = 0; i < N; ++i)
+ ret[i] = tanh(x[i]);
+ return ret;
+ }
+}
+
+// TODO: Why does this fail when inlined on HLSL,
+// We generate some really weird code...
+// [ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> atan<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> yOverX)
+{
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpExtInst glsl450 Atan $yOverX;
+ };
+ default:
+ CoopVec<T, N> ret;
+ for(int i = 0; i < N; ++i)
+ ret[i] = atan(yOverX[i]);
+ return ret;
+ }
+}
+
+// [ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> fma<T : __BuiltinFloatingPointType, let N : int>(CoopVec<T, N> a, CoopVec<T, N> b, CoopVec<T, N> c)
+{
+ // TODO: Investigate, why does this fail if it's not inlined
+ // replacing fma with mad below also fixes things...
+ // dxc generated substantially different code
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ result:$$CoopVec<T, N> = OpExtInst glsl450 Fma $a $b $c;
+ };
+ default:
+ CoopVec<T, N> ret;
+ for(int i = 0; i < N; ++i)
+ ret[i] = mad(a[i], b[i], c[i]);
+ return ret;
+ }
+}
+
+// Buffers from which values of arbitrary type can be loaded from byte offsets
+interface IPhysicalBuffer
+{
+ [__unsafeForceInlineEarly]
+ T LoadByteOffset<T>(int offset);
+
+ [__unsafeForceInlineEarly]
+ Ptr<uint32_t[]> GetBufferPointer();
+}
+
+// Buffers to which values of arbitrary type can be stored at byte offsets
+interface IRWPhysicalBuffer : IPhysicalBuffer
+{
+ [__unsafeForceInlineEarly]
+ void StoreByteOffset<T>(int offset, T element);
+}
+
+extension ByteAddressBuffer : IPhysicalBuffer
+{
+ [__unsafeForceInlineEarly]
+ Ptr<uint32_t[]> GetBufferPointer()
+ {
+ return __getStructuredBufferPtr(__getEquivalentStructuredBuffer<uint32_t>(this));
+ }
+
+ [__unsafeForceInlineEarly]
+ T LoadByteOffset<T>(int offset)
+ {
+ return this.Load<T>(offset);
+ }
+}
+
+extension RWByteAddressBuffer : IPhysicalBuffer
+{
+ [__unsafeForceInlineEarly]
+ Ptr<uint32_t[]> GetBufferPointer()
+ {
+ return __getStructuredBufferPtr(__getEquivalentStructuredBuffer<uint32_t>(this));
+ }
+
+ [__unsafeForceInlineEarly]
+ T LoadByteOffset<T>(int offset)
+ {
+ return this.Load<T>(offset);
+ }
+}
+
+extension RWByteAddressBuffer : IRWPhysicalBuffer
+{
+ [__unsafeForceInlineEarly]
+ void StoreByteOffset<T>(int offset, T element)
+ {
+ return this.Store<T>(offset, element);
+ }
+}
+
+
+//
+// Convenience loading functions for cooperative vectors which infer the
+// element type for structured buffers and groupshared arrays (and ByteAddressBuffers for consistency
+//
+
+[ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> coopVecLoad<let N : int, T : __BuiltinArithmeticType>(ByteAddressBuffer buffer, int32_t byteOffset16ByteAligned = 0)
+{
+ return CoopVec<T, N>.load(buffer, byteOffset16ByteAligned);
+}
+
+[ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> coopVecLoad<let N : int, T : __BuiltinArithmeticType>(RWByteAddressBuffer buffer, int32_t byteOffset16ByteAligned = 0)
+{
+ return CoopVec<T, N>.load(buffer, byteOffset16ByteAligned);
+}
+
+[ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> coopVecLoad<let N : int, T : __BuiltinArithmeticType>(StructuredBuffer<T> buffer, int32_t byteOffset16ByteAligned = 0)
+{
+ return CoopVec<T, N>.load(buffer, byteOffset16ByteAligned);
+}
+
+[ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> coopVecLoad<let N : int, T : __BuiltinArithmeticType>(RWStructuredBuffer<T> buffer, int32_t byteOffset16ByteAligned = 0)
+{
+ return CoopVec<T, N>.load(buffer, byteOffset16ByteAligned);
+}
+
+// Groupshared
+[ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, N> coopVecLoadGroupshared<let N : int, T : __BuiltinArithmeticType, let M : int>(__constref groupshared const T[M] data, int32_t byteOffset16ByteAligned = 0)
+{
+ return CoopVec<T, N>.load(data, byteOffset16ByteAligned);
+}
+
+//
+// Coop Vector matrix multiplication
+//
+
+enum CoopVecMatrixLayout
+{
+ RowMajor,
+ ColumnMajor,
+ InferencingOptimal,
+ TrainingOptimal
+};
+
+enum CoopVecComponentType
+{
+ FloatE4M3,
+ FloatE5M2,
+ Float16,
+ Float32,
+ Float64,
+ SignedInt8,
+ SignedInt16,
+ SignedInt32,
+ SignedInt64,
+ SignedInt8Packed,
+ UnsignedInt8,
+ UnsignedInt16,
+ UnsignedInt32,
+ UnsignedInt64,
+ UnsignedInt8Packed
+};
+
+[ForceInline]
+int __inputInterpretationPackingFactor(CoopVecComponentType componentType)
+{
+ switch (componentType)
+ {
+ case CoopVecComponentType::SignedInt8Packed:
+ case CoopVecComponentType::UnsignedInt8Packed:
+ return 4;
+ }
+ return 1;
+}
+
+[ForceInline]
+bool __isPackedInputInterpretation(CoopVecComponentType componentType)
+{
+ return __inputInterpretationPackingFactor(componentType) != 1;
+}
+
+// TODO: We might consider some way of specifying these from our lookup tables
+[ForceInline]
+uint32_t __getSpvCoopVecMatrixLayout(CoopVecMatrixLayout layout)
+{
+ switch (layout)
+ {
+ case CoopVecMatrixLayout::RowMajor:
+ return 0;
+ case CoopVecMatrixLayout::ColumnMajor:
+ return 1;
+ case CoopVecMatrixLayout::InferencingOptimal:
+ return 2;
+ case CoopVecMatrixLayout::TrainingOptimal:
+ return 3;
+ default:
+ static_assert(false, "unsupported layout value");
+ }
+ return 0xffffffff;
+}
+
+[ForceInline]
+uint32_t __getHLSLCoopVecMatrixLayout(CoopVecMatrixLayout layout)
+{
+ switch (layout)
+ {
+ // TODO: Check these are the same
+ case CoopVecMatrixLayout::RowMajor:
+ return 0;
+ case CoopVecMatrixLayout::ColumnMajor:
+ return 1;
+ case CoopVecMatrixLayout::InferencingOptimal:
+ return 2;
+ case CoopVecMatrixLayout::TrainingOptimal:
+ return 3;
+ default:
+ static_assert(false, "unsupported layout value");
+ }
+ return 0xffffffff;
+}
+
+[ForceInline]
+uint32_t __getSpvCoopVecComponentType(CoopVecComponentType componentType)
+{
+ switch (componentType)
+ {
+ case CoopVecComponentType::Float16:
+ return 0;
+ case CoopVecComponentType::Float32:
+ return 1;
+ case CoopVecComponentType::Float64:
+ return 2;
+ case CoopVecComponentType::SignedInt8:
+ return 3;
+ case CoopVecComponentType::SignedInt16:
+ return 4;
+ case CoopVecComponentType::SignedInt32:
+ return 5;
+ case CoopVecComponentType::SignedInt8Packed:
+ return 1000491000;
+ case CoopVecComponentType::SignedInt64:
+ return 6;
+ case CoopVecComponentType::UnsignedInt8:
+ return 7;
+ case CoopVecComponentType::UnsignedInt16:
+ return 8;
+ case CoopVecComponentType::UnsignedInt32:
+ return 9;
+ case CoopVecComponentType::UnsignedInt8Packed:
+ return 1000491001;
+ case CoopVecComponentType::UnsignedInt64:
+ return 10;
+ case CoopVecComponentType::FloatE4M3:
+ return 1000491002;
+ case CoopVecComponentType::FloatE5M2:
+ return 1000491003;
+ default:
+ static_assert(false, "unsupported componentType value");
+ }
+ return 0xffffffff;
+}
+
+[ForceInline]
+uint32_t __getHLSLCoopVecComponentType(CoopVecComponentType componentType)
+{
+ switch (componentType)
+ {
+ case CoopVecComponentType::Float16:
+ return 0;
+ case CoopVecComponentType::Float32:
+ return 1;
+ case CoopVecComponentType::UnsignedInt8:
+ return 2;
+ case CoopVecComponentType::UnsignedInt16:
+ return 3;
+ case CoopVecComponentType::UnsignedInt32:
+ return 4;
+ case CoopVecComponentType::SignedInt8:
+ return 5;
+ case CoopVecComponentType::SignedInt16:
+ return 6;
+ case CoopVecComponentType::SignedInt32:
+ return 7;
+ case CoopVecComponentType::SignedInt8Packed:
+ return 8;
+ case CoopVecComponentType::UnsignedInt8Packed:
+ return 9;
+ case CoopVecComponentType::FloatE4M3:
+ return 10;
+ case CoopVecComponentType::FloatE5M2:
+ return 11;
+ default:
+ static_assert(false, "unsupported componentType value");
+ }
+ return 32;
+}
+
+[ForceInline]
+uint32_t __coopVecComponentTypeStride(CoopVecComponentType componentType)
+{
+ switch (componentType)
+ {
+ case CoopVecComponentType::Float16:
+ return 2;
+ case CoopVecComponentType::Float32:
+ return 4;
+ case CoopVecComponentType::Float64:
+ return 8;
+ case CoopVecComponentType::SignedInt8:
+ return 1;
+ case CoopVecComponentType::SignedInt16:
+ return 2;
+ case CoopVecComponentType::SignedInt32:
+ return 4;
+ case CoopVecComponentType::SignedInt8Packed:
+ return 4;
+ case CoopVecComponentType::SignedInt64:
+ return 8;
+ case CoopVecComponentType::UnsignedInt8:
+ return 1;
+ case CoopVecComponentType::UnsignedInt16:
+ return 2;
+ case CoopVecComponentType::UnsignedInt32:
+ return 4;
+ case CoopVecComponentType::UnsignedInt8Packed:
+ return 4;
+ case CoopVecComponentType::UnsignedInt64:
+ return 8;
+ default:
+ static_assert(false, "unsupported componentType value");
+ }
+ return 0xffffffff;
+}
+
+${{{{
+static const struct {
+ bool isRW;
+ char const* type;
+} kByteAddressBufferCases_[] =
+{
+ {true, "RWByteAddressBuffer"},
+ {false, "ByteAddressBuffer"},
+};
+for(auto buffer : kByteAddressBufferCases_) {
+}}}}
+
+// TODO: Can we ForceInline for just hlsl? the other platforms don't really
+// need it
+[ForceInline]
+[require(cooperative_vector)]
+__generic<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType>
+CoopVec<T, M> coopVecMatMulPacked(
+ CoopVec<U, PackedK> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ constexpr int k,
+ $(buffer.type) matrix,
+ int32_t matrixOffset,
+ constexpr CoopVecComponentType matrixInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+)
+{
+ static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK
+ , "for non-packed inputInterpretation values k must be equal to the input vector length");
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK
+ , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor");
+
+ __target_switch
+ {
+ case spirv:
+ let m : int32_t = M;
+ let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation);
+ let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation);
+ let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout);
+ let matrixPtr = matrix.GetBufferPointer();
+ return spirv_asm
+ {
+ result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride;
+ };
+
+ case hlsl:
+ var ret = CoopVec<T, M>(0);
+ let inputInterpretationHLSL = __getHLSLCoopVecComponentType(inputInterpretation);
+ let matrixInterpretationHLSL = __getHLSLCoopVecComponentType(matrixInterpretation);
+ let memoryLayoutHLSL = __getHLSLCoopVecMatrixLayout(memoryLayout);
+ ret.__mutMatMul(
+ input,
+ inputInterpretationHLSL,
+ matrix,
+ matrixOffset,
+ matrixInterpretationHLSL,
+ M,
+ k,
+ memoryLayoutHLSL,
+ transpose,
+ matrixStride
+ );
+ return ret;
+
+ default:
+ var result = CoopVec<T, M>(0);
+ var v = CoopVec<T, PackedK*4>();
+ // TODO: Insert language from the spec to describe this madness
+ if(k == PackedK)
+ {
+ for(int i = 0; i < k; ++i)
+ v[i] = __arithmetic_cast<T>(input[i]);
+ }
+ else
+ {
+ static_assert(k == PackedK*4, "K must be 4 * PackedK for the non-spirv coopVecMatMulPacked backend");
+ static_assert(inputInterpretation == CoopVecComponentType::SignedInt8Packed ||
+ inputInterpretation == CoopVecComponentType::UnsignedInt8Packed,
+ "Packing is only supported for 4*int8 or 4*uint8 vectors");
+ for(int i = 0; i < k; ++i)
+ {
+ let n = __arithmetic_cast<int32_t>(input[i/4]);
+ let s = int8_t(n >> ((i % 4) * 8) & 0xff);
+ v[i] = T(s);
+ }
+ }
+
+ for (int i = 0; i < M; ++i)
+ {
+ for (int j = 0; j < k; ++j)
+ {
+ int row = (transpose ^ memoryLayout == CoopVecMatrixLayout::ColumnMajor) ? j : i;
+ int col = (transpose ^ memoryLayout == CoopVecMatrixLayout::ColumnMajor) ? i : j;
+ int offset = matrixOffset + (row * matrixStride + col * __coopVecComponentTypeStride(matrixInterpretation));
+
+ switch (matrixInterpretation)
+ {
+ case CoopVecComponentType::Float16:
+ result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<half>(offset)) * v[j];
+ break;
+ case CoopVecComponentType::Float32:
+ result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<float>(offset)) * v[j];
+ break;
+ case CoopVecComponentType::Float64:
+ result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<double>(offset)) * v[j];
+ break;
+ case CoopVecComponentType::SignedInt8:
+ result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<int8_t>(offset)) * v[j];
+ break;
+ case CoopVecComponentType::SignedInt16:
+ result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<int16_t>(offset)) * v[j];
+ break;
+ case CoopVecComponentType::SignedInt32:
+ result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<int32_t>(offset)) * v[j];
+ break;
+ case CoopVecComponentType::SignedInt64:
+ result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<int64_t>(offset)) * v[j];
+ break;
+ case CoopVecComponentType::SignedInt8Packed:
+ result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<int8_t>(offset)) * v[j];
+ break;
+ case CoopVecComponentType::UnsignedInt8:
+ result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<uint8_t>(offset)) * v[j];
+ break;
+ case CoopVecComponentType::UnsignedInt16:
+ result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<uint16_t>(offset)) * v[j];
+ break;
+ case CoopVecComponentType::UnsignedInt32:
+ result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<uint32_t>(offset)) * v[j];
+ break;
+ case CoopVecComponentType::UnsignedInt64:
+ result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<uint64_t>(offset)) * v[j];
+ break;
+ case CoopVecComponentType::UnsignedInt8Packed:
+ result[i] += __arithmetic_cast<T>(matrix.LoadByteOffset<uint8_t>(offset)) * v[j];
+ break;
+ }
+ }
+ }
+
+ return result;
+ }
+}
+
+[ForceInline]
+[require(cooperative_vector)]
+__generic<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType>
+CoopVec<T, M> coopVecMatMul(
+ CoopVec<U, K> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ $(buffer.type) matrix,
+ int32_t matrixOffset,
+ constexpr CoopVecComponentType matrixInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+)
+{
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually");
+ return coopVecMatMulPacked<
+ T, M, K, U>(
+ input,
+ inputInterpretation,
+ K,
+ matrix,
+ matrixOffset,
+ matrixInterpretation,
+ memoryLayout,
+ transpose,
+ matrixStride);
+}
+
+[ForceInline]
+[require(cooperative_vector)]
+CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType>(
+ CoopVec<U, PackedK> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ constexpr int k,
+ $(buffer.type) matrix,
+ int32_t matrixOffset,
+ constexpr CoopVecComponentType matrixInterpretation,
+ $(buffer.type) bias,
+ int32_t biasOffset,
+ constexpr CoopVecComponentType biasInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+)
+{
+ static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK
+ , "for non-packed inputInterpretation values k must be equal to the input vector length");
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK
+ , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor");
+
+ __target_switch
+ {
+ case spirv:
+ let m : int32_t = M;
+ let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation);
+ let biasInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(biasInterpretation);
+ let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation);
+ let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout);
+ let matrixPtr = matrix.GetBufferPointer();
+ let biasPtr = bias.GetBufferPointer();
+ return spirv_asm
+ {
+ result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $biasPtr $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride;
+ };
+
+ case hlsl:
+ var ret = CoopVec<T, M>(0);
+ let inputInterpretationHLSL = __getHLSLCoopVecComponentType(inputInterpretation);
+ let matrixInterpretationHLSL = __getHLSLCoopVecComponentType(matrixInterpretation);
+ let biasInterpretationHLSL = __getHLSLCoopVecComponentType(biasInterpretation);
+ let memoryLayoutHLSL = __getHLSLCoopVecMatrixLayout(memoryLayout);
+ ret.__mutMatMulAdd(
+ input,
+ inputInterpretationHLSL,
+ matrix,
+ matrixOffset,
+ matrixInterpretationHLSL,
+ bias,
+ biasOffset,
+ biasInterpretationHLSL,
+ M,
+ k,
+ memoryLayoutHLSL,
+ transpose,
+ matrixStride
+ );
+ return ret;
+
+ default:
+ var result = coopVecMatMulPacked<T, M, PackedK, U>(
+ input,
+ inputInterpretation,
+ k,
+ matrix,
+ matrixOffset,
+ matrixInterpretation,
+ memoryLayout,
+ transpose,
+ matrixStride);
+
+ for (int i = 0; i < M; ++i)
+ {
+ int b = biasOffset + i * __coopVecComponentTypeStride(biasInterpretation);
+ switch (biasInterpretation)
+ {
+ case CoopVecComponentType::Float16:
+ result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<half>(b));
+ break;
+ case CoopVecComponentType::Float32:
+ result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<float>(b));
+ break;
+ case CoopVecComponentType::Float64:
+ result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<double>(b));
+ break;
+ case CoopVecComponentType::SignedInt8:
+ result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<int8_t>(b));
+ break;
+ case CoopVecComponentType::SignedInt16:
+ result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<int16_t>(b));
+ break;
+ case CoopVecComponentType::SignedInt32:
+ result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<int32_t>(b));
+ break;
+ case CoopVecComponentType::SignedInt64:
+ result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<int64_t>(b));
+ break;
+ case CoopVecComponentType::SignedInt8Packed:
+ result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<int8_t>(b));
+ break;
+ case CoopVecComponentType::UnsignedInt8:
+ result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<uint8_t>(b));
+ break;
+ case CoopVecComponentType::UnsignedInt16:
+ result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<uint16_t>(b));
+ break;
+ case CoopVecComponentType::UnsignedInt32:
+ result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<uint32_t>(b));
+ break;
+ case CoopVecComponentType::UnsignedInt64:
+ result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<uint64_t>(b));
+ break;
+ case CoopVecComponentType::UnsignedInt8Packed:
+ result[i] += __arithmetic_cast<T>(bias.LoadByteOffset<uint8_t>(b));
+ break;
+ }
+ }
+
+ return result;
+ }
+}
+
+[ForceInline]
+[require(cooperative_vector)]
+__generic<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType>
+CoopVec<T, M> coopVecMatMulAdd(
+ CoopVec<U, K> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ $(buffer.type) matrix,
+ int32_t matrixOffset,
+ constexpr CoopVecComponentType matrixInterpretation,
+ $(buffer.type) bias,
+ int32_t biasOffset,
+ constexpr CoopVecComponentType biasInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+)
+{
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually");
+ return coopVecMatMulAddPacked<
+ T, M, K, U>(
+ input,
+ inputInterpretation,
+ K,
+ matrix,
+ matrixOffset,
+ matrixInterpretation,
+ bias,
+ biasOffset,
+ biasInterpretation,
+ memoryLayout,
+ transpose,
+ matrixStride);
+}
+
+//
+// Coop Vector accumulation
+//
+
+${{{{
+if(buffer.isRW)
+{
+}}}}
+[require(cooperative_vector)]
+void coopVecOuterProductAccumulate<T : __BuiltinArithmeticType, let M : int, let N : int>(
+ CoopVec<T, M> a,
+ CoopVec<T, N> b,
+ $(buffer.type) matrix,
+ int32_t matrixOffset,
+ constexpr uint matrixStride,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr CoopVecComponentType matrixInterpretation,
+)
+{
+ __target_switch
+ {
+ case spirv:
+ let matrixInterpretationSpirv : int = __getSpvCoopVecComponentType(matrixInterpretation);
+ let memoryLayoutSpirv : int = __getSpvCoopVecMatrixLayout(memoryLayout);
+ let matrixPtr = matrix.GetBufferPointer();
+ spirv_asm
+ {
+ OpCapability CooperativeVectorTrainingNV;
+ OpCooperativeVectorOuterProductAccumulateNV $matrixPtr $matrixOffset $a $b $memoryLayoutSpirv $matrixInterpretationSpirv $matrixStride;
+ };
+ default:
+ for (int i = 0; i < M; ++i)
+ {
+ for (int j = 0; j < N; ++j)
+ {
+ T product = a[i] * b[j];
+ int row = (memoryLayout == CoopVecMatrixLayout::ColumnMajor) ? j : i;
+ int col = (memoryLayout == CoopVecMatrixLayout::ColumnMajor) ? i : j;
+ int offset = matrixOffset + (row * matrixStride + col * __coopVecComponentTypeStride(matrixInterpretation));
+
+ switch (matrixInterpretation)
+ {
+ case CoopVecComponentType::Float16:
+ matrix.StoreByteOffset<half>(offset, matrix.LoadByteOffset<half>(offset) + __arithmetic_cast<half>(product));
+ break;
+ case CoopVecComponentType::Float32:
+ matrix.StoreByteOffset<float>(offset, matrix.LoadByteOffset<float>(offset) + __arithmetic_cast<float>(product));
+ break;
+ case CoopVecComponentType::Float64:
+ matrix.StoreByteOffset<double>(offset, matrix.LoadByteOffset<double>(offset) + __arithmetic_cast<double>(product));
+ break;
+ case CoopVecComponentType::SignedInt8:
+ matrix.StoreByteOffset<int8_t>(offset, matrix.LoadByteOffset<int8_t>(offset) + __arithmetic_cast<int8_t>(product));
+ break;
+ case CoopVecComponentType::SignedInt16:
+ matrix.StoreByteOffset<int16_t>(offset, matrix.LoadByteOffset<int16_t>(offset) + __arithmetic_cast<int16_t>(product));
+ break;
+ case CoopVecComponentType::SignedInt32:
+ matrix.StoreByteOffset<int32_t>(offset, matrix.LoadByteOffset<int32_t>(offset) + __arithmetic_cast<int32_t>(product));
+ break;
+ case CoopVecComponentType::SignedInt64:
+ matrix.StoreByteOffset<int64_t>(offset, matrix.LoadByteOffset<int64_t>(offset) + __arithmetic_cast<int64_t>(product));
+ break;
+ case CoopVecComponentType::SignedInt8Packed:
+ matrix.StoreByteOffset<int8_t>(offset, matrix.LoadByteOffset<int8_t>(offset) + __arithmetic_cast<int8_t>(product));
+ break;
+ case CoopVecComponentType::UnsignedInt8:
+ matrix.StoreByteOffset<uint8_t>(offset, matrix.LoadByteOffset<uint8_t>(offset) + __arithmetic_cast<uint8_t>(product));
+ break;
+ case CoopVecComponentType::UnsignedInt16:
+ matrix.StoreByteOffset<uint16_t>(offset, matrix.LoadByteOffset<uint16_t>(offset) + __arithmetic_cast<uint16_t>(product));
+ break;
+ case CoopVecComponentType::UnsignedInt32:
+ matrix.StoreByteOffset<uint32_t>(offset, matrix.LoadByteOffset<uint32_t>(offset) + __arithmetic_cast<uint32_t>(product));
+ break;
+ case CoopVecComponentType::UnsignedInt64:
+ matrix.StoreByteOffset<uint64_t>(offset, matrix.LoadByteOffset<uint64_t>(offset) + __arithmetic_cast<uint64_t>(product));
+ break;
+ case CoopVecComponentType::UnsignedInt8Packed:
+ matrix.StoreByteOffset<uint8_t>(offset, matrix.LoadByteOffset<uint8_t>(offset) + __arithmetic_cast<uint8_t>(product));
+ break;
+ }
+ }
+ }
+ }
+}
+
+[require(cooperative_vector)]
+void coopVecReduceSumAccumulate<T : __BuiltinArithmeticType, let N : int>(
+ CoopVec<T, N> v,
+ $(buffer.type) buffer,
+ int32_t offset
+)
+{
+ __target_switch
+ {
+ case spirv:
+ let bufferPtr = buffer.GetBufferPointer();
+ spirv_asm
+ {
+ OpCapability CooperativeVectorTrainingNV;
+ OpCooperativeVectorReduceSumAccumulateNV $bufferPtr $offset $v;
+ };
+ default:
+ for (int i = 0; i < N; ++i)
+ {
+ int byteOffset = offset + i * __naturalStrideOf<T>();
+ T currentValue = buffer.LoadByteOffset<T>(byteOffset);
+ T newValue = currentValue + __arithmetic_cast<T>(v[i]);
+ buffer.StoreByteOffset(byteOffset, newValue);
+ }
+ }
+}
+
+${{{{
+} // if rw
+} // buffer type loop
+}}}}
+
+${{{{
+static const struct {
+ bool isRW;
+ char const* type;
+} kStructuredBufferCases_[] =
+{
+ {true, "RWStructuredBuffer<IgnoredBufferElementType>"},
+ {false, "StructuredBuffer<IgnoredBufferElementType>"},
+};
+for(auto buffer : kStructuredBufferCases_) {
+}}}}
+
+[require(spirv, cooperative_vector)]
+__generic<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType,IgnoredBufferElementType>
+CoopVec<T, M> coopVecMatMulPacked(
+ CoopVec<U, PackedK> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ constexpr int k,
+ $(buffer.type) matrix,
+ int32_t matrixOffset,
+ constexpr CoopVecComponentType matrixInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+)
+{
+ static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK
+ , "for non-packed inputInterpretation values k must be equal to the input vector length");
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK
+ , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor");
+ __target_switch
+ {
+ case spirv:
+ let m : int32_t = M;
+ let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation);
+ let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation);
+ let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout);
+ let matrixPtr = __getStructuredBufferPtr(matrix);
+ return spirv_asm
+ {
+ result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride;
+ };
+ }
+}
+
+// specialized coopVecMatMul for non-packed inputs
+[require(spirv, cooperative_vector)]
+__generic<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType,IgnoredBufferElementType>
+CoopVec<T, M> coopVecMatMul(
+ CoopVec<U, K> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ $(buffer.type) matrix,
+ int32_t matrixOffset,
+ constexpr CoopVecComponentType matrixInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+)
+{
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually");
+ return coopVecMatMulPacked<
+ T, M, K, U, IgnoredBufferElementType>(
+ input,
+ inputInterpretation,
+ K,
+ matrix,
+ matrixOffset,
+ matrixInterpretation,
+ memoryLayout,
+ transpose,
+ matrixStride);
+}
+
+[require(spirv, cooperative_vector)]
+CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType, IgnoredBufferElementType>(
+ CoopVec<U, PackedK> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ constexpr int k,
+ $(buffer.type) matrix,
+ int32_t matrixOffset,
+ constexpr CoopVecComponentType matrixInterpretation,
+ $(buffer.type) bias,
+ int32_t biasOffset,
+ constexpr CoopVecComponentType biasInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+)
+{
+ static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK
+ , "for non-packed inputInterpretation values k must be equal to the input vector length");
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK
+ , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor");
+
+ __target_switch
+ {
+ case spirv:
+ let m : int32_t = M;
+ let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation);
+ let biasInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(biasInterpretation);
+ let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation);
+ let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout);
+ let matrixPtr = __getStructuredBufferPtr(matrix);
+ let biasPtr = __getStructuredBufferPtr(bias);
+ return spirv_asm
+ {
+ result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $matrixPtr $matrixOffset $matrixInterpretationSpirv $biasPtr $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride;
+ };
+ }
+}
+
+[require(spirv, cooperative_vector)]
+CoopVec<T, M> coopVecMatMulAdd<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType, IgnoredBufferElementType>(
+ CoopVec<U, K> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ $(buffer.type) matrix,
+ int32_t matrixOffset,
+ constexpr CoopVecComponentType matrixInterpretation,
+ $(buffer.type) bias,
+ int32_t biasOffset,
+ constexpr CoopVecComponentType biasInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+)
+{
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually");
+ return coopVecMatMulAddPacked<
+ T, M, K, U, IgnoredBufferElementType>(
+ input,
+ inputInterpretation,
+ K,
+ matrix,
+ matrixOffset,
+ matrixInterpretation,
+ bias,
+ biasOffset,
+ biasInterpretation,
+ memoryLayout,
+ transpose,
+ matrixStride);
+}
+
+//
+// Coop Vector accumulation
+//
+
+${{{{
+if(buffer.isRW)
+{
+}}}}
+[require(spirv, cooperative_vector_training)]
+void coopVecOuterProductAccumulate<T : __BuiltinArithmeticType, let M : int, let N : int, IgnoredBufferElementType>(
+ CoopVec<T, M> a,
+ CoopVec<T, N> b,
+ $(buffer.type) matrix,
+ int32_t matrixOffset,
+ constexpr uint matrixStride,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr CoopVecComponentType matrixInterpretation,
+)
+{
+ __target_switch
+ {
+ case spirv:
+ let matrixInterpretationSpirv : int = __getSpvCoopVecComponentType(matrixInterpretation);
+ let memoryLayoutSpirv : int = __getSpvCoopVecMatrixLayout(memoryLayout);
+ let matrixPtr = __getStructuredBufferPtr(matrix);
+ spirv_asm
+ {
+ OpCapability CooperativeVectorTrainingNV;
+ OpCooperativeVectorOuterProductAccumulateNV $matrixPtr $matrixOffset $a $b $memoryLayoutSpirv $matrixInterpretationSpirv $matrixStride;
+ };
+ }
+}
+
+[require(spirv, cooperative_vector_training)]
+void coopVecReduceSumAccumulate<T : __BuiltinArithmeticType, let N : int, IgnoredBufferElementType>(
+ CoopVec<T, N> v,
+ $(buffer.type) buffer,
+ int32_t offset
+)
+{
+ __target_switch
+ {
+ case spirv:
+ let bufferPtr = __getStructuredBufferPtr(buffer);
+ spirv_asm
+ {
+ OpCapability CooperativeVectorTrainingNV;
+ OpCooperativeVectorReduceSumAccumulateNV $bufferPtr $offset $v;
+ };
+ }
+}
+
+${{{{
+} // if rw
+} // buffer type loop
+}}}}
+
+
+[require(spirv, cooperative_vector_training)]
+void coopVecOuterProductAccumulate<T : __BuiltinArithmeticType, let M : int, let N : int, U : __BuiltinArithmeticType, let IgnoredBufferSize : int>(
+ CoopVec<T, M> a,
+ CoopVec<T, N> b,
+ __ref groupshared U[IgnoredBufferSize] matrix,
+ int32_t matrixOffset,
+ constexpr uint matrixStride,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr CoopVecComponentType matrixInterpretation,
+)
+{
+ __target_switch
+ {
+ case spirv:
+ let matrixInterpretationSpirv : int = __getSpvCoopVecComponentType(matrixInterpretation);
+ let memoryLayoutSpirv : int = __getSpvCoopVecMatrixLayout(memoryLayout);
+ spirv_asm
+ {
+ OpCapability CooperativeVectorTrainingNV;
+ OpCooperativeVectorOuterProductAccumulateNV &matrix $matrixOffset $a $b $memoryLayoutSpirv $matrixInterpretationSpirv $matrixStride;
+ };
+ }
+}
+
+[require(spirv, cooperative_vector_training)]
+void coopVecReduceSumAccumulate<T : __BuiltinArithmeticType, let N : int, U, let IgnoredBufferSize : int>(
+ CoopVec<T, N> v,
+ __ref groupshared U[IgnoredBufferSize] buffer,
+ int32_t offset
+)
+{
+ __target_switch
+ {
+ case spirv:
+ spirv_asm
+ {
+ OpCapability CooperativeVectorTrainingNV;
+ OpCooperativeVectorReduceSumAccumulateNV &buffer $offset $v;
+ };
+ }
+}
+
//@public:
/// Mark a variable as being workgroup uniform.
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 56e07d430..2581630dd 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -170,6 +170,11 @@ enum : ConversionCost
kConversionCost_ScalarIntegerToFloatMatrix =
kConversionCost_IntegerToFloatConversion + kConversionCost_ScalarToMatrix,
+ // Additional conversion cost to add when promoting from a scalar to
+ // a CoopVector (this will be added to the cost, if any, of converting
+ // the element type of the CoopVector)
+ kConversionCost_ScalarToCoopVector = 1,
+
// Additional cost when casting an LValue.
kConversionCost_LValueCast = 800,
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 23d89957a..29a52a93a 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -355,6 +355,29 @@ Type* AtomicType::getElementType()
return as<Type>(_getGenericTypeArg(this, 0));
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! CoopVectorExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+Type* CoopVectorExpressionType::getElementType()
+{
+ return as<Type>(_getGenericTypeArg(this, 0));
+}
+
+IntVal* CoopVectorExpressionType::getElementCount()
+{
+ return as<IntVal>(_getGenericTypeArg(this, 1));
+}
+
+void CoopVectorExpressionType::_toTextOverride(StringBuilder& out)
+{
+ out << toSlice("CoopVector<") << getElementType() << toSlice(",") << getElementCount()
+ << toSlice(">");
+}
+
+BasicExpressionType* CoopVectorExpressionType::_getScalarTypeOverride()
+{
+ return as<BasicExpressionType>(getElementType());
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void TypeType::_toTextOverride(StringBuilder& out)
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index 9d4297919..f60c0485e 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -475,6 +475,17 @@ class AtomicType : public DeclRefType
Type* getElementType();
};
+class CoopVectorExpressionType : public ArithmeticExpressionType
+{
+ SLANG_AST_CLASS(CoopVectorExpressionType)
+
+ void _toTextOverride(StringBuilder& out);
+ BasicExpressionType* _getScalarTypeOverride();
+
+ Type* getElementType();
+ IntVal* getElementCount();
+};
+
// The "type" of an expression that resolves to a type.
// For example, in the expression `float(2)` the sub-expression,
// `float` would have the type `TypeType(float)`.
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index 3bc54c080..3be09b8d3 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -208,6 +208,7 @@ def _sm_6_4 : _sm_6_3;
def _sm_6_5 : _sm_6_4;
def _sm_6_6 : _sm_6_5;
def _sm_6_7 : _sm_6_6;
+def _sm_6_8 : _sm_6_7;
/// Represents HLSL NVAPI support.
/// [Version]
@@ -524,6 +525,14 @@ def SPV_NV_compute_shader_derivatives : _spirv_1_0;
/// [EXT]
def SPV_GOOGLE_user_type : _spirv_1_0;
+/// Represents the SPIR-V extension for SPV_EXT_replicated_composites.
+/// [EXT]
+def SPV_EXT_replicated_composites : _spirv_1_0;
+
+/// Represents the SPIR-V extension for SPV_NV_cooperative_vector.
+/// [EXT]
+def SPV_NV_cooperative_vector : _spirv_1_6 + SPV_EXT_replicated_composites;
+
// SPIRV Capabilities.
/// Represents the SPIR-V capability for atomic float 32 add operations.
@@ -662,6 +671,18 @@ def spvDemoteToHelperInvocationEXT : SPV_EXT_demote_to_helper_invocation;
/// [EXT]
def spvDemoteToHelperInvocation : spvDemoteToHelperInvocationEXT;
+/// Represents the SPIR-V capability for replicated composites
+/// [EXT]
+def spvReplicatedCompositesEXT : SPV_EXT_replicated_composites;
+
+/// Represents the SPIR-V capability for cooperative vectors
+/// [EXT]
+def spvCooperativeVectorNV : SPV_NV_cooperative_vector;
+
+/// Represents the SPIR-V capability for cooperative vector training
+/// [EXT]
+def spvCooperativeVectorTrainingNV : SPV_NV_cooperative_vector;
+
/// Represents the SPIR-V capability for maximal reconvergence.
/// [EXT]
def spvMaximalReconvergenceKHR : SPV_KHR_maximal_reconvergence;
@@ -1033,6 +1054,14 @@ alias bufferreference = GL_EXT_buffer_reference;
/// Capabilities needed to use GLSL buffer-reference's with int64
/// [Compound]
alias bufferreference_int64 = bufferreference + GL_EXT_shader_explicit_arithmetic_types_int64;
+/// Capabilities needed to use cooperative vectors
+/// Note that cpp and cuda are supported via a fallback non-cooperative implementation
+/// No HLSL shader model bound yet
+/// [Compound]
+alias cooperative_vector = _sm_6_8 | cpp | _cuda_sm_9_0 | spvCooperativeVectorNV;
+/// Capabilities needed to train cooperative vectors
+/// [Compound]
+alias cooperative_vector_training = spvCooperativeVectorTrainingNV;
// Non-internal shader stages
//
@@ -1479,6 +1508,25 @@ alias sm_6_7 = sm_6_7_version
| sm_6_6
;
+/// HLSL shader model 6.8 and related capabilities of other targets.
+/// Does not include related GLSL/SPIRV extensions.
+/// [Version]
+alias sm_6_8_version = _sm_6_8
+ | _GLSL_460
+ | spirv_1_5
+ | _cuda_sm_9_0
+ | metal
+ | cpp
+ ;
+
+/// HLSL shader model 6.8 and related capabilities of other targets.
+/// Includes related GLSL/SPIRV extensions.
+/// [Version]
+alias sm_6_8 = sm_6_8_version
+ | sm_6_7
+ ;
+
+// Profiles
/// Use `sm_4_0` instead
/// [Other]
alias DX_4_0 = sm_4_0;
@@ -1527,6 +1575,10 @@ alias DX_6_6 = sm_6_6;
/// [Other]
alias DX_6_7 = sm_6_7;
+/// Use `sm_6_8` instead
+/// [Other]
+alias DX_6_8 = sm_6_8;
+
// GLSL profile capabilities
//
/// GLSL 130 and related capabilities of other targets.
@@ -2042,10 +2094,10 @@ alias ser_motion_raygen = raygen + ser_motion;
/// User should not use this capability
/// [Other]
-alias all = _sm_6_7 + hlsl_nvapi
- | sm_6_7
+alias all = _sm_6_8 + hlsl_nvapi
+ | glsl_spirv_1_5 + sm_6_8
+ ser + shaderclock + texturefootprint + fragmentshaderinterlock + _GL_NV_shader_subgroup_partitioned
+ _GL_NV_ray_tracing_motion_blur + _GL_NV_shader_texture_footprint
- | spirv_1_5 + sm_6_7
+ | spirv_1_5 + sm_6_8
+ ser + shaderclock + texturefootprint + fragmentshaderinterlock + spvGroupNonUniformPartitionedNV
+ spvRayTracingMotionBlurNV + spvRayTracingMotionBlurNV;
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index 248c83fe5..db8548dce 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -34,6 +34,8 @@ BuiltinConversionKind SemanticsVisitor::getImplicitConversionBuiltinKind(Decl* d
bool SemanticsVisitor::isEffectivelyScalarForInitializerLists(Type* type)
{
+ if (as<CoopVectorExpressionType>(type))
+ return false;
if (as<ArrayExpressionType>(type))
return false;
if (as<VectorExpressionType>(type))
@@ -282,6 +284,50 @@ bool SemanticsVisitor::_readAggregateValueFromInitializerList(
}
}
}
+ else if (auto toCoopVectorType = as<CoopVectorExpressionType>(toType))
+ {
+ auto toElementCount = toCoopVectorType->getElementCount();
+ auto toElementType = toCoopVectorType->getElementType();
+
+ UInt elementCount = 0;
+ if (auto constElementCount = as<ConstantIntVal>(toElementCount))
+ {
+ elementCount = (UInt)constElementCount->getValue();
+ }
+ else
+ {
+ // We don't know the element count statically,
+ // so what are we supposed to be doing?
+ //
+ if (outToExpr)
+ {
+ getSink()->diagnose(
+ fromInitializerListExpr,
+ Diagnostics::cannotUseInitializerListForCoopVectorOfUnknownSize,
+ toElementCount);
+ }
+ return false;
+ }
+
+ for (UInt ee = 0; ee < elementCount; ++ee)
+ {
+ Expr* coercedArg = nullptr;
+ bool argResult = _readValueFromInitializerList(
+ toElementType,
+ outToExpr ? &coercedArg : nullptr,
+ fromInitializerListExpr,
+ ioArgIndex);
+
+ // No point in trying further if any argument fails
+ if (!argResult)
+ return false;
+
+ if (coercedArg)
+ {
+ coercedArgs.add(coercedArg);
+ }
+ }
+ }
else if (auto toArrayType = as<ArrayExpressionType>(toType))
{
// TODO(tfoley): If we can compute the size of the array statically,
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 5ace72708..448534ce8 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -871,6 +871,7 @@ String GetHLSLProfileName(Profile profile)
CASE(DX_6_5, _6_5);
CASE(DX_6_6, _6_6);
CASE(DX_6_7, _6_7);
+ CASE(DX_6_8, _6_8);
#undef CASE
default:
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 345be7a54..682980107 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -1415,6 +1415,11 @@ DIAGNOSTIC(
Error,
cannotUseInitializerListForType,
"cannot use initializer list for type '$0'")
+DIAGNOSTIC(
+ 30505,
+ Error,
+ cannotUseInitializerListForCoopVectorOfUnknownSize,
+ "cannot use initializer list for CoopVector of statically unknown size '$0'")
// 3062x: variables
DIAGNOSTIC(
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 180ef9909..db2c0150f 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -1469,6 +1469,8 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
case kIROp_MakeArray:
case kIROp_swizzleSet:
case kIROp_MakeArrayFromElement:
+ case kIROp_MakeCoopVector:
+
return false;
}
@@ -2362,6 +2364,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
emitSimpleValue(inst);
break;
+ case kIROp_MakeCoopVector:
case kIROp_MakeVector:
case kIROp_MakeMatrix:
case kIROp_VectorReshape:
@@ -3096,6 +3099,51 @@ void CLikeSourceEmitter::_emitInst(IRInst* inst)
m_writer->advanceToSourceLocation(inst->sourceLoc);
}
+ if (auto coopVecType = as<IRCoopVectorType>(inst->getDataType()))
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_MakeCoopVector:
+ {
+ emitType(coopVecType, getName(inst));
+ m_writer->emit(";\n");
+
+ auto elemCount = as<IRIntLit>(coopVecType->getOperand(1));
+ IRIntegerValue elemCountValue = elemCount->getValue();
+ for (IRIntegerValue i = 0; i < elemCountValue; ++i)
+ {
+ m_writer->emit(getName(inst));
+ m_writer->emit(".WriteToIndex(");
+ m_writer->emit(i);
+ m_writer->emit(", ");
+ emitDereferenceOperand(inst->getOperand(i), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ }
+ return;
+ }
+ case kIROp_Call:
+ emitType(coopVecType, getName(inst));
+ m_writer->emit(";\n");
+
+ m_writer->emit(getName(inst));
+ m_writer->emit(".CopyFrom(");
+ emitCallExpr((IRCall*)inst, getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return;
+ case kIROp_Load:
+ emitType(coopVecType, getName(inst));
+ m_writer->emit(";\n");
+
+ m_writer->emit(getName(inst));
+ m_writer->emit(".CopyFrom(");
+ emitDereferenceOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return;
+ default:
+ break;
+ }
+ }
+
switch (inst->getOp())
{
default:
@@ -3342,11 +3390,21 @@ void CLikeSourceEmitter::_emitStoreImpl(IRStore* store)
{
auto srcVal = store->getVal();
auto dstPtr = store->getPtr();
- auto prec = getInfo(EmitOp::Assign);
- emitDereferenceOperand(dstPtr, leftSide(getInfo(EmitOp::General), prec));
- m_writer->emit(" = ");
- emitOperand(srcVal, rightSide(prec, getInfo(EmitOp::General)));
- m_writer->emit(";\n");
+ if (isPointerOfType(dstPtr->getDataType(), kIROp_CoopVectorType))
+ {
+ emitDereferenceOperand(dstPtr, getInfo(EmitOp::General));
+ m_writer->emit(".CopyFrom(");
+ emitDereferenceOperand(srcVal, getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ }
+ else
+ {
+ auto prec = getInfo(EmitOp::Assign);
+ emitDereferenceOperand(dstPtr, leftSide(getInfo(EmitOp::General), prec));
+ m_writer->emit(" = ");
+ emitOperand(srcVal, rightSide(prec, getInfo(EmitOp::General)));
+ m_writer->emit(";\n");
+ }
}
void CLikeSourceEmitter::_emitInstAsDefaultInitializedVar(IRInst* inst, IRType* type)
@@ -4627,7 +4685,45 @@ void CLikeSourceEmitter::emitVar(IRVar* varDecl)
{
if (store->getPtr() == varDecl)
{
- _emitInstAsVarInitializerImpl(store->getVal());
+ const bool isCoopVectorType = varType->getOp() == kIROp_CoopVectorType;
+ if (isCoopVectorType && store->getVal()->getOp() == kIROp_Load)
+ {
+ m_writer->emit(";\n");
+ m_writer->emit(getName(varDecl));
+ m_writer->emit(".CopyFrom(");
+ emitDereferenceOperand(store->getVal()->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ }
+ else if (isCoopVectorType && store->getVal()->getOp() == kIROp_Call)
+ {
+ m_writer->emit(";\n");
+ m_writer->emit(getName(varDecl));
+ m_writer->emit(".CopyFrom(");
+ emitCallExpr((IRCall*)store->getVal(), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ }
+ else if (isCoopVectorType && store->getVal()->getOp() == kIROp_MakeCoopVector)
+ {
+ auto coopVecType = as<IRCoopVectorType>(store->getVal()->getDataType());
+ auto elemCount = as<IRIntLit>(coopVecType->getOperand(1));
+ IRIntegerValue elemCountValue = elemCount->getValue();
+ for (IRIntegerValue i = 0; i < elemCountValue; ++i)
+ {
+ m_writer->emit(";\n");
+ m_writer->emit(getName(varDecl));
+ m_writer->emit(".WriteToIndex(");
+ m_writer->emit(i);
+ m_writer->emit(", ");
+ emitDereferenceOperand(
+ store->getVal()->getOperand(i),
+ getInfo(EmitOp::General));
+ m_writer->emit(")");
+ }
+ }
+ else
+ {
+ _emitInstAsVarInitializerImpl(store->getVal());
+ }
}
}
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index 61457d98e..1429139f9 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -3370,6 +3370,23 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
return;
}
+ else if (auto specializedType = as<IRSpecialize>(type))
+ {
+ // If a `specialize` instruction made it this far, then
+ // it represents an intrinsic generic type.
+ //
+ emitSimpleType((IRType*)getSpecializedValue(specializedType));
+ m_writer->emit("<");
+ UInt argCount = specializedType->getArgCount();
+ for (UInt ii = 0; ii < argCount; ++ii)
+ {
+ if (ii != 0)
+ m_writer->emit(", ");
+ emitVal(specializedType->getArg(ii), getInfo(EmitOp::General));
+ }
+ m_writer->emit(" >");
+ return;
+ }
auto decorated = getResolvedInstForDecorations(type);
UnownedStringSlice intrinsicDef;
diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp
index 7bd3bb3db..53545b31f 100644
--- a/source/slang/slang-emit-hlsl.cpp
+++ b/source/slang/slang-emit-hlsl.cpp
@@ -765,6 +765,7 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
m_writer->emit("GroupMemoryBatrierWithGroupSync();\n");
return true;
}
+ case kIROp_MakeCoopVector:
case kIROp_MakeVector:
case kIROp_MakeMatrix:
{
@@ -1359,6 +1360,16 @@ void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
emitSimpleTypeImpl(cast<IRAtomicType>(type)->getElementType());
return;
}
+ case kIROp_CoopVectorType:
+ {
+ auto coopVecType = (IRCoopVectorType*)type;
+ m_writer->emit("CoopVector<");
+ emitType(coopVecType->getElementType());
+ m_writer->emit(",");
+ m_writer->emit(getIntVal(coopVecType->getElementCount()));
+ m_writer->emit(">");
+ return;
+ }
case kIROp_ConstRefType:
{
emitSimpleTypeImpl(as<IRConstRefType>(type)->getValueType());
diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h
index 6b99a7f50..07c7af3d2 100644
--- a/source/slang/slang-emit-hlsl.h
+++ b/source/slang/slang-emit-hlsl.h
@@ -82,7 +82,7 @@ protected:
virtual void emitPostKeywordTypeAttributesImpl(IRInst* inst) SLANG_OVERRIDE;
- void _emitPrefixTypeAttr(IRAttr* attr) SLANG_OVERRIDE;
+ virtual void _emitPrefixTypeAttr(IRAttr* attr) SLANG_OVERRIDE;
// Emit a single `register` semantic, as appropriate for a given resource-type-specific layout
// info Keyword to use in the uniform case (`register` for globals, `packoffset` inside a
diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h
index 1f2996646..8c5316f51 100644
--- a/source/slang/slang-emit-spirv-ops.h
+++ b/source/slang/slang-emit-spirv-ops.h
@@ -138,6 +138,19 @@ SpvInst* emitOpTypeVector(
componentCount);
}
+template<typename T1, typename T2>
+SpvInst* emitOpTypeCoopVec(IRInst* inst, const T1& componentType, const T2& componentCount)
+{
+ static_assert(isSingular<T1>);
+ return emitInstMemoized(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ inst,
+ SpvOpTypeCooperativeVectorNV,
+ kResultID,
+ componentType,
+ componentCount);
+}
+
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypeMatrix
template<typename T>
SpvInst* emitOpTypeMatrix(IRInst* inst, const T& columnType, const SpvLiteralInteger& columnCount)
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 24d8cc0c6..2c36ae5f7 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1628,6 +1628,16 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
static_cast<IRIntLit*>(vectorType->getElementCount())->getValue(),
vectorType);
}
+ case kIROp_CoopVectorType:
+ {
+ auto coopVecType = static_cast<IRCoopVectorType*>(inst);
+ requireSPIRVCapability(SpvCapabilityCooperativeVectorNV);
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_cooperative_vector"));
+ return ensureCoopVecType(
+ static_cast<IRBasicType*>(coopVecType->getElementType())->getBaseType(),
+ static_cast<IRIntLit*>(coopVecType->getElementCount())->getValue(),
+ coopVecType);
+ }
case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(inst);
@@ -1778,6 +1788,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
numElems->getValue());
}
case kIROp_MakeVector:
+ case kIROp_MakeCoopVector:
case kIROp_MakeArray:
case kIROp_MakeStruct:
return emitCompositeConstruct(getSection(SpvLogicalSectionID::ConstantsAndTypes), inst);
@@ -2361,6 +2372,27 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return result;
}
+ /// Similar to ensureVectorType but for CoopVecType
+ SpvInst* ensureCoopVecType(
+ BaseType baseType,
+ IRIntegerValue elementCount,
+ IRCoopVectorType* inst)
+ {
+ IRBuilder builder(m_irModule);
+ if (!inst)
+ {
+ builder.setInsertInto(m_irModule->getModuleInst());
+ inst = builder.getCoopVectorType(
+ builder.getBasicType(baseType),
+ builder.getIntValue(builder.getIntType(), elementCount));
+ }
+ auto result = emitOpTypeCoopVec(
+ inst,
+ inst->getElementType(),
+ emitIntConstant(elementCount, builder.getIntType()));
+ return result;
+ }
+
bool _maybeEmitInterpolationModifierDecoration(IRInterpolationMode mode, SpvId varInst)
{
switch (mode)
@@ -3417,6 +3449,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case kIROp_StructuredBufferGetDimensions:
result = emitStructuredBufferGetDimensions(parent, inst);
break;
+ case kIROp_GetStructuredBufferPtr:
+ case kIROp_GetUntypedBufferPtr:
+ result = emitGetBufferPtr(parent, inst);
+ break;
case kIROp_swizzle:
result = emitSwizzle(parent, as<IRSwizzle>(inst));
break;
@@ -3711,6 +3747,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
result = emitSplat(parent, inst, scalar, numElems->getValue());
}
break;
+ case kIROp_MakeCoopVector:
+ result = emitConstruct(parent, inst);
+ break;
case kIROp_MakeArray:
result = emitConstruct(parent, inst);
break;
@@ -6079,7 +6118,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
const auto baseTy = base->getDataType();
SLANG_ASSERT(
as<IRPointerLikeType>(baseTy) || as<IRArrayType>(baseTy) || as<IRVectorType>(baseTy) ||
- as<IRMatrixType>(baseTy));
+ as<IRCoopVectorType>(baseTy) || as<IRMatrixType>(baseTy));
IRBuilder builder(m_irModule);
builder.setInsertBefore(inst);
@@ -6097,7 +6136,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
}
else
{
- SLANG_ASSERT(as<IRVectorType>(baseTy));
+ SLANG_ASSERT(as<IRVectorType>(baseTy) || as<IRCoopVectorType>(baseTy));
// SPIRV Only allows dynamic element extract on vector types.
return emitOpVectorExtractDynamic(
parent,
@@ -6307,6 +6346,27 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return result;
}
+ SpvInst* emitGetBufferPtr(SpvInstParent* parent, IRInst* inst)
+ {
+ IRBuilder builder(inst);
+ auto addressSpace =
+ isSpirv14OrLater() ? AddressSpace::StorageBuffer : AddressSpace::Uniform;
+ // The buffer is a global parameter, so it's a pointer
+ IRPtrTypeBase* bufPtrType = cast<IRPtrTypeBase>(inst->getOperand(0)->getDataType());
+ // It's lowered to a struct type..
+ IRStructType* bufType = cast<IRStructType>(bufPtrType->getValueType());
+ // containing an unsized array, specifically one with an explicit
+ // stride, which is not expressible in spirv_asm blocks
+ IRArrayTypeBase* arrayType =
+ cast<IRArrayTypeBase>(bufType->getFields().getFirst()->getFieldType());
+ return emitOpAccessChain(
+ parent,
+ inst,
+ builder.getPtrType(arrayType, addressSpace),
+ inst->getOperand(0),
+ makeArray(emitIntConstant(0, builder.getIntType())));
+ }
+
SpvInst* emitSwizzle(SpvInstParent* parent, IRSwizzle* inst)
{
if (inst->getElementCount() == 1)
@@ -6478,7 +6538,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
IRType* toType = nullptr;
bool isMatrixCast = false;
- if (as<IRVectorType>(fromTypeV) || as<IRVectorType>(toTypeV))
+ if (as<IRVectorType>(fromTypeV) || as<IRVectorType>(toTypeV) ||
+ as<IRCoopVectorType>(fromTypeV) || as<IRCoopVectorType>(toTypeV))
{
fromType = getVectorElementType(fromTypeV);
toType = getVectorElementType(toTypeV);
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index cf4070df4..d8594ffdd 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -66,6 +66,7 @@
#include "slang-ir-lower-bit-cast.h"
#include "slang-ir-lower-buffer-element-type.h"
#include "slang-ir-lower-combined-texture-sampler.h"
+#include "slang-ir-lower-coopvec.h"
#include "slang-ir-lower-dynamic-resource-heap.h"
#include "slang-ir-lower-generics.h"
#include "slang-ir-lower-glsl-ssbo-types.h"
@@ -1016,6 +1017,16 @@ Result linkAndOptimizeIR(
#endif
validateIRModuleIfEnabled(codeGenContext, irModule);
+ switch (target)
+ {
+ case CodeGenTarget::SPIRV:
+ case CodeGenTarget::SPIRVAssembly:
+ case CodeGenTarget::HLSL:
+ break;
+ default:
+ lowerCooperativeVectors(irModule, sink);
+ }
+
// Inline calls to any functions marked with [__unsafeInlineEarly] or [ForceInline].
performForceInlining(irModule);
@@ -1476,14 +1487,15 @@ Result linkAndOptimizeIR(
{
default:
break;
+ case CodeGenTarget::HLSL:
case CodeGenTarget::GLSL:
case CodeGenTarget::WGSL:
- moveGlobalVarInitializationToEntryPoints(irModule);
+ moveGlobalVarInitializationToEntryPoints(irModule, targetProgram);
break;
// For SPIR-V to SROA across 2 entry-points a value must not be a global
case CodeGenTarget::SPIRV:
case CodeGenTarget::SPIRVAssembly:
- moveGlobalVarInitializationToEntryPoints(irModule);
+ moveGlobalVarInitializationToEntryPoints(irModule, targetProgram);
if (targetProgram->getOptionSet().getBoolOption(
CompilerOptionName::EnableExperimentalPasses))
introduceExplicitGlobalContext(irModule, target);
@@ -1494,7 +1506,7 @@ Result linkAndOptimizeIR(
case CodeGenTarget::Metal:
case CodeGenTarget::CPPSource:
case CodeGenTarget::CUDASource:
- moveGlobalVarInitializationToEntryPoints(irModule);
+ moveGlobalVarInitializationToEntryPoints(irModule, targetProgram);
introduceExplicitGlobalContext(irModule, target);
if (target == CodeGenTarget::CPPSource)
{
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 7507e2fac..d65a22e77 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -2619,6 +2619,7 @@ bool canTypeBeStored(IRInst* type)
case kIROp_ClassType:
case kIROp_FloatType:
case kIROp_VectorType:
+ case kIROp_CoopVectorType:
case kIROp_MatrixType:
case kIROp_BackwardDiffIntermediateContextType:
return true;
diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp
index 6c297fd30..ff6d64319 100644
--- a/source/slang/slang-ir-constexpr.cpp
+++ b/source/slang/slang-ir-constexpr.cpp
@@ -104,6 +104,7 @@ bool opCanBeConstExpr(IROp op)
case kIROp_MakeMatrix:
case kIROp_MakeMatrixFromScalar:
case kIROp_MatrixReshape:
+ case kIROp_MakeCoopVector:
case kIROp_VectorReshape:
case kIROp_CastFloatToInt:
case kIROp_CastIntToFloat:
diff --git a/source/slang/slang-ir-explicit-global-init.cpp b/source/slang/slang-ir-explicit-global-init.cpp
index 97a04f48a..8945f0c5f 100644
--- a/source/slang/slang-ir-explicit-global-init.cpp
+++ b/source/slang/slang-ir-explicit-global-init.cpp
@@ -39,6 +39,7 @@ namespace Slang
struct MoveGlobalVarInitializationToEntryPointsPass
{
IRModule* m_module;
+ TargetProgram* m_targetProgram;
// In the Slang IR, a global variable represents a pointer
// to the storage for the variable but it *also* encodes
@@ -66,9 +67,10 @@ struct MoveGlobalVarInitializationToEntryPointsPass
};
List<GlobalVarInfo> m_globalVarsWithInit;
- void processModule(IRModule* module)
+ void processModule(IRModule* module, TargetProgram* targetProgram)
{
m_module = module;
+ m_targetProgram = targetProgram;
// We start by looking for global variables with
// initialization logic in the IR, and processing
@@ -113,8 +115,32 @@ struct MoveGlobalVarInitializationToEntryPointsPass
}
}
+ bool shouldMoveGlobalVarInitialization(IRGlobalVar* globalVar)
+ {
+ // Currently CoopVector for DXC cannot be created from
+ // constructors with arguments. When CoopVector is used as a
+ // global variable, its initialization has to happen at the
+ // beginning of the entry point.
+ //
+ // At the same time, we don't want to apply
+ // "moveGlobalVarInitializationToEntryPoints" to the rest of
+ // the global variables when targeting HLSL.
+ //
+ if (isD3DTarget(m_targetProgram->getTargetReq()))
+ {
+ auto valueType = globalVar->getDataType()->getValueType();
+ if (as<IRCoopVectorType>(valueType))
+ return true;
+ return false;
+ }
+ return true;
+ }
+
void processGlobalVarWithInit(IRGlobalVar* globalVar, IRBlock* firstBlock)
{
+ if (!shouldMoveGlobalVarInitialization(globalVar))
+ return;
+
IRBuilder builder(m_module);
builder.setInsertBefore(globalVar);
@@ -216,10 +242,10 @@ struct MoveGlobalVarInitializationToEntryPointsPass
};
/// Move initialization logic off of global variables and onto each entry point
-void moveGlobalVarInitializationToEntryPoints(IRModule* module)
+void moveGlobalVarInitializationToEntryPoints(IRModule* module, TargetProgram* targetProgram)
{
MoveGlobalVarInitializationToEntryPointsPass pass;
- pass.processModule(module);
+ pass.processModule(module, targetProgram);
}
} // namespace Slang
diff --git a/source/slang/slang-ir-explicit-global-init.h b/source/slang/slang-ir-explicit-global-init.h
index 3244d90a2..8554a7e3b 100644
--- a/source/slang/slang-ir-explicit-global-init.h
+++ b/source/slang/slang-ir-explicit-global-init.h
@@ -4,7 +4,8 @@
namespace Slang
{
struct IRModule;
+class TargetProgram;
/// Move initialization logic off of global variables and onto each entry point
-void moveGlobalVarInitializationToEntryPoints(IRModule* module);
+void moveGlobalVarInitializationToEntryPoints(IRModule* module, TargetProgram* targetProgram);
} // namespace Slang
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 591ffabf0..f1e9624f3 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -234,6 +234,7 @@ INST(Nop, nop, 0, 0)
INST(RayQueryType, RayQuery, 1, HOISTABLE)
INST(HitObjectType, HitObject, 0, HOISTABLE)
+INST(CoopVectorType, CoopVectorType, 2, HOISTABLE)
// Opaque type that can be dynamically cast to other resource types.
INST(DynamicResourceType, DynamicResource, 0, HOISTABLE)
@@ -363,6 +364,8 @@ INST(MatrixReshape, matrixReshape, 1, 0)
INST(VectorReshape, vectorReshape, 1, 0)
INST(MakeArray, makeArray, 0, 0)
INST(MakeArrayFromElement, makeArrayFromElement, 1, 0)
+INST(MakeCoopVector, makeCoopVector, 0, 0)
+INST(MakeCoopVectorFromValuePack, makeCoopVectorFromValuePack, 1, 0)
INST(MakeStruct, makeStruct, 0, 0)
INST(MakeTuple, makeTuple, 0, 0)
INST(MakeTargetTuple, makeTuple, 0, 0)
@@ -1250,6 +1253,11 @@ INST(CudaKernelLaunch, CudaKernelLaunch, 6, 0)
// Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer
INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0)
+// Gets a T[] pointer to the underlying data of a StructuredBuffer etc...
+INST(GetStructuredBufferPtr, getStructuredBufferPtr, 1, 0)
+// Gets a uint[] pointer to the underlying data of a ByteAddressBuffer etc...
+INST(GetUntypedBufferPtr, getUntypedBufferPtr, 1, 0)
+
/* Layout */
INST(VarLayout, varLayout, 1, HOISTABLE)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index caee45ba8..a883172ff 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3018,6 +3018,11 @@ struct IRMakeVectorFromScalar : IRInst
IR_LEAF_ISA(MakeVectorFromScalar)
};
+struct IRMakeCoopVector : IRInst
+{
+ IR_LEAF_ISA(MakeCoopVector)
+};
+
// An Instruction that creates a differential pair value from a
// primal and differential.
@@ -3735,6 +3740,8 @@ public:
IRVectorType* getVectorType(IRType* elementType, IRIntegerValue elementCount);
+ IRCoopVectorType* getCoopVectorType(IRType* elementType, IRInst* elementCount);
+
IRMatrixType* getMatrixType(
IRType* elementType,
IRInst* rowCount,
@@ -4056,6 +4063,9 @@ public:
IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, UInt element);
IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, IRInst* element);
+ IRInst* emitGetElement(IRType* type, IRInst* arrayLikeType, IRIntegerValue element);
+ IRInst* emitGetElementPtr(IRType* type, IRInst* arrayLikeType, IRIntegerValue element);
+
IRInst* emitMakeResultError(IRType* resultType, IRInst* errorVal);
IRInst* emitMakeResultValue(IRType* resultType, IRInst* val);
IRInst* emitIsResultError(IRInst* result);
@@ -4094,6 +4104,8 @@ public:
IRInst* emitMakeMatrixFromScalar(IRType* type, IRInst* scalarValue);
+ IRInst* emitMakeCoopVector(IRType* type, UInt argCount, IRInst* const* args);
+
IRInst* emitMakeArray(IRType* type, UInt argCount, IRInst* const* args);
IRInst* emitMakeArrayList(IRType* type, UInt argCount, IRInst* const* args);
diff --git a/source/slang/slang-ir-lower-coopvec.cpp b/source/slang/slang-ir-lower-coopvec.cpp
new file mode 100644
index 000000000..c8b24597c
--- /dev/null
+++ b/source/slang/slang-ir-lower-coopvec.cpp
@@ -0,0 +1,234 @@
+#include "slang-ir-lower-coopvec.h"
+
+#include "slang-ir-insts.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+struct CoopVecLoweringContext
+{
+ IRModule* module;
+ DiagnosticSink* sink;
+
+ InstWorkList workList;
+ InstHashSet workListSet;
+
+ CoopVecLoweringContext(IRModule* inModule)
+ : module(inModule), workList(inModule), workListSet(inModule)
+ {
+ }
+
+ struct LoweredCoopVecInfo : public RefObject
+ {
+ IRType* coopvecType;
+ IRArrayType* arrayType;
+ };
+ Dictionary<IRInst*, RefPtr<LoweredCoopVecInfo>> mapLoweredArrayToCoopVecInfo;
+ Dictionary<IRInst*, RefPtr<LoweredCoopVecInfo>> loweredCoopVecs;
+
+ IRType* maybeLowerCoopVecType(IRBuilder* builder, IRType* type)
+ {
+ if (const auto cvt = as<IRCoopVectorType>(type))
+ {
+ if (auto info = getLoweredCoopVecType(builder, cvt))
+ return info->arrayType;
+ }
+ return type;
+ }
+
+ LoweredCoopVecInfo* getLoweredCoopVecType(IRBuilder* builder, IRCoopVectorType* type)
+ {
+ if (auto loweredInfo = loweredCoopVecs.tryGetValue(type))
+ return loweredInfo->Ptr();
+ if (auto loweredInfo = mapLoweredArrayToCoopVecInfo.tryGetValue(type))
+ return loweredInfo->Ptr();
+
+ if (!type)
+ return nullptr;
+
+ RefPtr<LoweredCoopVecInfo> info = new LoweredCoopVecInfo();
+ info->coopvecType = (IRType*)type;
+ info->arrayType = builder->getArrayType(type->getElementType(), type->getElementCount());
+ builder->addNameHintDecoration(info->arrayType, UnownedStringSlice("CoopVec"));
+
+ mapLoweredArrayToCoopVecInfo[info->arrayType] = info;
+ loweredCoopVecs[type] = info;
+ return info.Ptr();
+ }
+
+ void addToWorkList(IRInst* inst)
+ {
+ for (auto ii = inst->getParent(); ii; ii = ii->getParent())
+ {
+ if (as<IRGeneric>(ii))
+ return;
+ }
+
+ if (workListSet.contains(inst))
+ return;
+
+ workList.add(inst);
+ workListSet.add(inst);
+ }
+
+ void processMakeCoopVec(IRInst* inst)
+ {
+ IRBuilder builderStorage(module);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(inst);
+
+ const auto cvt = as<IRCoopVectorType>(inst->getDataType());
+ SLANG_ASSERT(cvt);
+ auto info = getLoweredCoopVecType(builder, cvt);
+ List<IRInst*> operands;
+ operands.setCount(Index(inst->getOperandCount()));
+ UIndex i = 0;
+ for (auto operand = inst->getOperands(); i < inst->getOperandCount(); operand++, i++)
+ operands[Index(i)] = operand->get();
+ auto makeArray =
+ builder->emitMakeArray(info->arrayType, operands.getCount(), operands.begin());
+ inst->replaceUsesWith(makeArray);
+ inst->removeAndDeallocate();
+ }
+
+ void processGetCoopVecElement(IRGetElement*) {}
+
+ void processGetElementPtr(IRGetElementPtr*) {}
+
+ void processGetElement(IRGetElement*) {}
+
+ void processCoopVecType(IRCoopVectorType* inst)
+ {
+ IRBuilder builderStorage(module);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(inst);
+
+ auto loweredCoopVecInfo = getLoweredCoopVecType(builder, inst);
+ SLANG_ASSERT(loweredCoopVecInfo);
+ SLANG_UNUSED(loweredCoopVecInfo);
+ }
+
+ void processUpdateElement(IRUpdateElement*) {}
+
+ void processEntrywiseOp(IRInst* inst)
+ {
+ SLANG_ASSERT(inst->getOperandCount());
+ if (!as<IRCoopVectorType>(inst->getDataType()))
+ return;
+ List<IRInst*> operands;
+ IRIntegerValue width = 0;
+ IRType* resultElementType = nullptr;
+ UIndex opIndex = 0;
+ for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount();
+ operand++, opIndex++)
+ {
+ operands.add(operand->get());
+ if (const auto cv = as<IRCoopVectorType>(operand->get()->getDataType()))
+ {
+ width = getIntVal(cv->getElementCount());
+ resultElementType = cv->getElementType();
+ }
+ }
+ if (width == 0)
+ return;
+ IRBuilder builder(module);
+ IRType* resultElementPtrType = builder.getPtrType(resultElementType);
+ builder.setInsertBefore(inst);
+ const auto result = builder.emitVar(inst->getFullType());
+ List<IRInst*> entrywiseOperands;
+ entrywiseOperands.setCount(operands.getCount());
+ for (IRIntegerValue i = 0; i < width; ++i)
+ {
+ for (int j = 0; j < operands.getCount(); ++j)
+ {
+ if (const auto cv = as<IRCoopVectorType>(operands[j]->getDataType()))
+ {
+ SLANG_ASSERT(getIntVal(cv->getElementCount()) == width);
+ const auto elementType = cv->getElementType();
+ entrywiseOperands[j] = builder.emitGetElement(elementType, operands[j], i);
+ }
+ else
+ {
+ entrywiseOperands[j] = operands[j];
+ }
+ }
+ const auto x = builder.emitIntrinsicInst(
+ resultElementType,
+ inst->getOp(),
+ entrywiseOperands.getCount(),
+ entrywiseOperands.begin());
+ const auto d = builder.emitGetElementPtr(resultElementPtrType, result, i);
+ builder.emitStore(d, x);
+ }
+ const auto v = builder.emitLoad(result);
+ inst->replaceUsesWith(v);
+ inst->removeAndDeallocate();
+ }
+
+ void processInst(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_CoopVectorType:
+ processCoopVecType((IRCoopVectorType*)inst);
+ break;
+ case kIROp_MakeCoopVector:
+ processMakeCoopVec((IRMakeCoopVector*)inst);
+ break;
+ case kIROp_GetElement:
+ processGetElement((IRGetElement*)inst);
+ break;
+ case kIROp_GetElementPtr:
+ processGetElementPtr((IRGetElementPtr*)inst);
+ break;
+ case kIROp_UpdateElement:
+ processUpdateElement((IRUpdateElement*)inst);
+ break;
+ case kIROp_Neg:
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_IntCast:
+ case kIROp_FloatCast:
+ case kIROp_CastIntToFloat:
+ case kIROp_CastFloatToInt:
+ processEntrywiseOp(inst);
+ break;
+ default:
+ break;
+ }
+ }
+
+ void processModule()
+ {
+ IRBuilder builder(module);
+
+ addToWorkList(module->getModuleInst());
+
+ while (workList.getCount() != 0)
+ {
+ IRInst* inst = workList.getLast();
+
+ workList.removeLast();
+ workListSet.remove(inst);
+
+ processInst(inst);
+
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
+ addToWorkList(child);
+ }
+
+ // Replace all coopvec types with sized array types
+ for (const auto& [key, value] : loweredCoopVecs)
+ key->replaceUsesWith(value->arrayType);
+ }
+};
+
+void lowerCooperativeVectors(IRModule* module, DiagnosticSink* sink)
+{
+ CoopVecLoweringContext context(module);
+ context.sink = sink;
+ context.processModule();
+}
+} // namespace Slang
diff --git a/source/slang/slang-ir-lower-coopvec.h b/source/slang/slang-ir-lower-coopvec.h
new file mode 100644
index 000000000..60b8943d5
--- /dev/null
+++ b/source/slang/slang-ir-lower-coopvec.h
@@ -0,0 +1,12 @@
+#pragma once
+
+namespace Slang
+{
+
+struct IRModule;
+class DiagnosticSink;
+
+/// Lower Cooperative Vectors to ordinary arrays
+void lowerCooperativeVectors(IRModule* module, DiagnosticSink* sink);
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index ced0b5eb0..fc399954b 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -353,6 +353,31 @@ struct PeepholeContext : InstPassBase
break;
}
break;
+ case kIROp_MakeCoopVectorFromValuePack:
+ {
+ const auto pack = inst->getOperand(0);
+ if (const auto packType = as<IRTypePack>(pack->getDataType()))
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ List<IRInst*> args;
+ for (UInt j = 0; j < packType->getOperandCount(); ++j)
+ {
+ const auto e = builder.emitGetTupleElement(
+ cast<IRType>(packType->getOperand(j)),
+ pack,
+ j);
+ args.add(e);
+ }
+ const auto cvt = builder.getCoopVectorType(
+ args[0]->getDataType(),
+ builder.getIntValue(builder.getIntType(), args.getCount()));
+ const auto v = builder.emitMakeCoopVector(cvt, args.getCount(), args.begin());
+ inst->replaceUsesWith(v);
+ inst->removeAndDeallocate();
+ }
+ }
+ break;
case kIROp_FieldExtract:
if (inst->getOperand(0)->getOp() == kIROp_MakeStruct)
{
@@ -1091,7 +1116,12 @@ struct PeepholeContext : InstPassBase
break;
auto type = inst->getOperand(0)->getDataType();
IRSizeAndAlignment sizeAlignment;
- getNaturalSizeAndAlignment(targetProgram->getOptionSet(), type, &sizeAlignment);
+ const auto res = getNaturalSizeAndAlignment(
+ targetProgram->getOptionSet(),
+ type,
+ &sizeAlignment);
+ if (!SLANG_SUCCEEDED(res))
+ break;
IRBuilder builder(module);
builder.setInsertBefore(inst);
auto stride =
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 3043e7e31..dd7819e1a 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -21,6 +21,8 @@ IRType* getVectorElementType(IRType* type)
{
if (auto vectorType = as<IRVectorType>(type))
return vectorType->getElementType();
+ if (auto coopVecType = as<IRCoopVectorType>(type))
+ return coopVecType->getElementType();
return type;
}
@@ -1388,6 +1390,7 @@ bool isZero(IRInst* inst)
return as<IRFloatLit>(inst)->getValue() == 0.0;
case kIROp_BoolLit:
return as<IRBoolLit>(inst)->getValue() == false;
+ case kIROp_MakeCoopVector:
case kIROp_MakeVector:
case kIROp_MakeVectorFromScalar:
case kIROp_MakeMatrix:
@@ -1422,6 +1425,7 @@ bool isOne(IRInst* inst)
return as<IRFloatLit>(inst)->getValue() == 1.0;
case kIROp_BoolLit:
return as<IRBoolLit>(inst)->getValue();
+ case kIROp_MakeCoopVector:
case kIROp_MakeVector:
case kIROp_MakeVectorFromScalar:
case kIROp_MakeMatrix:
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 32acc7baa..6a7564e67 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2964,6 +2964,13 @@ IRVectorType* IRBuilder::getVectorType(IRType* elementType, IRIntegerValue eleme
return getVectorType(elementType, getIntValue(getIntType(), elementCount));
}
+IRCoopVectorType* IRBuilder::getCoopVectorType(IRType* elementType, IRInst* elementCount)
+{
+ IRInst* operands[] = {elementType, elementCount};
+ return (IRCoopVectorType*)
+ getType(kIROp_CoopVectorType, sizeof(operands) / sizeof(operands[0]), operands);
+}
+
IRMatrixType* IRBuilder::getMatrixType(
IRType* elementType,
IRInst* rowCount,
@@ -3887,6 +3894,26 @@ IRInst* IRBuilder::emitDefaultConstruct(IRType* type, bool fallback)
return nullptr;
return emitIntrinsicInst(type, kIROp_MakeVectorFromScalar, 1, &inner);
}
+ case kIROp_CoopVectorType:
+ {
+ auto coopVecType = as<IRCoopVectorType>(actualType);
+ if (auto count = as<IRIntLit>(coopVecType->getElementCount()))
+ {
+ auto element = emitDefaultConstruct(coopVecType->getElementType(), fallback);
+ if (!element)
+ return nullptr;
+ List<IRInst*> elements;
+ constexpr int maxCount = 4096;
+ if (count->getValue() > maxCount)
+ break;
+ for (IRIntegerValue i = 0; i < count->getValue(); i++)
+ {
+ elements.add(element);
+ }
+ return emitMakeCoopVector(type, elements.getCount(), elements.getBuffer());
+ }
+ break;
+ }
case kIROp_MatrixType:
{
auto inner =
@@ -4171,6 +4198,18 @@ IRInst* IRBuilder::emitGetNativeString(IRInst* str)
return emitIntrinsicInst(getNativeStringType(), kIROp_getNativeStr, 1, &str);
}
+IRInst* IRBuilder::emitGetElement(IRType* type, IRInst* arrayLikeType, IRIntegerValue element)
+{
+ IRInst* args[] = {arrayLikeType, getIntValue(getIntType(), element)};
+ return emitIntrinsicInst(type, kIROp_GetElement, 2, args);
+}
+
+IRInst* IRBuilder::emitGetElementPtr(IRType* type, IRInst* arrayLikeType, IRIntegerValue element)
+{
+ IRInst* args[] = {arrayLikeType, getIntValue(getIntType(), element)};
+ return emitIntrinsicInst(type, kIROp_GetElementPtr, 2, args);
+}
+
IRInst* IRBuilder::emitGetTupleElement(IRType* type, IRInst* tuple, IRInst* element)
{
IRInst* args[] = {tuple, element};
@@ -4345,6 +4384,11 @@ IRInst* IRBuilder::emitMakeMatrixFromScalar(IRType* type, IRInst* scalarValue)
return emitIntrinsicInst(type, kIROp_MakeMatrixFromScalar, 1, &scalarValue);
}
+IRInst* IRBuilder::emitMakeCoopVector(IRType* type, UInt argCount, IRInst* const* args)
+{
+ return emitIntrinsicInst(type, kIROp_MakeCoopVector, argCount, args);
+}
+
IRInst* IRBuilder::emitMakeArray(IRType* type, UInt argCount, IRInst* const* args)
{
return emitIntrinsicInst(type, kIROp_MakeArray, argCount, args);
@@ -5187,6 +5231,10 @@ IRInst* IRBuilder::emitElementAddress(IRInst* basePtr, IRInst* index)
{
type = vectorType->getElementType();
}
+ else if (auto coopVecType = as<IRCoopVectorType>(valueType))
+ {
+ type = coopVecType->getElementType();
+ }
else if (auto matrixType = as<IRMatrixType>(valueType))
{
type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount());
@@ -8143,6 +8191,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options)
case kIROp_GetAddr:
case kIROp_GetValueFromBoundInterface:
case kIROp_MakeUInt64:
+ case kIROp_MakeCoopVector:
case kIROp_MakeVector:
case kIROp_MakeMatrix:
case kIROp_MakeMatrixFromScalar:
@@ -8627,6 +8676,7 @@ bool isMovableInst(IRInst* inst)
switch (inst->getOp())
{
+ case kIROp_MakeCoopVector:
case kIROp_Add:
case kIROp_Sub:
case kIROp_Mul:
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index b29b3b815..8f53c9f14 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1862,6 +1862,14 @@ struct IRHitObjectType : IRType
IR_LEAF_ISA(HitObjectType)
};
+struct IRCoopVectorType : IRType
+{
+ IRType* getElementType() { return (IRType*)getOperand(0); }
+ IRInst* getElementCount() { return getOperand(1); }
+
+ IR_LEAF_ISA(CoopVectorType)
+};
+
bool isDefinition(IRInst* inVal);
// A structure type is represented as a parent instruction,
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index ba6b3413d..5c0c3edfb 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -4852,6 +4852,29 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
return LoweredValInfo::simple(
getBuilder()->emitMakeMatrix(irType, args.getCount(), args.getBuffer()));
}
+ else if (auto coopVecType = as<CoopVectorExpressionType>(type))
+ {
+ UInt elementCount = (UInt)getIntVal(coopVecType->getElementCount());
+
+ for (UInt ee = 0; ee < argCount; ++ee)
+ {
+ auto argExpr = expr->args[ee];
+ LoweredValInfo argVal = lowerRValueExpr(context, argExpr);
+ args.add(getSimpleVal(context, argVal));
+ }
+ if (elementCount > argCount)
+ {
+ auto irDefaultValue =
+ getSimpleVal(context, getDefaultVal(coopVecType->getElementType()));
+ for (UInt ee = argCount; ee < elementCount; ++ee)
+ {
+ args.add(irDefaultValue);
+ }
+ }
+
+ return LoweredValInfo::simple(
+ getBuilder()->emitMakeCoopVector(irType, args.getCount(), args.getBuffer()));
+ }
else if (auto declRefType = as<DeclRefType>(type))
{
DeclRef<Decl> declRef = declRefType->getDeclRef();
diff --git a/source/slang/slang-profile-defs.h b/source/slang/slang-profile-defs.h
index 8328bf063..57ad3ab25 100644
--- a/source/slang/slang-profile-defs.h
+++ b/source/slang/slang-profile-defs.h
@@ -104,6 +104,7 @@ PROFILE_VERSION(DX_6_4, DX)
PROFILE_VERSION(DX_6_5, DX)
PROFILE_VERSION(DX_6_6, DX)
PROFILE_VERSION(DX_6_7, DX)
+PROFILE_VERSION(DX_6_8, DX)
PROFILE_VERSION(GLSL_150, GLSL)
PROFILE_VERSION(GLSL_330, GLSL)
@@ -139,6 +140,7 @@ PROFILE(DX_Compute_6_4, cs_6_4, Compute, DX_6_4)
PROFILE(DX_Compute_6_5, cs_6_5, Compute, DX_6_5)
PROFILE(DX_Compute_6_6, cs_6_6, Compute, DX_6_6)
PROFILE(DX_Compute_6_7, cs_6_7, Compute, DX_6_7)
+PROFILE(DX_Compute_6_8, cs_6_8, Compute, DX_6_8)
PROFILE(DX_Domain_5_0, ds_5_0, Domain, DX_5_0)
PROFILE(DX_Domain_5_1, ds_5_1, Domain, DX_5_1)
@@ -150,6 +152,7 @@ PROFILE(DX_Domain_6_4, ds_6_4, Domain, DX_6_4)
PROFILE(DX_Domain_6_5, ds_6_5, Domain, DX_6_5)
PROFILE(DX_Domain_6_6, ds_6_6, Domain, DX_6_6)
PROFILE(DX_Domain_6_7, ds_6_7, Domain, DX_6_7)
+PROFILE(DX_Domain_6_8, ds_6_8, Domain, DX_6_8)
PROFILE(DX_Geometry_4_0, gs_4_0, Geometry, DX_4_0)
PROFILE(DX_Geometry_4_1, gs_4_1, Geometry, DX_4_1)
@@ -163,6 +166,7 @@ PROFILE(DX_Geometry_6_4, gs_6_4, Geometry, DX_6_4)
PROFILE(DX_Geometry_6_5, gs_6_5, Geometry, DX_6_5)
PROFILE(DX_Geometry_6_6, gs_6_6, Geometry, DX_6_6)
PROFILE(DX_Geometry_6_7, gs_6_7, Geometry, DX_6_7)
+PROFILE(DX_Geometry_6_8, gs_6_8, Geometry, DX_6_8)
PROFILE(DX_Hull_5_0, hs_5_0, Hull, DX_5_0)
PROFILE(DX_Hull_5_1, hs_5_1, Hull, DX_5_1)
@@ -174,6 +178,7 @@ PROFILE(DX_Hull_6_4, hs_6_4, Hull, DX_6_4)
PROFILE(DX_Hull_6_5, hs_6_5, Hull, DX_6_5)
PROFILE(DX_Hull_6_6, hs_6_6, Hull, DX_6_6)
PROFILE(DX_Hull_6_7, hs_6_7, Hull, DX_6_7)
+PROFILE(DX_Hull_6_8, hs_6_8, Hull, DX_6_8)
PROFILE(DX_Fragment_4_0, ps_4_0, Fragment, DX_4_0)
PROFILE(DX_Fragment_4_1, ps_4_1, Fragment, DX_4_1)
@@ -187,6 +192,7 @@ PROFILE(DX_Fragment_6_4, ps_6_4, Fragment, DX_6_4)
PROFILE(DX_Fragment_6_5, ps_6_5, Fragment, DX_6_5)
PROFILE(DX_Fragment_6_6, ps_6_6, Fragment, DX_6_6)
PROFILE(DX_Fragment_6_7, ps_6_7, Fragment, DX_6_7)
+PROFILE(DX_Fragment_6_8, ps_6_8, Fragment, DX_6_8)
PROFILE(DX_Vertex_4_0, vs_4_0, Vertex, DX_4_0)
PROFILE(DX_Vertex_4_1, vs_4_1, Vertex, DX_4_1)
@@ -200,14 +206,17 @@ PROFILE(DX_Vertex_6_4, vs_6_4, Vertex, DX_6_4)
PROFILE(DX_Vertex_6_5, vs_6_5, Vertex, DX_6_5)
PROFILE(DX_Vertex_6_6, vs_6_6, Vertex, DX_6_6)
PROFILE(DX_Vertex_6_7, vs_6_7, Vertex, DX_6_7)
+PROFILE(DX_Vertex_6_8, vs_6_8, Vertex, DX_6_8)
PROFILE(DX_Mesh_6_5, ms_6_5, Mesh, DX_6_5)
PROFILE(DX_Mesh_6_6, ms_6_6, Mesh, DX_6_6)
PROFILE(DX_Mesh_6_7, ms_6_7, Mesh, DX_6_7)
+PROFILE(DX_Mesh_6_8, ms_6_8, Mesh, DX_6_8)
PROFILE(DX_Amplification_6_5, as_6_5, Amplification, DX_6_5)
PROFILE(DX_Amplification_6_6, as_6_6, Amplification, DX_6_6)
PROFILE(DX_Amplification_6_7, as_6_7, Amplification, DX_6_7)
+PROFILE(DX_Amplification_6_8, as_6_8, Amplification, DX_6_8)
// TODO: consider making `lib_*_*` alias these...
PROFILE(DX_None_4_0, sm_4_0, Unknown, DX_4_0)
@@ -234,6 +243,7 @@ PROFILE(DX_Lib_6_4, lib_6_4, Unknown, DX_6_4)
PROFILE(DX_Lib_6_5, lib_6_5, Unknown, DX_6_5)
PROFILE(DX_Lib_6_6, lib_6_6, Unknown, DX_6_6)
PROFILE(DX_Lib_6_7, lib_6_7, Unknown, DX_6_7)
+PROFILE(DX_Lib_6_8, lib_6_8, Unknown, DX_6_8)
PROFILE_ALIAS(DX_None_6_1, DX_Lib_6_1, sm_6_1)
PROFILE_ALIAS(DX_None_6_2, DX_Lib_6_2, sm_6_2)
@@ -242,6 +252,7 @@ PROFILE_ALIAS(DX_None_6_4, DX_Lib_6_4, sm_6_4)
PROFILE_ALIAS(DX_None_6_5, DX_Lib_6_5, sm_6_5)
PROFILE_ALIAS(DX_None_6_6, DX_Lib_6_6, sm_6_6)
PROFILE_ALIAS(DX_None_6_7, DX_Lib_6_7, sm_6_7)
+PROFILE_ALIAS(DX_None_6_8, DX_Lib_6_8, sm_6_8)
PROFILE(METAL_LIB_2_3, metallib_2_3, Unknown, METAL_2_3)
PROFILE(METAL_LIB_2_4, metallib_2_4, Unknown, METAL_2_4)