diff options
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 0725103da..f2fd8e3b0 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -246,52 +246,66 @@ __intrinsic_type($(kIROp_TorchTensorType)) struct TorchTensor { __intrinsic_op($(kIROp_TorchTensorGetView)) + [CudaHost] TensorView<T> getView(); __target_intrinsic(cuda, "$0.dims()") __target_intrinsic(cpp, "$0.dims()") [__readNone] + [CudaHost] uint dims(); __target_intrinsic(cuda, "$0.size($1)") __target_intrinsic(cpp, "$0.size($1)") [__readNone] + [CudaHost] uint size(uint i); __target_intrinsic(cuda, "$0.stride($1)") __target_intrinsic(cpp, "$0.stride($1)") [__readNone] + [CudaHost] uint stride(uint i); __target_intrinsic(cuda, "$0.data_ptr<$G0>()") __target_intrinsic(cpp, "$0.data_ptr<$G0>()") [__readNone] + [CudaHost] Ptr<T> data_ptr(); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor<T> alloc(uint x); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor<T> alloc(uint x, uint y); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor<T> alloc(uint x, uint y, uint z); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor<T> alloc(uint x, uint y, uint z, uint w); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor<T> alloc(uint i0, uint i1, uint i2, uint i3, uint i4); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor<T> emptyLike(TorchTensor<T> other); __target_intrinsic(cpp, "$0.zero_()") + [CudaHost] void fillZero(); __target_intrinsic(cpp, "$0.fill_($1)") + [CudaHost] void fillValue(T val); + [CudaHost] static TorchTensor<T> zerosLike(TorchTensor<T> other) { var result = emptyLike(other); @@ -854,8 +868,10 @@ T detach(T x); #define SLANG_SQR(x) ((x)*(x)) +#define SLANG_SIGN(x) select(((x)>T(0.0)), ReturnType(T(1.0)), select(((x)==T(0.0)), ReturnType(T(0.0)), ReturnType(T(-1.0)))) + // Absolute value -UNARY_DERIVATIVE_IMPL(abs, select(dpx.p > T(0.0), dpx.d, ReturnType.dmul(T(-1.0), dpx.d)), (ReturnType.dmul(__slang_noop_cast<ReturnType>(sign(dpx.p)), dOut))) +UNARY_DERIVATIVE_IMPL(abs, ReturnType.dmul(SLANG_SIGN(dpx.p), dpx.d), ReturnType.dmul(SLANG_SIGN(dpx.p), dOut)) // Saturate UNARY_DERIVATIVE_IMPL(saturate, select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dpx.d), select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dOut)) // frac |
