diff options
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 84 |
1 files changed, 83 insertions, 1 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index bbe94dbc2..d5b70bbb3 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -9,7 +9,6 @@ attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [PrimalSubstitute(function)] : PrimalSubstituteAttribute; - __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; @@ -26,6 +25,89 @@ attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [NoDiffThis] : NoDiffThisAttribute; +__generic<T> +__magic_type(TensorViewType) +__intrinsic_type($(kIROp_TensorViewType)) +struct TensorView +{ + __target_intrinsic(cuda, "$0.data_ptr<$G0>()") + Ptr<T> data_ptr(); + + __implicit_conversion($(kConversionCost_ImplicitDereference)) + __intrinsic_op($(kIROp_TorchTensorGetView)) + __init(TorchTensor<T> t); + + __target_intrinsic(cuda, "$0.load<$G0>($1)") + T load(uint x); + __target_intrinsic(cuda, "$0.load<$G0>($1, $2)") + T load(uint x, uint y); + __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3)") + T load(uint x, uint y, uint z); + __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3, $4)") + T load(uint x, uint y, uint z, uint w); + __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3, $4, $5)") + T load(uint i0, uint i1, uint i2, uint i3, uint i4); + + __target_intrinsic(cuda, "$0.store<$G0>($1, $2)") + void store(uint x, T val); + __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3)") + void store(uint x, uint y, T val); + __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4)") + void store(uint x, uint y, uint z, T val); + __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4, $5)") + void store(uint x, uint y, uint z, uint w, T val); + __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4, $5, $6)") + void store(uint i0, uint i1, uint i2, uint i3, uint i4, T val); + + __target_intrinsic(cuda, "$0.dimensionCount") + uint dims(); + + __target_intrinsic(cuda, "$0.sizes[$1]") + uint size(uint i); + + __target_intrinsic(cuda, "$0.strides[$1]") + uint stride(uint i); +} + +__generic<T> +__intrinsic_type($(kIROp_TorchTensorType)) +struct TorchTensor +{ + __intrinsic_op($(kIROp_TorchTensorGetView)) + TensorView<T> getView(); + + __target_intrinsic(cuda, "$0.dims()") + __target_intrinsic(cpp, "$0.dims()") + uint dims(); + + __target_intrinsic(cuda, "$0.size($1)") + __target_intrinsic(cpp, "$0.size($1)") + uint size(uint i); + + __target_intrinsic(cuda, "$0.stride($1)") + __target_intrinsic(cpp, "$0.stride($1)") + uint stride(uint i); + + __target_intrinsic(cuda, "$0.data_ptr<$G0>()") + __target_intrinsic(cpp, "$0.data_ptr<$G0>()") + Ptr<T> data_ptr(); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> alloc(uint x); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> alloc(uint x, uint y); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> alloc(uint x, uint y, uint z); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> alloc(uint x, uint y, uint z, uint w); + + __intrinsic_op($(kIROp_AllocateTorchTensor)) + static TorchTensor<T> alloc(uint i0, uint i1, uint i2, uint i3, uint i4); +} + __generic<T: IDifferentiable> __intrinsic_op($(kIROp_MakeDifferentialPairUserCode)) DifferentialPair<T> diffPair(T primal, T.Differential diff); |
