summaryrefslogtreecommitdiffstats
path: root/prelude
diff options
context:
space:
mode:
authorjsmall-nvidia <jsmall@nvidia.com>2021-05-04 16:24:51 -0400
committerGitHub <noreply@github.com>2021-05-04 13:24:51 -0700
commit731f1fc6b26659dc8f62fbc1969c076b78ada24f (patch)
tree9fcc4d1d931049edeabe3cea46d0bd6942956042 /prelude
parentdc571f1291f6b82b189a0db52c0468ae2fc7af4b (diff)
CUDA half comparison support (#1834)
* #include an absolute path didn't work - because paths were taken to always be relative. * Split out StringEscapeUtil. * Added StringEscapeUtil. * Fix typo in unix quoting type. * Small comment improvements. * Try to fix linux linking issue. * Fix typo. * Attempt to fix linux link issue. * Update VS proj even though nothing really changed. * Fix another typo issue. * Fix for windows issue. Fixed bug. * Make separate Utils for escaping. * Fix typo. * Split out into StringEscapeHandler. * Windows shell does handle removing quotes (so remove code to remove them). * Handle unescaping if not initiating using the shell. * Slight improvement around shell like decoding. * Simplify command extraction. * Add shared-library category type. * Fix bug in command extraction. * Typo in transcendental category. * Enable unit-test on in smoke test category. * Make parsing failing output as a failing test. * Fixes for transcendental tests. Disable tests that do not work. * Changed category parsing. * Removed the TestResult parameter from _gatherTestsForFile. Made testsList only output. * Remove testing if all tests were disabled. * Make args of CommandLine always unescaped. * Add category. * Don't need escaping on unix/linux. * Remove some no longer used functions. * Add requireSMVersion to CUDAExtensionTracker. * half-calc.slang now works for CUDA. * bit-cast-16-bit works on CUDA. * WIP handling of CUDA vector<half> types. * Half swizzle CUDA. * Half vector test. * Fix swizzle half bug. * Fix compilation issue with narrowing to Index. * Add unary ops. * Add some vector scalar maths ops. * Add half vector conversions for CUDA. * Fix erroneous comment. * Support for half comparisons. * First pass test for half compare. * Fix bug in CUDA specialized emit control. Updated tests to have pre and post inc/dec. * Removed unneeded parts of the cuda prelude. * Half structured buffer works on CUDA. Co-authored-by: Tim Foley <tfoleyNV@users.noreply.github.com>
Diffstat (limited to 'prelude')
-rw-r--r--prelude/slang-cuda-prelude.h88
1 files changed, 43 insertions, 45 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 05b978cf6..a627cc652 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -5,6 +5,9 @@
// are passed down.
#ifdef SLANG_CUDA_ENABLE_HALF
+// We don't want half2 operators, because it will implement comparison operators that return a bool(!). We want to generate
+// those functions. Doing so means that we will have to define all the other half2 operators.
+# define __CUDA_NO_HALF2_OPERATORS__
# include <cuda_fp16.h>
#endif
@@ -155,6 +158,7 @@ union Union64
struct __half3 { __half2 xy; __half z; };
struct __half4 { __half2 xy; __half2 zw; };
+// *** convert ***
// half -> other
@@ -196,7 +200,43 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 convert___half2(const double2& v) { r
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 convert___half3(const double3& v) { return __half3{ __float22half2_rn(float2{v.x, v.y}), __float2half_rn(v.z) }; }
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 convert___half4(const double4& v) { return __half4{ __float22half2_rn(float2{v.x, v.y}), __float22half2_rn(float2{v.z, v.w}) }; }
-// half2
+// *** make ***
+
+// Mechanism to make half vectors
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 make___half2(__half x, __half y) { return __halves2half2(x, y); }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 make___half3(__half x, __half y, __half z) { return __half3{ __halves2half2(x, y), z }; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 make___half4(__half x, __half y, __half z, __half w) { return __half4{ __halves2half2(x, y), __halves2half2(z, w)}; }
+
+// *** constructFromScalar ***
+
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 constructFromScalar___half2(half x) { return __half2half2(x); }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 constructFromScalar___half3(half x) { return __half3{__half2half2(x), x}; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 constructFromScalar___half4(half x) { const __half2 v = __half2half2(x); return __half4{v, v}; }
+
+// *** half2 ***
+
+// half2 maths ops
+
+// NOTE! That by default these are in cuda_fp16.hpp, but we disable them, because we need to define the comparison operators
+// as we need versions that will return vector<bool>
+
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator+(const __half2& lh, const __half2& rh) { return __hadd2(lh, rh); }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(const __half2& lh, const __half2& rh) { return __hsub2(lh, rh); }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator*(const __half2& lh, const __half2& rh) { return __hmul2(lh, rh); }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator/(const __half2& lh, const __half2& rh) { return __h2div(lh, rh); }
+
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator+=(__half2& lh, const __half2& rh) { lh = __hadd2(lh, rh); return lh; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator-=(__half2& lh, const __half2& rh) { lh = __hsub2(lh, rh); return lh; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator*=(__half2& lh, const __half2& rh) { lh = __hmul2(lh, rh); return lh; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator/=(__half2& lh, const __half2& rh) { lh = __h2div(lh, rh); return lh; }
+
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 &operator++(__half2 &h) { __half2_raw one; one.x = 0x3C00; one.y = 0x3C00; h = __hadd2(h, one); return h; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 &operator--(__half2 &h) { __half2_raw one; one.x = 0x3C00; one.y = 0x3C00; h = __hsub2(h, one); return h; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator++(__half2 &h, int) { __half2 ret = h; __half2_raw one; one.x = 0x3C00; one.y = 0x3C00; h = __hadd2(h, one); return ret; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator--(__half2 &h, int) { __half2 ret = h; __half2_raw one; one.x = 0x3C00; one.y = 0x3C00; h = __hsub2(h, one); return ret; }
+
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator+(const __half2 &h) { return h; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(const __half2 &h) { return __hneg2(h); }
// vec op scalar
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator+(const __half2& lh, __half rh) { return __hadd2(lh, __half2half2(rh)); }
@@ -210,16 +250,7 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(__half lh, const __half2& r
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator*(__half lh, const __half2& rh) { return __hmul2(__half2half2(lh), rh); }
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator/(__half lh, const __half2& rh) { return __h2div(__half2half2(lh), rh); }
-// Mechanism to make half vectors
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 make___half2(__half x, __half y) { return __halves2half2(x, y); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 make___half3(__half x, __half y, __half z) { return __half3{ __halves2half2(x, y), z }; }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 make___half4(__half x, __half y, __half z, __half w) { return __half4{ __halves2half2(x, y), __halves2half2(z, w)}; }
-
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 constructFromScalar___half2(half x) { return __half2half2(x); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 constructFromScalar___half3(half x) { return __half3{__half2half2(x), x}; }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 constructFromScalar___half4(half x) { const __half2 v = __half2half2(x); return __half4{v, v}; }
-
-// Half3 maths ops
+// *** half3 ***
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator+(const __half3& lh, const __half3& rh) { return __half3{__hadd2(lh.xy, rh.xy), __hadd(lh.z, rh.z)}; }
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator-(const __half3& lh, const __half3& rh) { return __half3{__hsub2(lh.xy, rh.xy), __hsub(lh.z, rh.z)}; }
@@ -241,18 +272,7 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator-(__half lh, const __half3& r
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator*(__half lh, const __half3& rh) { return __half3{__hmul2(__half2half2(lh), rh.xy), __hmul(lh, rh.z)}; }
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator/(__half lh, const __half3& rh) { return __half3{__h2div(__half2half2(lh), rh.xy), __hdiv(lh, rh.z)}; }
-
-#if 0
-// We need to return the vector<bool> type
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator==(const __half3& lh, const __half3& rh) { return __hbeq2(lh.xy, rh.xy) && __heq(lh.z, rh.z); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator!=(const __half3& lh, const __half3& rh) { return __hbneu2(lh.xy, rh.xy) && __hneu(lh.z, rh.z); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator>(const __half3& lh, const __half3& rh) { return __hbgt2(lh.xy, rh.xy) && __hgt(lh.z, rh.z); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator<(const __half3& lh, const __half3& rh) { return __hblt2(lh.xy, rh.xy) && __hlt(lh.z, rh.z); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator>=(const __half3& lh, const __half3& rh) { return __hbge2(lh.xy, rh.xy) && __hge(lh.z, rh.z); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator<=(const __half3& lh, const __half3& rh) { return __hble2(lh.xy, rh.xy) && __hle(lh.z, rh.z); }
-#endif
-
-// Half4 maths ops
+// *** half4 ***
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator+(const __half4& lh, const __half4& rh) { return __half4{__hadd2(lh.xy, rh.xy), __hadd2(lh.zw, rh.zw)}; }
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator-(const __half4& lh, const __half4& rh) { return __half4{__hsub2(lh.xy, rh.xy), __hsub2(lh.zw, rh.zw)}; }
@@ -274,28 +294,6 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator/(__half lh, const __half4& r
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator-(const __half4& h) { return __half4{__hneg2(h.xy), __hneg2(h.zw)}; }
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator+(const __half4& h) { return h; }
-#if 0
-// We need to return vector<bool> type
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator==(const __half4& lh, const __half4& rh) { return __hbeq2(lh.xy, rh.xy) && __hbeq2(lh.zw, rh.zw); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator!=(const __half4& lh, const __half4& rh) { return __hbneu2(lh.xy, rh.xy) && __hbneu2(lh.zw, rh.zw); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator>(const __half4& lh, const __half4& rh) { return __hbgt2(lh.xy, rh.xy) && __hbgt2(lh.zw, rh.zw); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator<(const __half4& lh, const __half4& rh) { return __hblt2(lh.xy, rh.xy) && __hblt2(lh.zw, rh.zw); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator>=(const __half4& lh, const __half4& rh) { return __hbge2(lh.xy, rh.xy) && __hbge2(lh.zw, rh.zw); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator<=(const __half4& lh, const __half4& rh) { return __hble2(lh.xy, rh.xy) && __hble2(lh.zw, rh.zw); }
-#endif
-
-// Use the round nearest as the default - it is the only one defined
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 __float22half2(const float2 a) { return __float22half2_rn(a); }
-
-// Implement the vector versions
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 __float2half(float2 a) { return __float22half2(a); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 __float2half(float3 a) { __half3 o; o.xy = __float22half2(make_float2(a.x, a.y)); o.z = __float2half(a.z); return o; }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 __float2half(float4 a) { __half4 o; o.xy = __float22half2(make_float2(a.x, a.y)); o.zw = __float22half2(make_float2(a.z, a.w)); return o; }
-
-SLANG_FORCE_INLINE SLANG_CUDA_CALL float2 __half2float(__half2 a) { return __half22float2(a); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL float3 __half2float(__half3 a) { float2 xy = __half22float2(a.xy); float z = __half2float(a.z); return make_float3(xy.x, xy.y, z); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL float4 __half2float(__half4 a) { float2 xy = __half22float2(a.xy); float2 zw = __half22float2(a.zw); return make_float4(xy.x, xy.y, zw.x, zw.y); }
-
#endif
// ----------------------------- F32 -----------------------------------------