diff options
| author | Yong He <yonghe@outlook.com> | 2023-05-09 18:00:48 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-05-09 18:00:48 -0700 |
| commit | ddebd60853b3f34bfd8e89de804fd15808abf75d (patch) | |
| tree | d5d686843bc2c67e493693376a0170857998c077 /prelude | |
| parent | 38ed03a7203baacf36fca62539ac74fd45ed42d2 (diff) | |
Various fixes for autodiff and slangpy. (#2876)
* Various fixes for autodiff and slangpy.
* Fix cuda code gen for `select`.
* Fix getBuildTagString().
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'prelude')
| -rw-r--r-- | prelude/slang-cpp-host-prelude.h | 16 | ||||
| -rw-r--r-- | prelude/slang-cpp-types-core.h | 17 | ||||
| -rw-r--r-- | prelude/slang-cuda-prelude.h | 32 | ||||
| -rw-r--r-- | prelude/slang-torch-prelude.h | 17 |
4 files changed, 82 insertions, 0 deletions
diff --git a/prelude/slang-cpp-host-prelude.h b/prelude/slang-cpp-host-prelude.h index 0a026471c..f69d03eed 100644 --- a/prelude/slang-cpp-host-prelude.h +++ b/prelude/slang-cpp-host-prelude.h @@ -28,6 +28,22 @@ # include <stdint.h> #endif // SLANG_LLVM +#if defined(_MSC_VER) +# define SLANG_PRELUDE_SHARED_LIB_EXPORT __declspec(dllexport) +#else +# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__((__visibility__("default"))) +//# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__ ((dllexport)) __attribute__((__visibility__("default"))) +#endif + +#ifdef __cplusplus +# define SLANG_PRELUDE_EXTERN_C extern "C" +# define SLANG_PRELUDE_EXTERN_C_START extern "C" { +# define SLANG_PRELUDE_EXTERN_C_END } +#else +# define SLANG_PRELUDE_EXTERN_C +# define SLANG_PRELUDE_EXTERN_C_START +# define SLANG_PRELUDE_EXTERN_C_END +#endif #include "slang-cpp-scalar-intrinsics.h" diff --git a/prelude/slang-cpp-types-core.h b/prelude/slang-cpp-types-core.h index c49ee013c..25fe47202 100644 --- a/prelude/slang-cpp-types-core.h +++ b/prelude/slang-cpp-types-core.h @@ -203,6 +203,23 @@ struct Vector<T, 4> }; template<typename T, int N> +SLANG_FORCE_INLINE Vector<T, N> _slang_select(Vector<bool, N> condition, Vector<T, N> v0, Vector<T, N> v1) +{ + Vector<T, N> result; + for (int i = 0; i < N; i++) + { + result[i] = condition[i] ? v0[i] : v1[i]; + } + return result; +} + +template<typename T> +SLANG_FORCE_INLINE T _slang_select(bool condition, T v0, T v1) +{ + return condition ? v0 : v1; +} + +template<typename T, int N> SLANG_FORCE_INLINE T _slang_vector_get_element(Vector<T, N> x, int index) { return x[index]; diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 0e0349bd7..0362a3a67 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -743,11 +743,43 @@ SLANG_FLOAT_MATRIX_OPS(__half) #undef SLANG_MATRIX_INT_NEG_OP #undef SLANG_FLOAT_MATRIX_MOD +#define SLANG_SELECT_IMPL(T, N)\ +SLANG_FORCE_INLINE SLANG_CUDA_CALL Vector<T, N> _slang_select(bool##N condition, Vector<T, N> v0, Vector<T, N> v1) \ +{ \ + Vector<T, N> result; \ + for (int i = 0; i < N; i++) \ + { \ + *_slang_vector_get_element_ptr(&result, i) = _slang_vector_get_element(condition, i) ? _slang_vector_get_element(v0, i) : _slang_vector_get_element(v1, i); \ + } \ + return result; \ +} +#define SLANG_SELECT_T(T)\ + SLANG_SELECT_IMPL(T, 2)\ + SLANG_SELECT_IMPL(T, 3)\ + SLANG_SELECT_IMPL(T, 4) + +SLANG_SELECT_T(int) +SLANG_SELECT_T(uint) +SLANG_SELECT_T(short) +SLANG_SELECT_T(ushort) +SLANG_SELECT_T(char) +SLANG_SELECT_T(uchar) +SLANG_SELECT_T(float) +SLANG_SELECT_T(double) + +template<typename T> +SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_select(bool condition, T v0, T v1) +{ + return condition ? v0 : v1; +} + // // Half support // #if SLANG_CUDA_ENABLE_HALF +SLANG_SELECT_T(__half) + // Convenience functions ushort -> half SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 __ushort_as_half(const ushort2& i) { return __halves2half2(__ushort_as_half(i.x), __ushort_as_half(i.y)); } diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h index 2f5273e1f..0a69b8305 100644 --- a/prelude/slang-torch-prelude.h +++ b/prelude/slang-torch-prelude.h @@ -38,6 +38,23 @@ # include <stdint.h> #endif // SLANG_LLVM +#if defined(_MSC_VER) +# define SLANG_PRELUDE_SHARED_LIB_EXPORT __declspec(dllexport) +#else +# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__((__visibility__("default"))) +//# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__ ((dllexport)) __attribute__((__visibility__("default"))) +#endif + +#ifdef __cplusplus +# define SLANG_PRELUDE_EXTERN_C extern "C" +# define SLANG_PRELUDE_EXTERN_C_START extern "C" { +# define SLANG_PRELUDE_EXTERN_C_END } +#else +# define SLANG_PRELUDE_EXTERN_C +# define SLANG_PRELUDE_EXTERN_C_START +# define SLANG_PRELUDE_EXTERN_C_END +#endif + #define SLANG_PRELUDE_NAMESPACE |
