summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang84
1 files changed, 83 insertions, 1 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index bbe94dbc2..d5b70bbb3 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -9,7 +9,6 @@ attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [PrimalSubstitute(function)] : PrimalSubstituteAttribute;
-
__attributeTarget(FunctionDeclBase)
attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute;
@@ -26,6 +25,89 @@ attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [NoDiffThis] : NoDiffThisAttribute;
+__generic<T>
+__magic_type(TensorViewType)
+__intrinsic_type($(kIROp_TensorViewType))
+struct TensorView
+{
+ __target_intrinsic(cuda, "$0.data_ptr<$G0>()")
+ Ptr<T> data_ptr();
+
+ __implicit_conversion($(kConversionCost_ImplicitDereference))
+ __intrinsic_op($(kIROp_TorchTensorGetView))
+ __init(TorchTensor<T> t);
+
+ __target_intrinsic(cuda, "$0.load<$G0>($1)")
+ T load(uint x);
+ __target_intrinsic(cuda, "$0.load<$G0>($1, $2)")
+ T load(uint x, uint y);
+ __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3)")
+ T load(uint x, uint y, uint z);
+ __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3, $4)")
+ T load(uint x, uint y, uint z, uint w);
+ __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3, $4, $5)")
+ T load(uint i0, uint i1, uint i2, uint i3, uint i4);
+
+ __target_intrinsic(cuda, "$0.store<$G0>($1, $2)")
+ void store(uint x, T val);
+ __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3)")
+ void store(uint x, uint y, T val);
+ __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4)")
+ void store(uint x, uint y, uint z, T val);
+ __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4, $5)")
+ void store(uint x, uint y, uint z, uint w, T val);
+ __target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4, $5, $6)")
+ void store(uint i0, uint i1, uint i2, uint i3, uint i4, T val);
+
+ __target_intrinsic(cuda, "$0.dimensionCount")
+ uint dims();
+
+ __target_intrinsic(cuda, "$0.sizes[$1]")
+ uint size(uint i);
+
+ __target_intrinsic(cuda, "$0.strides[$1]")
+ uint stride(uint i);
+}
+
+__generic<T>
+__intrinsic_type($(kIROp_TorchTensorType))
+struct TorchTensor
+{
+ __intrinsic_op($(kIROp_TorchTensorGetView))
+ TensorView<T> getView();
+
+ __target_intrinsic(cuda, "$0.dims()")
+ __target_intrinsic(cpp, "$0.dims()")
+ uint dims();
+
+ __target_intrinsic(cuda, "$0.size($1)")
+ __target_intrinsic(cpp, "$0.size($1)")
+ uint size(uint i);
+
+ __target_intrinsic(cuda, "$0.stride($1)")
+ __target_intrinsic(cpp, "$0.stride($1)")
+ uint stride(uint i);
+
+ __target_intrinsic(cuda, "$0.data_ptr<$G0>()")
+ __target_intrinsic(cpp, "$0.data_ptr<$G0>()")
+ Ptr<T> data_ptr();
+
+ __intrinsic_op($(kIROp_AllocateTorchTensor))
+ static TorchTensor<T> alloc(uint x);
+
+ __intrinsic_op($(kIROp_AllocateTorchTensor))
+ static TorchTensor<T> alloc(uint x, uint y);
+
+ __intrinsic_op($(kIROp_AllocateTorchTensor))
+ static TorchTensor<T> alloc(uint x, uint y, uint z);
+
+ __intrinsic_op($(kIROp_AllocateTorchTensor))
+ static TorchTensor<T> alloc(uint x, uint y, uint z, uint w);
+
+ __intrinsic_op($(kIROp_AllocateTorchTensor))
+ static TorchTensor<T> alloc(uint i0, uint i1, uint i2, uint i3, uint i4);
+}
+
__generic<T: IDifferentiable>
__intrinsic_op($(kIROp_MakeDifferentialPairUserCode))
DifferentialPair<T> diffPair(T primal, T.Differential diff);