diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-26 13:59:11 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-26 13:59:11 -0700 |
| commit | d64ee86a3130f8eeb75d09193c38c621d7565eba (patch) | |
| tree | fed25a0cc2a7372d26175774f5983bed693e6b64 /prelude | |
| parent | 666af0962b6ab41489a3a3287db83f77c2f6461a (diff) | |
Add PyTorch C++ binding generation. (#2734)
* Add PyTorch C++ binding generation.
* fix
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'prelude')
| -rw-r--r-- | prelude/slang-cpp-types-core.h | 561 | ||||
| -rw-r--r-- | prelude/slang-cpp-types.h | 554 | ||||
| -rw-r--r-- | prelude/slang-cuda-prelude.h | 71 | ||||
| -rw-r--r-- | prelude/slang-torch-prelude.h | 126 |
4 files changed, 759 insertions, 553 deletions
diff --git a/prelude/slang-cpp-types-core.h b/prelude/slang-cpp-types-core.h new file mode 100644 index 000000000..c49ee013c --- /dev/null +++ b/prelude/slang-cpp-types-core.h @@ -0,0 +1,561 @@ +#ifndef SLANG_PRELUDE_CPP_TYPES_CORE_H +#define SLANG_PRELUDE_CPP_TYPES_CORE_H + +#ifndef SLANG_PRELUDE_ASSERT +# ifdef SLANG_PRELUDE_ENABLE_ASSERT +# define SLANG_PRELUDE_ASSERT(VALUE) assert(VALUE) +# else +# define SLANG_PRELUDE_ASSERT(VALUE) +# endif +#endif + +// Since we are using unsigned arithmatic care is need in this comparison. +// It is *assumed* that sizeInBytes >= elemSize. Which means (sizeInBytes >= elemSize) >= 0 +// Which means only a single test is needed + +// Asserts for bounds checking. +// It is assumed index/count are unsigned types. +#define SLANG_BOUND_ASSERT(index, count) SLANG_PRELUDE_ASSERT(index < count); +#define SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_PRELUDE_ASSERT(index <= (sizeInBytes - elemSize) && (index & 3) == 0); + +// Macros to zero index if an access is out of range +#define SLANG_BOUND_ZERO_INDEX(index, count) index = (index < count) ? index : 0; +#define SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) index = (index <= (sizeInBytes - elemSize)) ? index : 0; + +// The 'FIX' macro define how the index is fixed. The default is to do nothing. If SLANG_ENABLE_BOUND_ZERO_INDEX +// the fix macro will zero the index, if out of range +#ifdef SLANG_ENABLE_BOUND_ZERO_INDEX +# define SLANG_BOUND_FIX(index, count) SLANG_BOUND_ZERO_INDEX(index, count) +# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) +# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) SLANG_BOUND_ZERO_INDEX(index, count) +#else +# define SLANG_BOUND_FIX(index, count) +# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) +# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) +#endif + +#ifndef SLANG_BOUND_CHECK +# define SLANG_BOUND_CHECK(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX(index, count) +#endif + +#ifndef SLANG_BOUND_CHECK_BYTE_ADDRESS +# define SLANG_BOUND_CHECK_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) +#endif + +#ifndef SLANG_BOUND_CHECK_FIXED_ARRAY +# define SLANG_BOUND_CHECK_FIXED_ARRAY(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX_FIXED_ARRAY(index, count) +#endif + +struct TypeInfo +{ + size_t typeSize; +}; + +template <typename T, size_t SIZE> +struct FixedArray +{ + const T& operator[](size_t index) const { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; } + T& operator[](size_t index) { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; } + + T m_data[SIZE]; +}; + +// An array that has no specified size, becomes a 'Array'. This stores the size so it can potentially +// do bounds checking. +template <typename T> +struct Array +{ + const T& operator[](size_t index) const { SLANG_BOUND_CHECK(index, count); return data[index]; } + T& operator[](size_t index) { SLANG_BOUND_CHECK(index, count); return data[index]; } + + T* data; + size_t count; +}; + +/* Constant buffers become a pointer to the contained type, so ConstantBuffer<T> becomes T* in C++ code. +*/ + +template <typename T, int COUNT> +struct Vector; + +template <typename T> +struct Vector<T, 1> +{ + T x; + const T& operator[](size_t /*index*/) const { return x; } + T& operator[](size_t /*index*/) { return x; } + operator T() const { return x; } + Vector() = default; + Vector(T scalar) + { + x = scalar; + } + template <typename U> + Vector(Vector<U, 1> other) + { + x = (T)other.x; + } + template <typename U, int otherSize> + Vector(Vector<U, otherSize> other) + { + int minSize = 1; + if (otherSize < minSize) minSize = otherSize; + for (int i = 0; i < minSize; i++) + (*this)[i] = (T)other[i]; + } +}; + +template <typename T> +struct Vector<T, 2> +{ + T x, y; + const T& operator[](size_t index) const { return index == 0 ? x : y; } + T& operator[](size_t index) { return index == 0 ? x : y; } + Vector() = default; + Vector(T scalar) + { + x = y = scalar; + } + Vector(T _x, T _y) + { + x = _x; + y = _y; + } + template <typename U> + Vector(Vector<U, 2> other) + { + x = (T)other.x; + y = (T)other.y; + } + template <typename U, int otherSize> + Vector(Vector<U, otherSize> other) + { + int minSize = 2; + if (otherSize < minSize) minSize = otherSize; + for (int i = 0; i < minSize; i++) + (*this)[i] = (T)other[i]; + } +}; + +template <typename T> +struct Vector<T, 3> +{ + T x, y, z; + const T& operator[](size_t index) const { return *((T*)(this) + index); } + T& operator[](size_t index) { return *((T*)(this) + index); } + + Vector() = default; + Vector(T scalar) + { + x = y = z = scalar; + } + Vector(T _x, T _y, T _z) + { + x = _x; + y = _y; + z = _z; + } + template <typename U> + Vector(Vector<U, 3> other) + { + x = (T)other.x; + y = (T)other.y; + z = (T)other.z; + } + template <typename U, int otherSize> + Vector(Vector<U, otherSize> other) + { + int minSize = 3; + if (otherSize < minSize) minSize = otherSize; + for (int i = 0; i < minSize; i++) + (*this)[i] = (T)other[i]; + } +}; + +template <typename T> +struct Vector<T, 4> +{ + T x, y, z, w; + + const T& operator[](size_t index) const { return *((T*)(this) + index); } + T& operator[](size_t index) { return *((T*)(this) + index); } + Vector() = default; + Vector(T scalar) + { + x = y = z = w = scalar; + } + Vector(T _x, T _y, T _z, T _w) + { + x = _x; + y = _y; + z = _z; + w = _w; + } + template <typename U, int otherSize> + Vector(Vector<U, otherSize> other) + { + int minSize = 4; + if (otherSize < minSize) minSize = otherSize; + for (int i = 0; i < minSize; i++) + (*this)[i] = (T)other[i]; + } + +}; + +template<typename T, int N> +SLANG_FORCE_INLINE T _slang_vector_get_element(Vector<T, N> x, int index) +{ + return x[index]; +} + +template<typename T, int N> +SLANG_FORCE_INLINE const T* _slang_vector_get_element_ptr(const Vector<T, N>* x, int index) +{ + return &((*const_cast<Vector<T,N>*>(x))[index]); +} + +template<typename T, int N> +SLANG_FORCE_INLINE T* _slang_vector_get_element_ptr(Vector<T, N>* x, int index) +{ + return &((*x)[index]); +} + +template<typename T, int n, typename OtherT, int m> +SLANG_FORCE_INLINE Vector<T, n> _slang_vector_reshape(const Vector<OtherT, m> other) +{ + Vector<T, n> result; + for (int i = 0; i < n; i++) + { + OtherT otherElement = T(0); + if (i < m) + otherElement = _slang_vector_get_element(other, i); + *_slang_vector_get_element_ptr(&result, i) = (T)otherElement; + } + return result; +} + +typedef uint32_t uint; + +#define SLANG_VECTOR_BINARY_OP(T, op) \ + template<int n> \ + SLANG_FORCE_INLINE Vector<T, n> operator op(const Vector<T, n>& thisVal, const Vector<T, n>& other) \ + { \ + Vector<T, n> result;\ + for (int i = 0; i < n; i++) \ + result[i] = thisVal[i] op other[i]; \ + return result;\ + } +#define SLANG_VECTOR_BINARY_COMPARE_OP(T, op) \ + template<int n> \ + SLANG_FORCE_INLINE Vector<bool, n> operator op(const Vector<T, n>& thisVal, const Vector<T, n>& other) \ + { \ + Vector<bool, n> result;\ + for (int i = 0; i < n; i++) \ + result[i] = thisVal[i] op other[i]; \ + return result;\ + } + +#define SLANG_VECTOR_UNARY_OP(T, op) \ + template<int n> \ + SLANG_FORCE_INLINE Vector<T, n> operator op(const Vector<T, n>& thisVal) \ + { \ + Vector<T, n> result;\ + for (int i = 0; i < n; i++) \ + result[i] = op thisVal[i]; \ + return result;\ + } +#define SLANG_INT_VECTOR_OPS(T) \ + SLANG_VECTOR_BINARY_OP(T, +)\ + SLANG_VECTOR_BINARY_OP(T, -)\ + SLANG_VECTOR_BINARY_OP(T, *)\ + SLANG_VECTOR_BINARY_OP(T, / )\ + SLANG_VECTOR_BINARY_OP(T, &)\ + SLANG_VECTOR_BINARY_OP(T, |)\ + SLANG_VECTOR_BINARY_OP(T, &&)\ + SLANG_VECTOR_BINARY_OP(T, ||)\ + SLANG_VECTOR_BINARY_OP(T, ^)\ + SLANG_VECTOR_BINARY_OP(T, %)\ + SLANG_VECTOR_BINARY_OP(T, >>)\ + SLANG_VECTOR_BINARY_OP(T, <<)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, !=)\ + SLANG_VECTOR_UNARY_OP(T, !)\ + SLANG_VECTOR_UNARY_OP(T, ~) +#define SLANG_FLOAT_VECTOR_OPS(T) \ + SLANG_VECTOR_BINARY_OP(T, +)\ + SLANG_VECTOR_BINARY_OP(T, -)\ + SLANG_VECTOR_BINARY_OP(T, *)\ + SLANG_VECTOR_BINARY_OP(T, /)\ + SLANG_VECTOR_UNARY_OP(T, -)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\ + SLANG_VECTOR_BINARY_COMPARE_OP(T, !=) + +SLANG_INT_VECTOR_OPS(bool) +SLANG_INT_VECTOR_OPS(int) +SLANG_INT_VECTOR_OPS(int8_t) +SLANG_INT_VECTOR_OPS(int16_t) +SLANG_INT_VECTOR_OPS(int64_t) +SLANG_INT_VECTOR_OPS(uint) +SLANG_INT_VECTOR_OPS(uint8_t) +SLANG_INT_VECTOR_OPS(uint16_t) +SLANG_INT_VECTOR_OPS(uint64_t) + +SLANG_FLOAT_VECTOR_OPS(float) +SLANG_FLOAT_VECTOR_OPS(double) + +#define SLANG_VECTOR_INT_NEG_OP(T) \ + template<int N>\ + Vector<T, N> operator-(const Vector<T, N>& thisVal) \ + { \ + Vector<T, N> result;\ + for (int i = 0; i < N; i++) \ + result[i] = 0 - thisVal[i]; \ + return result;\ + } +SLANG_VECTOR_INT_NEG_OP(int) +SLANG_VECTOR_INT_NEG_OP(int8_t) +SLANG_VECTOR_INT_NEG_OP(int16_t) +SLANG_VECTOR_INT_NEG_OP(int64_t) +SLANG_VECTOR_INT_NEG_OP(uint) +SLANG_VECTOR_INT_NEG_OP(uint8_t) +SLANG_VECTOR_INT_NEG_OP(uint16_t) +SLANG_VECTOR_INT_NEG_OP(uint64_t) + +#define SLANG_FLOAT_VECTOR_MOD(T)\ + template<int N> \ + Vector<T, N> operator%(const Vector<T, N>& left, const Vector<T, N>& right) \ + {\ + Vector<T, N> result;\ + for (int i = 0; i < N; i++) \ + result[i] = _slang_fmod(left[i], right[i]); \ + return result;\ + } + +SLANG_FLOAT_VECTOR_MOD(float) +SLANG_FLOAT_VECTOR_MOD(double) +#undef SLANG_FLOAT_VECTOR_MOD +#undef SLANG_VECTOR_BINARY_OP +#undef SLANG_VECTOR_UNARY_OP +#undef SLANG_INT_VECTOR_OPS +#undef SLANG_FLOAT_VECTOR_OPS +#undef SLANG_VECTOR_INT_NEG_OP +#undef SLANG_FLOAT_VECTOR_MOD + +template <typename T, int ROWS, int COLS> +struct Matrix +{ + Vector<T, COLS> rows[ROWS]; + Vector<T, COLS>& operator[](size_t index) { return rows[index]; } + Matrix() = default; + Matrix(T scalar) + { + for (int i = 0; i < ROWS; i++) + rows[i] = Vector<T, COLS>(scalar); + } + Matrix(const Vector<T, COLS>& row0) + { + rows[0] = row0; + } + Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1) + { + rows[0] = row0; + rows[1] = row1; + } + Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2) + { + rows[0] = row0; + rows[1] = row1; + rows[2] = row2; + } + Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2, const Vector<T, COLS>& row3) + { + rows[0] = row0; + rows[1] = row1; + rows[2] = row2; + rows[3] = row3; + } + template<typename U, int otherRow, int otherCol> + Matrix(const Matrix<U, otherRow, otherCol>& other) + { + int minRow = ROWS; + int minCol = COLS; + if (minRow > otherRow) minRow = otherRow; + if (minCol > otherCol) minCol = otherCol; + for (int i = 0; i < minRow; i++) + for (int j = 0; j < minCol; j++) + rows[i][j] = (T)other.rows[i][j]; + } + Matrix(T v0, T v1, T v2, T v3) + { + rows[0][0] = v0; rows[0][1] = v1; + rows[1][0] = v2; rows[1][1] = v3; + } + Matrix(T v0, T v1, T v2, T v3, T v4, T v5) + { + if (COLS == 3) + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; + rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; + } + else + { + rows[0][0] = v0; rows[0][1] = v1; + rows[1][0] = v2; rows[1][1] = v3; + rows[2][0] = v4; rows[2][1] = v5; + } + } + Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7) + { + if (COLS == 4) + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; + rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; + } + else + { + rows[0][0] = v0; rows[0][1] = v1; + rows[1][0] = v2; rows[1][1] = v3; + rows[2][0] = v4; rows[2][1] = v5; + rows[3][0] = v6; rows[3][1] = v7; + } + } + Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8) + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; + rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; + rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8; + } + Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11) + { + if (COLS == 4) + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; + rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; + rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11; + } + else + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; + rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; + rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8; + rows[3][0] = v9; rows[3][1] = v10; rows[3][2] = v11; + } + } + Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15) + { + rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; + rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; + rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11; + rows[3][0] = v12; rows[3][1] = v13; rows[3][2] = v14; rows[3][3] = v15; + } +}; + +#define SLANG_MATRIX_BINARY_OP(T, op) \ + template<int R, int C> \ + Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal, const Matrix<T, R, C>& other) \ + { \ + Matrix<T, R, C> result;\ + for (int i = 0; i < R; i++) \ + for (int j = 0; j < C; j++) \ + result.rows[i][j] = thisVal.rows[i][j] op other.rows[i][j]; \ + return result;\ + } + +#define SLANG_MATRIX_UNARY_OP(T, op) \ + template<int R, int C> \ + Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal) \ + { \ + Matrix<T, R, C> result;\ + for (int i = 0; i < R; i++) \ + for (int j = 0; j < C; j++) \ + result[i].rows[i][j] = op thisVal.rows[i][j]; \ + return result;\ + } +#define SLANG_INT_MATRIX_OPS(T) \ + SLANG_MATRIX_BINARY_OP(T, +)\ + SLANG_MATRIX_BINARY_OP(T, -)\ + SLANG_MATRIX_BINARY_OP(T, *)\ + SLANG_MATRIX_BINARY_OP(T, / )\ + SLANG_MATRIX_BINARY_OP(T, &)\ + SLANG_MATRIX_BINARY_OP(T, |)\ + SLANG_MATRIX_BINARY_OP(T, &&)\ + SLANG_MATRIX_BINARY_OP(T, ||)\ + SLANG_MATRIX_BINARY_OP(T, ^)\ + SLANG_MATRIX_BINARY_OP(T, %)\ + SLANG_MATRIX_UNARY_OP(T, !)\ + SLANG_MATRIX_UNARY_OP(T, ~) +#define SLANG_FLOAT_MATRIX_OPS(T) \ + SLANG_MATRIX_BINARY_OP(T, +)\ + SLANG_MATRIX_BINARY_OP(T, -)\ + SLANG_MATRIX_BINARY_OP(T, *)\ + SLANG_MATRIX_BINARY_OP(T, /)\ + SLANG_MATRIX_UNARY_OP(T, -) +SLANG_INT_MATRIX_OPS(int) +SLANG_INT_MATRIX_OPS(int8_t) +SLANG_INT_MATRIX_OPS(int16_t) +SLANG_INT_MATRIX_OPS(int64_t) +SLANG_INT_MATRIX_OPS(uint) +SLANG_INT_MATRIX_OPS(uint8_t) +SLANG_INT_MATRIX_OPS(uint16_t) +SLANG_INT_MATRIX_OPS(uint64_t) + +SLANG_FLOAT_MATRIX_OPS(float) +SLANG_FLOAT_MATRIX_OPS(double) + +#define SLANG_MATRIX_INT_NEG_OP(T) \ + template<int R, int C>\ + SLANG_FORCE_INLINE Matrix<T, R, C> operator-(Matrix<T, R, C> thisVal) \ + { \ + Matrix<T, R, C> result;\ + for (int i = 0; i < R; i++) \ + for (int j = 0; j < C; j++) \ + result.rows[i][j] = 0 - thisVal.rows[i][j]; \ + return result;\ + } + SLANG_MATRIX_INT_NEG_OP(int) + SLANG_MATRIX_INT_NEG_OP(int8_t) + SLANG_MATRIX_INT_NEG_OP(int16_t) + SLANG_MATRIX_INT_NEG_OP(int64_t) + SLANG_MATRIX_INT_NEG_OP(uint) + SLANG_MATRIX_INT_NEG_OP(uint8_t) + SLANG_MATRIX_INT_NEG_OP(uint16_t) + SLANG_MATRIX_INT_NEG_OP(uint64_t) + +#define SLANG_FLOAT_MATRIX_MOD(T)\ + template<int R, int C> \ + SLANG_FORCE_INLINE Matrix<T, R, C> operator%(Matrix<T, R, C> left, Matrix<T, R, C> right) \ + {\ + Matrix<T, R, C> result;\ + for (int i = 0; i < R; i++) \ + for (int j = 0; j < C; j++) \ + result.rows[i][j] = _slang_fmod(left.rows[i][j], right.rows[i][j]); \ + return result;\ + } + + SLANG_FLOAT_MATRIX_MOD(float) + SLANG_FLOAT_MATRIX_MOD(double) +#undef SLANG_FLOAT_MATRIX_MOD +#undef SLANG_MATRIX_BINARY_OP +#undef SLANG_MATRIX_UNARY_OP +#undef SLANG_INT_MATRIX_OPS +#undef SLANG_FLOAT_MATRIX_OPS +#undef SLANG_MATRIX_INT_NEG_OP +#undef SLANG_FLOAT_MATRIX_MOD + +template<typename TResult, typename TInput> +TResult slang_bit_cast(TInput val) +{ + return *(TResult*)(&val); +} + +#endif + + diff --git a/prelude/slang-cpp-types.h b/prelude/slang-cpp-types.h index 28fe3dd8d..ac66ad9f3 100644 --- a/prelude/slang-cpp-types.h +++ b/prelude/slang-cpp-types.h @@ -1,244 +1,12 @@ #ifndef SLANG_PRELUDE_CPP_TYPES_H #define SLANG_PRELUDE_CPP_TYPES_H -#ifndef SLANG_PRELUDE_ASSERT -# ifdef SLANG_PRELUDE_ENABLE_ASSERT -# define SLANG_PRELUDE_ASSERT(VALUE) assert(VALUE) -# else -# define SLANG_PRELUDE_ASSERT(VALUE) -# endif -#endif - -// Since we are using unsigned arithmatic care is need in this comparison. -// It is *assumed* that sizeInBytes >= elemSize. Which means (sizeInBytes >= elemSize) >= 0 -// Which means only a single test is needed - -// Asserts for bounds checking. -// It is assumed index/count are unsigned types. -#define SLANG_BOUND_ASSERT(index, count) SLANG_PRELUDE_ASSERT(index < count); -#define SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_PRELUDE_ASSERT(index <= (sizeInBytes - elemSize) && (index & 3) == 0); - -// Macros to zero index if an access is out of range -#define SLANG_BOUND_ZERO_INDEX(index, count) index = (index < count) ? index : 0; -#define SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) index = (index <= (sizeInBytes - elemSize)) ? index : 0; - -// The 'FIX' macro define how the index is fixed. The default is to do nothing. If SLANG_ENABLE_BOUND_ZERO_INDEX -// the fix macro will zero the index, if out of range -#ifdef SLANG_ENABLE_BOUND_ZERO_INDEX -# define SLANG_BOUND_FIX(index, count) SLANG_BOUND_ZERO_INDEX(index, count) -# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) -# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) SLANG_BOUND_ZERO_INDEX(index, count) -#else -# define SLANG_BOUND_FIX(index, count) -# define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) -# define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) -#endif - -#ifndef SLANG_BOUND_CHECK -# define SLANG_BOUND_CHECK(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX(index, count) -#endif - -#ifndef SLANG_BOUND_CHECK_BYTE_ADDRESS -# define SLANG_BOUND_CHECK_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) -#endif - -#ifndef SLANG_BOUND_CHECK_FIXED_ARRAY -# define SLANG_BOUND_CHECK_FIXED_ARRAY(index, count) SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX_FIXED_ARRAY(index, count) -#endif - #ifdef SLANG_PRELUDE_NAMESPACE namespace SLANG_PRELUDE_NAMESPACE { #endif -struct TypeInfo -{ - size_t typeSize; -}; - -template <typename T, size_t SIZE> -struct FixedArray -{ - const T& operator[](size_t index) const { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; } - T& operator[](size_t index) { SLANG_BOUND_CHECK_FIXED_ARRAY(index, SIZE); return m_data[index]; } - - T m_data[SIZE]; -}; - -// An array that has no specified size, becomes a 'Array'. This stores the size so it can potentially -// do bounds checking. -template <typename T> -struct Array -{ - const T& operator[](size_t index) const { SLANG_BOUND_CHECK(index, count); return data[index]; } - T& operator[](size_t index) { SLANG_BOUND_CHECK(index, count); return data[index]; } - - T* data; - size_t count; -}; - -/* Constant buffers become a pointer to the contained type, so ConstantBuffer<T> becomes T* in C++ code. -*/ - -template <typename T, int COUNT> -struct Vector; - -template <typename T> -struct Vector<T, 1> -{ - T x; - const T& operator[](size_t /*index*/) const { return x; } - T& operator[](size_t /*index*/) { return x; } - operator T() const { return x; } - Vector() = default; - Vector(T scalar) - { - x = scalar; - } - template <typename U> - Vector(Vector<U, 1> other) - { - x = (T)other.x; - } - template <typename U, int otherSize> - Vector(Vector<U, otherSize> other) - { - int minSize = 1; - if (otherSize < minSize) minSize = otherSize; - for (int i = 0; i < minSize; i++) - (*this)[i] = (T)other[i]; - } -}; - -template <typename T> -struct Vector<T, 2> -{ - T x, y; - const T& operator[](size_t index) const { return index == 0 ? x : y; } - T& operator[](size_t index) { return index == 0 ? x : y; } - Vector() = default; - Vector(T scalar) - { - x = y = scalar; - } - Vector(T _x, T _y) - { - x = _x; - y = _y; - } - template <typename U> - Vector(Vector<U, 2> other) - { - x = (T)other.x; - y = (T)other.y; - } - template <typename U, int otherSize> - Vector(Vector<U, otherSize> other) - { - int minSize = 2; - if (otherSize < minSize) minSize = otherSize; - for (int i = 0; i < minSize; i++) - (*this)[i] = (T)other[i]; - } -}; - -template <typename T> -struct Vector<T, 3> -{ - T x, y, z; - const T& operator[](size_t index) const { return *((T*)(this) + index); } - T& operator[](size_t index) { return *((T*)(this) + index); } - - Vector() = default; - Vector(T scalar) - { - x = y = z = scalar; - } - Vector(T _x, T _y, T _z) - { - x = _x; - y = _y; - z = _z; - } - template <typename U> - Vector(Vector<U, 3> other) - { - x = (T)other.x; - y = (T)other.y; - z = (T)other.z; - } - template <typename U, int otherSize> - Vector(Vector<U, otherSize> other) - { - int minSize = 3; - if (otherSize < minSize) minSize = otherSize; - for (int i = 0; i < minSize; i++) - (*this)[i] = (T)other[i]; - } -}; -template <typename T> -struct Vector<T, 4> -{ - T x, y, z, w; - - const T& operator[](size_t index) const { return *((T*)(this) + index); } - T& operator[](size_t index) { return *((T*)(this) + index); } - Vector() = default; - Vector(T scalar) - { - x = y = z = w = scalar; - } - Vector(T _x, T _y, T _z, T _w) - { - x = _x; - y = _y; - z = _z; - w = _w; - } - template <typename U, int otherSize> - Vector(Vector<U, otherSize> other) - { - int minSize = 4; - if (otherSize < minSize) minSize = otherSize; - for (int i = 0; i < minSize; i++) - (*this)[i] = (T)other[i]; - } - -}; - -template<typename T, int N> -SLANG_FORCE_INLINE T _slang_vector_get_element(Vector<T, N> x, int index) -{ - return x[index]; -} - -template<typename T, int N> -SLANG_FORCE_INLINE const T* _slang_vector_get_element_ptr(const Vector<T, N>* x, int index) -{ - return &((*const_cast<Vector<T,N>*>(x))[index]); -} - -template<typename T, int N> -SLANG_FORCE_INLINE T* _slang_vector_get_element_ptr(Vector<T, N>* x, int index) -{ - return &((*x)[index]); -} - -template<typename T, int n, typename OtherT, int m> -SLANG_FORCE_INLINE Vector<T, n> _slang_vector_reshape(const Vector<OtherT, m> other) -{ - Vector<T, n> result; - for (int i = 0; i < n; i++) - { - OtherT otherElement = T(0); - if (i < m) - otherElement = _slang_vector_get_element(other, i); - *_slang_vector_get_element_ptr(&result, i) = (T)otherElement; - } - return result; -} - -typedef uint32_t uint; +#include "slang-cpp-types-core.h" typedef Vector<float, 2> float2; typedef Vector<float, 3> float3; @@ -252,320 +20,6 @@ typedef Vector<uint32_t, 2> uint2; typedef Vector<uint32_t, 3> uint3; typedef Vector<uint32_t, 4> uint4; -#define SLANG_VECTOR_BINARY_OP(T, op) \ - template<int n> \ - SLANG_FORCE_INLINE Vector<T, n> operator op(const Vector<T, n>& thisVal, const Vector<T, n>& other) \ - { \ - Vector<T, n> result;\ - for (int i = 0; i < n; i++) \ - result[i] = thisVal[i] op other[i]; \ - return result;\ - } -#define SLANG_VECTOR_BINARY_COMPARE_OP(T, op) \ - template<int n> \ - SLANG_FORCE_INLINE Vector<bool, n> operator op(const Vector<T, n>& thisVal, const Vector<T, n>& other) \ - { \ - Vector<bool, n> result;\ - for (int i = 0; i < n; i++) \ - result[i] = thisVal[i] op other[i]; \ - return result;\ - } - -#define SLANG_VECTOR_UNARY_OP(T, op) \ - template<int n> \ - SLANG_FORCE_INLINE Vector<T, n> operator op(const Vector<T, n>& thisVal) \ - { \ - Vector<T, n> result;\ - for (int i = 0; i < n; i++) \ - result[i] = op thisVal[i]; \ - return result;\ - } -#define SLANG_INT_VECTOR_OPS(T) \ - SLANG_VECTOR_BINARY_OP(T, +)\ - SLANG_VECTOR_BINARY_OP(T, -)\ - SLANG_VECTOR_BINARY_OP(T, *)\ - SLANG_VECTOR_BINARY_OP(T, / )\ - SLANG_VECTOR_BINARY_OP(T, &)\ - SLANG_VECTOR_BINARY_OP(T, |)\ - SLANG_VECTOR_BINARY_OP(T, &&)\ - SLANG_VECTOR_BINARY_OP(T, ||)\ - SLANG_VECTOR_BINARY_OP(T, ^)\ - SLANG_VECTOR_BINARY_OP(T, %)\ - SLANG_VECTOR_BINARY_OP(T, >>)\ - SLANG_VECTOR_BINARY_OP(T, <<)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, !=)\ - SLANG_VECTOR_UNARY_OP(T, !)\ - SLANG_VECTOR_UNARY_OP(T, ~) -#define SLANG_FLOAT_VECTOR_OPS(T) \ - SLANG_VECTOR_BINARY_OP(T, +)\ - SLANG_VECTOR_BINARY_OP(T, -)\ - SLANG_VECTOR_BINARY_OP(T, *)\ - SLANG_VECTOR_BINARY_OP(T, /)\ - SLANG_VECTOR_UNARY_OP(T, -)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, >)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, <)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, >=)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, <=)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, ==)\ - SLANG_VECTOR_BINARY_COMPARE_OP(T, !=) - -SLANG_INT_VECTOR_OPS(bool) -SLANG_INT_VECTOR_OPS(int) -SLANG_INT_VECTOR_OPS(int8_t) -SLANG_INT_VECTOR_OPS(int16_t) -SLANG_INT_VECTOR_OPS(int64_t) -SLANG_INT_VECTOR_OPS(uint) -SLANG_INT_VECTOR_OPS(uint8_t) -SLANG_INT_VECTOR_OPS(uint16_t) -SLANG_INT_VECTOR_OPS(uint64_t) - -SLANG_FLOAT_VECTOR_OPS(float) -SLANG_FLOAT_VECTOR_OPS(double) - -#define SLANG_VECTOR_INT_NEG_OP(T) \ - template<int N>\ - Vector<T, N> operator-(const Vector<T, N>& thisVal) \ - { \ - Vector<T, N> result;\ - for (int i = 0; i < N; i++) \ - result[i] = 0 - thisVal[i]; \ - return result;\ - } -SLANG_VECTOR_INT_NEG_OP(int) -SLANG_VECTOR_INT_NEG_OP(int8_t) -SLANG_VECTOR_INT_NEG_OP(int16_t) -SLANG_VECTOR_INT_NEG_OP(int64_t) -SLANG_VECTOR_INT_NEG_OP(uint) -SLANG_VECTOR_INT_NEG_OP(uint8_t) -SLANG_VECTOR_INT_NEG_OP(uint16_t) -SLANG_VECTOR_INT_NEG_OP(uint64_t) - -#define SLANG_FLOAT_VECTOR_MOD(T)\ - template<int N> \ - Vector<T, N> operator%(const Vector<T, N>& left, const Vector<T, N>& right) \ - {\ - Vector<T, N> result;\ - for (int i = 0; i < N; i++) \ - result[i] = _slang_fmod(left[i], right[i]); \ - return result;\ - } - -SLANG_FLOAT_VECTOR_MOD(float) -SLANG_FLOAT_VECTOR_MOD(double) -#undef SLANG_FLOAT_VECTOR_MOD -#undef SLANG_VECTOR_BINARY_OP -#undef SLANG_VECTOR_UNARY_OP -#undef SLANG_INT_VECTOR_OPS -#undef SLANG_FLOAT_VECTOR_OPS -#undef SLANG_VECTOR_INT_NEG_OP -#undef SLANG_FLOAT_VECTOR_MOD - -template <typename T, int ROWS, int COLS> -struct Matrix -{ - Vector<T, COLS> rows[ROWS]; - Vector<T, COLS>& operator[](size_t index) { return rows[index]; } - Matrix() = default; - Matrix(T scalar) - { - for (int i = 0; i < ROWS; i++) - rows[i] = Vector<T, COLS>(scalar); - } - Matrix(const Vector<T, COLS>& row0) - { - rows[0] = row0; - } - Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1) - { - rows[0] = row0; - rows[1] = row1; - } - Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2) - { - rows[0] = row0; - rows[1] = row1; - rows[2] = row2; - } - Matrix(const Vector<T, COLS>& row0, const Vector<T, COLS>& row1, const Vector<T, COLS>& row2, const Vector<T, COLS>& row3) - { - rows[0] = row0; - rows[1] = row1; - rows[2] = row2; - rows[3] = row3; - } - template<typename U, int otherRow, int otherCol> - Matrix(const Matrix<U, otherRow, otherCol>& other) - { - int minRow = ROWS; - int minCol = COLS; - if (minRow > otherRow) minRow = otherRow; - if (minCol > otherCol) minCol = otherCol; - for (int i = 0; i < minRow; i++) - for (int j = 0; j < minCol; j++) - rows[i][j] = (T)other.rows[i][j]; - } - Matrix(T v0, T v1, T v2, T v3) - { - rows[0][0] = v0; rows[0][1] = v1; - rows[1][0] = v2; rows[1][1] = v3; - } - Matrix(T v0, T v1, T v2, T v3, T v4, T v5) - { - if (COLS == 3) - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; - rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; - } - else - { - rows[0][0] = v0; rows[0][1] = v1; - rows[1][0] = v2; rows[1][1] = v3; - rows[2][0] = v4; rows[2][1] = v5; - } - } - Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7) - { - if (COLS == 4) - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; - rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; - } - else - { - rows[0][0] = v0; rows[0][1] = v1; - rows[1][0] = v2; rows[1][1] = v3; - rows[2][0] = v4; rows[2][1] = v5; - rows[3][0] = v6; rows[3][1] = v7; - } - } - Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8) - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; - rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; - rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8; - } - Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11) - { - if (COLS == 4) - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; - rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; - rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11; - } - else - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; - rows[1][0] = v3; rows[1][1] = v4; rows[1][2] = v5; - rows[2][0] = v6; rows[2][1] = v7; rows[2][2] = v8; - rows[3][0] = v9; rows[3][1] = v10; rows[3][2] = v11; - } - } - Matrix(T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15) - { - rows[0][0] = v0; rows[0][1] = v1; rows[0][2] = v2; rows[0][3] = v3; - rows[1][0] = v4; rows[1][1] = v5; rows[1][2] = v6; rows[1][3] = v7; - rows[2][0] = v8; rows[2][1] = v9; rows[2][2] = v10; rows[2][3] = v11; - rows[3][0] = v12; rows[3][1] = v13; rows[3][2] = v14; rows[3][3] = v15; - } -}; - -#define SLANG_MATRIX_BINARY_OP(T, op) \ - template<int R, int C> \ - Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal, const Matrix<T, R, C>& other) \ - { \ - Matrix<T, R, C> result;\ - for (int i = 0; i < R; i++) \ - for (int j = 0; j < C; j++) \ - result.rows[i][j] = thisVal.rows[i][j] op other.rows[i][j]; \ - return result;\ - } - -#define SLANG_MATRIX_UNARY_OP(T, op) \ - template<int R, int C> \ - Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal) \ - { \ - Matrix<T, R, C> result;\ - for (int i = 0; i < R; i++) \ - for (int j = 0; j < C; j++) \ - result[i].rows[i][j] = op thisVal.rows[i][j]; \ - return result;\ - } -#define SLANG_INT_MATRIX_OPS(T) \ - SLANG_MATRIX_BINARY_OP(T, +)\ - SLANG_MATRIX_BINARY_OP(T, -)\ - SLANG_MATRIX_BINARY_OP(T, *)\ - SLANG_MATRIX_BINARY_OP(T, / )\ - SLANG_MATRIX_BINARY_OP(T, &)\ - SLANG_MATRIX_BINARY_OP(T, |)\ - SLANG_MATRIX_BINARY_OP(T, &&)\ - SLANG_MATRIX_BINARY_OP(T, ||)\ - SLANG_MATRIX_BINARY_OP(T, ^)\ - SLANG_MATRIX_BINARY_OP(T, %)\ - SLANG_MATRIX_UNARY_OP(T, !)\ - SLANG_MATRIX_UNARY_OP(T, ~) -#define SLANG_FLOAT_MATRIX_OPS(T) \ - SLANG_MATRIX_BINARY_OP(T, +)\ - SLANG_MATRIX_BINARY_OP(T, -)\ - SLANG_MATRIX_BINARY_OP(T, *)\ - SLANG_MATRIX_BINARY_OP(T, /)\ - SLANG_MATRIX_UNARY_OP(T, -) -SLANG_INT_MATRIX_OPS(int) -SLANG_INT_MATRIX_OPS(int8_t) -SLANG_INT_MATRIX_OPS(int16_t) -SLANG_INT_MATRIX_OPS(int64_t) -SLANG_INT_MATRIX_OPS(uint) -SLANG_INT_MATRIX_OPS(uint8_t) -SLANG_INT_MATRIX_OPS(uint16_t) -SLANG_INT_MATRIX_OPS(uint64_t) - -SLANG_FLOAT_MATRIX_OPS(float) -SLANG_FLOAT_MATRIX_OPS(double) - -#define SLANG_MATRIX_INT_NEG_OP(T) \ - template<int R, int C>\ - SLANG_FORCE_INLINE Matrix<T, R, C> operator-(Matrix<T, R, C> thisVal) \ - { \ - Matrix<T, R, C> result;\ - for (int i = 0; i < R; i++) \ - for (int j = 0; j < C; j++) \ - result.rows[i][j] = 0 - thisVal.rows[i][j]; \ - return result;\ - } - SLANG_MATRIX_INT_NEG_OP(int) - SLANG_MATRIX_INT_NEG_OP(int8_t) - SLANG_MATRIX_INT_NEG_OP(int16_t) - SLANG_MATRIX_INT_NEG_OP(int64_t) - SLANG_MATRIX_INT_NEG_OP(uint) - SLANG_MATRIX_INT_NEG_OP(uint8_t) - SLANG_MATRIX_INT_NEG_OP(uint16_t) - SLANG_MATRIX_INT_NEG_OP(uint64_t) - -#define SLANG_FLOAT_MATRIX_MOD(T)\ - template<int R, int C> \ - SLANG_FORCE_INLINE Matrix<T, R, C> operator%(Matrix<T, R, C> left, Matrix<T, R, C> right) \ - {\ - Matrix<T, R, C> result;\ - for (int i = 0; i < R; i++) \ - for (int j = 0; j < C; j++) \ - result.rows[i][j] = _slang_fmod(left.rows[i][j], right.rows[i][j]); \ - return result;\ - } - - SLANG_FLOAT_MATRIX_MOD(float) - SLANG_FLOAT_MATRIX_MOD(double) -#undef SLANG_FLOAT_MATRIX_MOD -#undef SLANG_MATRIX_BINARY_OP -#undef SLANG_MATRIX_UNARY_OP -#undef SLANG_INT_MATRIX_OPS -#undef SLANG_FLOAT_MATRIX_OPS -#undef SLANG_MATRIX_INT_NEG_OP -#undef SLANG_FLOAT_MATRIX_MOD - // We can just map `NonUniformResourceIndex` type directly to the index type on CPU, as CPU does not require // any special handling around such accesses. typedef size_t NonUniformResourceIndex; @@ -1484,12 +938,6 @@ struct ComputeVaryingInput typedef void(*ComputeThreadFunc)(ComputeThreadVaryingInput* varyingInput, void* uniformEntryPointParams, void* uniformState); typedef void(*ComputeFunc)(ComputeVaryingInput* varyingInput, void* uniformEntryPointParams, void* uniformState); -template<typename TResult, typename TInput> -TResult slang_bit_cast(TInput val) -{ - return *(TResult*)(&val); -} - #ifdef SLANG_PRELUDE_NAMESPACE } #endif diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 7a4c5a918..9a55aed57 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -6,6 +6,8 @@ #define SLANG_CUDA_RTC 0 #endif +#include <stdio.h> + // Define SLANG_CUDA_ENABLE_HALF to use the cuda_fp16 include to add half support. // For this to work NVRTC needs to have the path to the CUDA SDK. // @@ -2080,4 +2082,73 @@ __forceinline__ __device__ void *traceOptiXRay( r0, r1 ); } + #endif + + +// TensorView +struct TensorView +{ + uint8_t* data; + uint32_t* strides; + uint32_t* sizes; + uint32_t dimensionCount; + + template<typename T> + __device__ T* data_ptr() + { + return reinterpret_cast<T*>(data); + } + + template<typename T> + __device__ T load(uint32_t x) + { + return *reinterpret_cast<T*>(data + strides[0] * x); + } + template<typename T> + __device__ T load(uint32_t x, uint32_t y) + { + return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y); + } + template<typename T> + __device__ T load(uint32_t x, uint32_t y, uint32_t z) + { + return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z); + } + template<typename T> + __device__ T load(uint32_t x, uint32_t y, uint32_t z, uint32_t w) + { + return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z + strides[3] * w); + } + template<typename T> + __device__ T load(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4) + { + return *reinterpret_cast<T*>(data + strides[0] * i0 + strides[1] * i1 + strides[2] * i2 + strides[3] * i3 + strides[4] * i4); + } + template<typename T> + __device__ void store(uint32_t x, T val) + { + *reinterpret_cast<T*>(data + strides[0] * x) = val; + } + template<typename T> + __device__ void store(uint32_t x, uint32_t y, T val) + { + *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y) = val; + } + template<typename T> + __device__ void store(uint32_t x, uint32_t y, uint32_t z, T val) + { + *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z) = val; + } + template<typename T> + __device__ void store(uint32_t x, uint32_t y, uint32_t z, uint32_t w, T val) + { + *reinterpret_cast<T*>( + data + strides[0] * x + strides[1] * y + strides[2] * z + strides[3] * w) = val; + } + template<typename T> + __device__ void store(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4, T val) + { + *reinterpret_cast<T*>(data + strides[0] * i0 + strides[1] * i1 + strides[2] * i2 + strides[3] * i3 + strides[4] * i4) = val; + } +}; diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h new file mode 100644 index 000000000..f2accc149 --- /dev/null +++ b/prelude/slang-torch-prelude.h @@ -0,0 +1,126 @@ +// Prelude for PyTorch cpp binding. + +#include <torch/extension.h> +#include <ATen/cuda/CUDAContext.h> +#include <ATen/cuda/CUDAUtils.h> +#include <vector> + +#ifndef SLANG_NO_THROW +# define SLANG_NO_THROW +#endif + +#ifndef SLANG_STDCALL +# define SLANG_STDCALL +#endif +#ifndef SLANG_MCALL +# define SLANG_MCALL SLANG_STDCALL +#endif +#ifndef SLANG_FORCE_INLINE +# define SLANG_FORCE_INLINE inline +#endif + +#ifdef SLANG_LLVM +#include "slang-llvm.h" +#else // SLANG_LLVM +# if SLANG_GCC_FAMILY && __GNUC__ < 6 +# include <cmath> +# define SLANG_PRELUDE_STD std:: +# else +# include <math.h> +# define SLANG_PRELUDE_STD +# endif + +# include <assert.h> +# include <stdlib.h> +# include <string.h> +# include <stdint.h> +#endif // SLANG_LLVM + +#include "slang-cpp-types-core.h" +#include "slang-cpp-scalar-intrinsics.h" + +struct TensorView +{ + uint8_t* data; + uint32_t* strides; + uint32_t* sizes; + uint32_t dimensionCount; +}; + +struct CudaTaskMemoryAllocator +{ + std::vector<void*> allocations; + + uint32_t* allocUIntArray(uint32_t size) + { + void* ptr = nullptr; + cudaMallocManaged(&ptr, size * sizeof(uint32_t)); + AT_CUDA_CHECK(cudaGetLastError()); + return (uint32_t*)ptr; + } + + ~CudaTaskMemoryAllocator() + { + for (auto ptr : allocations) + cudaFree(ptr); + } +}; + +TensorView make_tensor_view(CudaTaskMemoryAllocator* allocator, torch::Tensor val) +{ + val = val.to(torch::kCUDA); + TensorView res = {}; + res.dimensionCount = val.dim(); + res.strides = allocator->allocUIntArray(val.dim()); + res.sizes = allocator->allocUIntArray(val.dim()); + res.data = nullptr; + size_t elementSize = 4; + switch (val.scalar_type()) + { + case torch::kInt8: + case torch::kUInt8: + elementSize = 1; + res.data = (uint8_t*)val.data_ptr<uint8_t>(); + break; + case torch::kBFloat16: + elementSize = 2; + res.data = (uint8_t*)val.data_ptr<torch::BFloat16>(); + break; + case torch::kInt16: + elementSize = 2; + res.data = (uint8_t*)val.data_ptr<int16_t>(); + break; + case torch::kFloat32: + elementSize = 4; + res.data = (uint8_t*)val.data_ptr<float>(); + break; + case torch::kInt32: + elementSize = 4; + res.data = (uint8_t*)val.data_ptr<int32_t>(); + break; + case torch::kFloat64: + elementSize = 8; + res.data = (uint8_t*)val.data_ptr<double>(); + break; + case torch::kInt64: + elementSize = 8; + res.data = (uint8_t*)val.data_ptr<int64_t>(); + break; + } + for (int i = 0; i < val.dim(); ++i) + { + res.strides[i] = val.stride(i) * elementSize; + res.sizes[i] = val.size(i); + } + return res; +} + +size_t slangGetCudaKernelSharedMemSize(const void* func) +{ + cudaFuncAttributes attr = {}; + cudaFuncGetAttributes(&attr, func); + AT_CUDA_CHECK(cudaGetLastError()); + return attr.sharedSizeBytes; +} + +#define SLANG_PRELUDE_EXPORT |
