diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-26 13:59:11 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-26 13:59:11 -0700 |
| commit | d64ee86a3130f8eeb75d09193c38c621d7565eba (patch) | |
| tree | fed25a0cc2a7372d26175774f5983bed693e6b64 /source/slang/diff.meta.slang | |
| parent | 666af0962b6ab41489a3a3287db83f77c2f6461a (diff) | |
Add PyTorch C++ binding generation. (#2734)
* Add PyTorch C++ binding generation.
* fix
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 84 |
1 files changed, 83 insertions, 1 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index bbe94dbc2..d5b70bbb3 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -9,7 +9,6 @@ attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [PrimalSubstitute(function)] : PrimalSubstituteAttribute; - __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; @@ -26,6 +25,89 @@ attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [NoDiffThis] : NoDiffThisAttribute; +__generic<T> +__magic_type(TensorViewType) +__intrinsic_type($(kIROp_TensorViewType)) +struct TensorView +{ + __target_intrinsic(cuda, "$0.data_ptr<$G0>()") + Ptr<T> data_ptr(); + + __implicit_conversion($(kConversionCost_ImplicitDereference)) + __intrinsic_op($(kIROp_TorchTensorGetView)) + __init(TorchTensor<T> t); + + __target_intrinsic(cuda, "$0.load<$G0>($1)") + T load(uint x); + __target_intrinsic(cuda, "$0.load<$G0>($1, $2)") + T load(uint x, uint y); + __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3)") + T load(uint x, uint y, uint z); + __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3, $4)") + T load(uint x, uint y, uint z, uint w); + __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3, $4, $5)") + T load(uint i0, uint i1, uint i2, uint i3, uint i4); + + __target_intrinsic(cuda, "$0.store<$G0>($1, $2)") + void store(uint x, T val); + __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3)") + void store(uint x, uint y, T val); + __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4)") + void store(uint x, uint y, uint z, T val); + __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4, $5)") + void store(uint x, uint y, uint z, uint w, T val); + __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4, $5, $6)") + void store(uint i0, uint i1, uint i2, uint i3, uint i4, T val); + + __target_intrinsic(cuda, "$0.dimensionCount") + uint dims(); + + __target_intrinsic(cuda, "$0.sizes[$1]") + uint size(uint i); + + __target_intrinsic(cuda, "$0.strides[$1]") + uint stride(uint i); +} + +__generic<T> +__intrinsic_type($(kIROp_TorchTensorType)) +struct TorchTensor +{ + __intrinsic_op($(kIROp_TorchTensorGetView)) + TensorView<T> getView(); + + __target_intrinsic(cuda, "$0.dims()") + __target_intrinsic(cpp, "$0.dims()") + uint dims(); + + __target_intrinsic(cuda, "$0.size($1)") + __target_intrinsic(cpp, "$0.size($1)") + uint size(uint i); + + __target_intrinsic(cuda, "$0.stride($1)") + __target_intrinsic(cpp, "$0.stride($1)") + uint stride(uint i); + + __target_intrinsic(cuda, "$0.data_ptr<$G0>()") + __target_intrinsic(cpp, "$0.data_ptr<$G0>()") + Ptr<T> data_ptr(); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> alloc(uint x); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> alloc(uint x, uint y); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> alloc(uint x, uint y, uint z); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> alloc(uint x, uint y, uint z, uint w); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> alloc(uint i0, uint i1, uint i2, uint i3, uint i4); +} + __generic<T: IDifferentiable> __intrinsic_op($(kIROp_MakeDifferentialPairUserCode)) DifferentialPair<T> diffPair(T primal, T.Differential diff); |
