summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-27 18:32:20 -0700
committerGitHub <noreply@github.com>2023-04-27 18:32:20 -0700
commit53793612e3a2f1cadc4f7cbf703bcd94b7121414 (patch)
treeb995fb1e7b91817439f6f51f2489362b8b027a81 /source/slang/diff.meta.slang
parent60d829091cc97eef4fd36211afe8a83ad282c4de (diff)
Embed stdlib documentation to AST. (#2851)
* Embed stdlib documentation to AST. * Extract documentation for attributes. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang6
1 files changed, 5 insertions, 1 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index f8b36a3ac..f2f1d0cc3 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -25,6 +25,7 @@ attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [NoDiffThis] : NoDiffThisAttribute;
+/// Represents a GPU view of a tensor.
__generic<T>
__magic_type(TensorViewType)
__intrinsic_type($(kIROp_TensorViewType))
@@ -231,6 +232,7 @@ extension TensorView<float>
void InterlockedCompareExchange(vector<uint, N> index, float compare, float val);
}
+/// Represents the handle of a Torch tensor object.
__generic<T>
__intrinsic_type($(kIROp_TorchTensorType))
struct TorchTensor
@@ -294,10 +296,12 @@ struct TorchTensor
__target_intrinsic(cpp, "AT_CUDA_CHECK(cudaStreamSynchronize(at::cuda::getCurrentCUDAStream()))")
void syncTorchCudaStream();
+/// Constructs a `DifferentialPair` value from a primal value and a differential value.
__generic<T: IDifferentiable>
__intrinsic_op($(kIROp_MakeDifferentialPairUserCode))
DifferentialPair<T> diffPair(T primal, T.Differential diff);
+/// Constructs a `DifferentialPair` value from a primal value and a zero differential value.
__generic<T: IDifferentiable>
[__unsafeForceInlineEarly]
DifferentialPair<T> diffPair(T primal)
@@ -812,7 +816,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
}
#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, ReturnType.dmul(DIFF_FUNC, dpx.d), ReturnType.dmul(DIFF_FUNC, dOut))
-// Detach and set derivatives to zero
+/// Detach and set derivatives to zero.
__generic<T : IDifferentiable>
__intrinsic_op($(kIROp_DetachDerivative))
T detach(T x);