From d70da65a90ccd73439895a43b3958c0ea1441f35 Mon Sep 17 00:00:00 2001 From: Mukund Keshava Date: Tue, 10 Jun 2025 10:18:24 +0530 Subject: Add optix support for coopvec (#7286) * WiP: Add coopvec support for Optix * format code * fix minor issues * Fix review comments --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> --- source/slang/hlsl.meta.slang | 82 +++++++++++++++++++++++++++++++++- source/slang/slang-capabilities.capdef | 6 ++- source/slang/slang-emit-c-like.cpp | 1 + source/slang/slang-emit-c-like.h | 3 ++ source/slang/slang-emit-cpp.cpp | 12 ++++- 5 files changed, 99 insertions(+), 5 deletions(-) (limited to 'source') diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index e00108e96..8b0bade6e 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -24382,6 +24382,7 @@ struct CoopVec : IArray, IArithmeti [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] + [require(optix_coopvec)] __init(CoopVec other) { this.copyFrom(other); @@ -24421,6 +24422,7 @@ struct CoopVec : IArray, IArithmeti [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] + [require(optix_coopvec)] void copyFrom(CoopVec other) { __target_switch @@ -24429,6 +24431,8 @@ struct CoopVec : IArray, IArithmeti __intrinsic_asm "$0 = $1"; case hlsl_coopvec_poc: __intrinsic_asm ".CopyFrom"; + case optix_coopvec: + __intrinsic_asm "optixCoopVecCvt<$TR>(*($0));"; default: if (__isFloat() && __isInt()) this = __int_to_float_cast(other); @@ -24438,7 +24442,7 @@ struct CoopVec : IArray, IArithmeti this = __real_cast(other); else if (__isInt() && __isInt()) this = __int_cast(other); - } + } } /// Fill all elements of this CoopVec with the specified value. @@ -24591,6 +24595,7 @@ struct CoopVec : IArray, IArithmeti [__NoSideEffect] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] + [require(optix_coopvec)] static CoopVec load(ByteAddressBuffer buffer, int32_t byteOffset16ByteAligned = 0) { __target_switch @@ -24606,6 +24611,8 @@ struct CoopVec : IArray, IArithmeti CoopVec ret; ret.__Load(buffer, byteOffset16ByteAligned); return ret; + case optix_coopvec: + __intrinsic_asm "optixCoopVecLoad<$TR>((CUdeviceptr)(&($0)));"; default: var vec = CoopVec(); for(int i = 0; i < N; ++i) @@ -24618,6 +24625,7 @@ struct CoopVec : IArray, IArithmeti [__NoSideEffect] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] + [require(optix_coopvec)] static CoopVec load(RWByteAddressBuffer buffer, int32_t byteOffset16ByteAligned = 0) { __target_switch @@ -24633,6 +24641,8 @@ struct CoopVec : IArray, IArithmeti CoopVec ret; ret.__Load(buffer, byteOffset16ByteAligned); return ret; + case optix_coopvec: + __intrinsic_asm "optixCoopVecLoad<$TR>((CUdeviceptr)(&($0)));"; default: var vec = CoopVec(); for(int i = 0; i < N; ++i) @@ -24702,6 +24712,7 @@ struct CoopVec : IArray, IArithmeti [__NoSideEffect] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] + [require(optix_coopvec)] static CoopVec load(__constref groupshared const T[M] data, int32_t byteOffset16ByteAligned = 0) { static_assert(N <= M, "The destination vector size is smaller than the input."); @@ -24716,6 +24727,8 @@ struct CoopVec : IArray, IArithmeti CoopVec ret; ret.__Load(data, __byteToElemOffset(byteOffset16ByteAligned)); return ret; + case optix_coopvec: + __intrinsic_asm "optixCoopVecLoad<$TR>((CUdeviceptr)(&($0)));"; default: CoopVec result; for(int i = 0; i < N; ++i) @@ -24922,6 +24935,7 @@ struct CoopVec : IArray, IArithmeti [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] + [require(optix_coopvec)] This add(This other) { __target_switch @@ -24932,6 +24946,8 @@ struct CoopVec : IArray, IArithmeti This ret = this; ret.__mutAdd(other); return ret; + case optix_coopvec: + __intrinsic_asm "optixCoopVecAdd($0, $1)"; default: return __pureAdd(other); } } @@ -24957,6 +24973,7 @@ struct CoopVec : IArray, IArithmeti [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] + [require(optix_coopvec)] This sub(This other) { __target_switch @@ -24967,6 +24984,8 @@ struct CoopVec : IArray, IArithmeti This ret = this; ret.__mutSub(other); return ret; + case optix_coopvec: + __intrinsic_asm "optixCoopVecSub($0, $1)"; default: return __pureSub(other); } } @@ -24992,6 +25011,7 @@ struct CoopVec : IArray, IArithmeti [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] + [require(optix_coopvec)] This mul(This other) { __target_switch @@ -25002,6 +25022,8 @@ struct CoopVec : IArray, IArithmeti This ret = this; ret.__mutMul(other); return ret; + case optix_coopvec: + __intrinsic_asm "optixCoopVecMul($0, $1)"; default: return __pureMul(other); } } @@ -25621,6 +25643,7 @@ CoopVec operator *(const T lhs, CoopVec rhs) [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] +[require(optix_coopvec)] CoopVec min(CoopVec x, CoopVec y) { __target_switch @@ -25636,6 +25659,8 @@ CoopVec min(CoopVec x, CoopVec ret = x; ret.__mutMin(y); return ret; + case optix_coopvec: + __intrinsic_asm "optixCoopVecMin($0, $1)"; default: CoopVec ret; for(int i = 0; i < N; ++i) @@ -25648,6 +25673,7 @@ CoopVec min(CoopVec x, [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] +[require(optix_coopvec)] CoopVec max(CoopVec x, CoopVec y) { __target_switch @@ -25663,6 +25689,8 @@ CoopVec max(CoopVec x, CoopVec ret = x; ret.__mutMax(y); return ret; + case optix_coopvec: + __intrinsic_asm "optixCoopVecMax($0, $1)"; default: CoopVec ret; for(int i = 0; i < N; ++i) @@ -25809,6 +25837,7 @@ CoopVec clamp(CoopVec x, Coop // [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] +[require(optix_coopvec)] CoopVec step(CoopVec edge, CoopVec x) { __target_switch @@ -25825,6 +25854,8 @@ CoopVec step(CoopVec ed { result:$$CoopVec = OpExtInst glsl450 Step $edge $x; }; + case optix_coopvec: + __intrinsic_asm "optixCoopVecStep($0, $1)"; default: CoopVec ret; for(int i = 0; i < N; ++i) @@ -25890,6 +25921,43 @@ CoopVec log(CoopVec x) // [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] +[require(optix_coopvec)] +CoopVec log2(CoopVec x) +{ + __target_switch + { + default: + CoopVec ret; + for(int i = 0; i < N; ++i) + ret[i] = log2(x[i]); + return ret; + case optix_coopvec: + __intrinsic_asm "optixCoopVecLog2($0)"; + } +} + +// [ForceInline] +[require(cooperative_vector)] +[require(hlsl_coopvec_poc)] +[require(optix_coopvec)] +CoopVec exp2(CoopVec x) +{ + __target_switch + { + default: + CoopVec ret; + for(int i = 0; i < N; ++i) + ret[i] = exp2(x[i]); + return ret; + case optix_coopvec: + __intrinsic_asm "optixCoopVecExp2($0)"; + } +} + +// [ForceInline] +[require(cooperative_vector)] +[require(hlsl_coopvec_poc)] +[require(optix_coopvec)] CoopVec tanh(CoopVec x) { __target_switch @@ -25906,6 +25974,8 @@ CoopVec tanh(CoopVec x) { result:$$CoopVec = OpExtInst glsl450 Tanh $x; }; + case optix_coopvec: + __intrinsic_asm "optixCoopVecTanh($0)"; default: CoopVec ret; for(int i = 0; i < N; ++i) @@ -25944,6 +26014,7 @@ CoopVec atan(CoopVec yO // [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] +[require(optix_coopvec)] CoopVec fma(CoopVec a, CoopVec b, CoopVec c) { // TODO: Investigate, why does this fail if it's not inlined @@ -25963,6 +26034,8 @@ CoopVec fma(CoopVec a, { result:$$CoopVec = OpExtInst glsl450 Fma $a $b $c; }; + case optix_coopvec: + __intrinsic_asm "optixCoopVecFFMA($0, $1, $2)"; default: CoopVec ret; for(int i = 0; i < N; ++i) @@ -26695,6 +26768,7 @@ CoopVec coopVecMatMulAddPacked CoopVec coopVecMatMulAdd( CoopVec input, @@ -26746,6 +26820,7 @@ if(buffer.isRW) /// @param matrixInterpretation Specifies how to interpret the values in the matrix. [require(cooperative_vector)] [require(hlsl_coopvec_poc)] +[require(optix_coopvec)] void coopVecOuterProductAccumulate( CoopVec a, CoopVec b, @@ -26773,6 +26848,8 @@ void coopVecOuterProductAccumulate( CoopVec v, $(buffer.type) buffer, @@ -26855,6 +26933,8 @@ void coopVecReduceSumAccumulate( OpCapability CooperativeVectorTrainingNV; OpCooperativeVectorReduceSumAccumulateNV $bufferPtr $offset $v; }; + case optix_coopvec: + __intrinsic_asm "optixCoopVecReduceSumAccumulate($0, (CUdeviceptr)(&$1), $2)"; default: for (int i = 0; i < N; ++i) { diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index 7616cc201..343f89687 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -220,11 +220,11 @@ def _sm_6_9 : _sm_6_8; def hlsl_nvapi : hlsl; -/// Represet HLSL compatibility support. +/// Represent HLSL compatibility support. /// [Version] def hlsl_2018 : _sm_5_1; -/// Represet compatibility support for the deprecated POC DXC +/// Represent compatibility support for the deprecated POC DXC /// [Version] def hlsl_coopvec_poc : _sm_6_8; @@ -244,6 +244,8 @@ def _cuda_sm_6_0 : _cuda_sm_5_0; def _cuda_sm_7_0 : _cuda_sm_6_0; def _cuda_sm_8_0 : _cuda_sm_7_0; def _cuda_sm_9_0 : _cuda_sm_8_0; +/// Represents capabilities required for optix cooperative vector support. +def optix_coopvec : _cuda_sm_9_0; /// All code-gen targets /// [Compound] diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 0092d159a..3fbf47bfa 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -112,6 +112,7 @@ CLikeSourceEmitter::CLikeSourceEmitter(const Desc& desc) auto targetCaps = getTargetReq()->getTargetCaps(); isCoopvecPoc = targetCaps.implies(CapabilityAtom::hlsl_coopvec_poc); + isOptixCoopVec = targetCaps.implies(CapabilityAtom::optix_coopvec); } SlangResult CLikeSourceEmitter::init() diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 78793f655..1e9deaa0d 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -744,6 +744,9 @@ protected: // Indicates if we are emiting for DXC cooperative vector POC. bool isCoopvecPoc = false; + + // Indicates if we are emiting for Optix cooperative vector. + bool isOptixCoopVec = false; }; } // namespace Slang diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 6f97a11da..8e95cebfb 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1152,8 +1152,16 @@ void CPPSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator) auto arrayType = static_cast(type); auto elementType = arrayType->getElementType(); int elementCount = int(getIntVal(arrayType->getElementCount())); - - m_writer->emit("FixedArray<"); + auto nameHint = arrayType->findDecoration(); + bool isCoopVec = nameHint && (nameHint->getName() == UnownedStringSlice("CoopVec")); + if (isCoopVec && isOptixCoopVec) + { + m_writer->emit("OptixCoopVec<"); + } + else + { + m_writer->emit("FixedArray<"); + } _emitType(elementType, nullptr); m_writer->emit(", "); m_writer->emit(elementCount); -- cgit v1.2.3