From ddebd60853b3f34bfd8e89de804fd15808abf75d Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 9 May 2023 18:00:48 -0700 Subject: Various fixes for autodiff and slangpy. (#2876) * Various fixes for autodiff and slangpy. * Fix cuda code gen for `select`. * Fix getBuildTagString(). * Fix. --------- Co-authored-by: Yong He --- source/slang/diff.meta.slang | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) (limited to 'source/slang/diff.meta.slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 0725103da..f2fd8e3b0 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -246,52 +246,66 @@ __intrinsic_type($(kIROp_TorchTensorType)) struct TorchTensor { __intrinsic_op($(kIROp_TorchTensorGetView)) + [CudaHost] TensorView getView(); __target_intrinsic(cuda, "$0.dims()") __target_intrinsic(cpp, "$0.dims()") [__readNone] + [CudaHost] uint dims(); __target_intrinsic(cuda, "$0.size($1)") __target_intrinsic(cpp, "$0.size($1)") [__readNone] + [CudaHost] uint size(uint i); __target_intrinsic(cuda, "$0.stride($1)") __target_intrinsic(cpp, "$0.stride($1)") [__readNone] + [CudaHost] uint stride(uint i); __target_intrinsic(cuda, "$0.data_ptr<$G0>()") __target_intrinsic(cpp, "$0.data_ptr<$G0>()") [__readNone] + [CudaHost] Ptr data_ptr(); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor alloc(uint x); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor alloc(uint x, uint y); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor alloc(uint x, uint y, uint z); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor alloc(uint x, uint y, uint z, uint w); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor alloc(uint i0, uint i1, uint i2, uint i3, uint i4); __intrinsic_op($(kIROp_AllocateTorchTensor)) + [CudaHost] static TorchTensor emptyLike(TorchTensor other); __target_intrinsic(cpp, "$0.zero_()") + [CudaHost] void fillZero(); __target_intrinsic(cpp, "$0.fill_($1)") + [CudaHost] void fillValue(T val); + [CudaHost] static TorchTensor zerosLike(TorchTensor other) { var result = emptyLike(other); @@ -854,8 +868,10 @@ T detach(T x); #define SLANG_SQR(x) ((x)*(x)) +#define SLANG_SIGN(x) select(((x)>T(0.0)), ReturnType(T(1.0)), select(((x)==T(0.0)), ReturnType(T(0.0)), ReturnType(T(-1.0)))) + // Absolute value -UNARY_DERIVATIVE_IMPL(abs, select(dpx.p > T(0.0), dpx.d, ReturnType.dmul(T(-1.0), dpx.d)), (ReturnType.dmul(__slang_noop_cast(sign(dpx.p)), dOut))) +UNARY_DERIVATIVE_IMPL(abs, ReturnType.dmul(SLANG_SIGN(dpx.p), dpx.d), ReturnType.dmul(SLANG_SIGN(dpx.p), dOut)) // Saturate UNARY_DERIVATIVE_IMPL(saturate, select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dpx.d), select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dOut)) // frac -- cgit v1.2.3