summaryrefslogtreecommitdiffstats
path: root/prelude
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-05-09 18:00:48 -0700
committerGitHub <noreply@github.com>2023-05-09 18:00:48 -0700
commitddebd60853b3f34bfd8e89de804fd15808abf75d (patch)
treed5d686843bc2c67e493693376a0170857998c077 /prelude
parent38ed03a7203baacf36fca62539ac74fd45ed42d2 (diff)
Various fixes for autodiff and slangpy. (#2876)
* Various fixes for autodiff and slangpy. * Fix cuda code gen for `select`. * Fix getBuildTagString(). * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'prelude')
-rw-r--r--prelude/slang-cpp-host-prelude.h16
-rw-r--r--prelude/slang-cpp-types-core.h17
-rw-r--r--prelude/slang-cuda-prelude.h32
-rw-r--r--prelude/slang-torch-prelude.h17
4 files changed, 82 insertions, 0 deletions
diff --git a/prelude/slang-cpp-host-prelude.h b/prelude/slang-cpp-host-prelude.h
index 0a026471c..f69d03eed 100644
--- a/prelude/slang-cpp-host-prelude.h
+++ b/prelude/slang-cpp-host-prelude.h
@@ -28,6 +28,22 @@
# include <stdint.h>
#endif // SLANG_LLVM
+#if defined(_MSC_VER)
+# define SLANG_PRELUDE_SHARED_LIB_EXPORT __declspec(dllexport)
+#else
+# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__((__visibility__("default")))
+//# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__ ((dllexport)) __attribute__((__visibility__("default")))
+#endif
+
+#ifdef __cplusplus
+# define SLANG_PRELUDE_EXTERN_C extern "C"
+# define SLANG_PRELUDE_EXTERN_C_START extern "C" {
+# define SLANG_PRELUDE_EXTERN_C_END }
+#else
+# define SLANG_PRELUDE_EXTERN_C
+# define SLANG_PRELUDE_EXTERN_C_START
+# define SLANG_PRELUDE_EXTERN_C_END
+#endif
#include "slang-cpp-scalar-intrinsics.h"
diff --git a/prelude/slang-cpp-types-core.h b/prelude/slang-cpp-types-core.h
index c49ee013c..25fe47202 100644
--- a/prelude/slang-cpp-types-core.h
+++ b/prelude/slang-cpp-types-core.h
@@ -203,6 +203,23 @@ struct Vector<T, 4>
};
template<typename T, int N>
+SLANG_FORCE_INLINE Vector<T, N> _slang_select(Vector<bool, N> condition, Vector<T, N> v0, Vector<T, N> v1)
+{
+ Vector<T, N> result;
+ for (int i = 0; i < N; i++)
+ {
+ result[i] = condition[i] ? v0[i] : v1[i];
+ }
+ return result;
+}
+
+template<typename T>
+SLANG_FORCE_INLINE T _slang_select(bool condition, T v0, T v1)
+{
+ return condition ? v0 : v1;
+}
+
+template<typename T, int N>
SLANG_FORCE_INLINE T _slang_vector_get_element(Vector<T, N> x, int index)
{
return x[index];
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 0e0349bd7..0362a3a67 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -743,11 +743,43 @@ SLANG_FLOAT_MATRIX_OPS(__half)
#undef SLANG_MATRIX_INT_NEG_OP
#undef SLANG_FLOAT_MATRIX_MOD
+#define SLANG_SELECT_IMPL(T, N)\
+SLANG_FORCE_INLINE SLANG_CUDA_CALL Vector<T, N> _slang_select(bool##N condition, Vector<T, N> v0, Vector<T, N> v1) \
+{ \
+ Vector<T, N> result; \
+ for (int i = 0; i < N; i++) \
+ { \
+ *_slang_vector_get_element_ptr(&result, i) = _slang_vector_get_element(condition, i) ? _slang_vector_get_element(v0, i) : _slang_vector_get_element(v1, i); \
+ } \
+ return result; \
+}
+#define SLANG_SELECT_T(T)\
+ SLANG_SELECT_IMPL(T, 2)\
+ SLANG_SELECT_IMPL(T, 3)\
+ SLANG_SELECT_IMPL(T, 4)
+
+SLANG_SELECT_T(int)
+SLANG_SELECT_T(uint)
+SLANG_SELECT_T(short)
+SLANG_SELECT_T(ushort)
+SLANG_SELECT_T(char)
+SLANG_SELECT_T(uchar)
+SLANG_SELECT_T(float)
+SLANG_SELECT_T(double)
+
+template<typename T>
+SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_select(bool condition, T v0, T v1)
+{
+ return condition ? v0 : v1;
+}
+
//
// Half support
//
#if SLANG_CUDA_ENABLE_HALF
+SLANG_SELECT_T(__half)
+
// Convenience functions ushort -> half
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 __ushort_as_half(const ushort2& i) { return __halves2half2(__ushort_as_half(i.x), __ushort_as_half(i.y)); }
diff --git a/prelude/slang-torch-prelude.h b/prelude/slang-torch-prelude.h
index 2f5273e1f..0a69b8305 100644
--- a/prelude/slang-torch-prelude.h
+++ b/prelude/slang-torch-prelude.h
@@ -38,6 +38,23 @@
# include <stdint.h>
#endif // SLANG_LLVM
+#if defined(_MSC_VER)
+# define SLANG_PRELUDE_SHARED_LIB_EXPORT __declspec(dllexport)
+#else
+# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__((__visibility__("default")))
+//# define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__ ((dllexport)) __attribute__((__visibility__("default")))
+#endif
+
+#ifdef __cplusplus
+# define SLANG_PRELUDE_EXTERN_C extern "C"
+# define SLANG_PRELUDE_EXTERN_C_START extern "C" {
+# define SLANG_PRELUDE_EXTERN_C_END }
+#else
+# define SLANG_PRELUDE_EXTERN_C
+# define SLANG_PRELUDE_EXTERN_C_START
+# define SLANG_PRELUDE_EXTERN_C_END
+#endif
+
#define SLANG_PRELUDE_NAMESPACE