summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-09-19 18:51:24 -0400
committerGitHub <noreply@github.com>2023-09-19 18:51:24 -0400
commit739c3a7b53dc6489065fcd5e9f0a04370c5f9c8f (patch)
tree593c86cbc184476479c66554cc6784b454bdec66 /source/slang/diff.meta.slang
parent359fdc9d556b4c493c588c5b8f93df85933634f8 (diff)
Added `[AutoPyBindCUDA]` for automatic kernel binding + `[PyExport]` for exporting type information (#3209)
* Initial: add a DiffTensor impl * Auto-binding and diff tensor implementations now work * Refactored diff-tensor implementation + added py-export for struct types * Cleanup * Update slang-ir-pytorch-cpp-binding.cpp * Updated test names * Update autodiff-data-flow.slang.expected * Add more versions of load/store & default generic args for DiffTensorView. * Add diagnostic for default generic arg and more tests * Add more `[AutoPyBind]` tests
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang292
1 files changed, 292 insertions, 0 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 423b6bfd0..495b6b989 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -264,6 +264,298 @@ extension TensorView<float>
void InterlockedCompareExchange(vector<uint, N> index, float compare, float val);
}
+interface IDiffTensorWrapper
+{
+ __generic<T : __BuiltinFloatingPointType>
+ T load_forward(uint offset);
+
+ __generic<T : __BuiltinFloatingPointType>
+ T load_forward_2(uint2 offset);
+
+ __generic<T : __BuiltinFloatingPointType>
+ T load_forward_3(uint3 offset);
+
+ __generic<T : __BuiltinFloatingPointType>
+ T load_forward_4(uint4 offset);
+
+ __generic<T : __BuiltinFloatingPointType>
+ void load_backward(uint offset, T dOut);
+
+ __generic<T : __BuiltinFloatingPointType>
+ void load_backward_2(uint2 offset, T dOut);
+
+ __generic<T : __BuiltinFloatingPointType>
+ void load_backward_3(uint3 offset, T dOut);
+
+ __generic<T : __BuiltinFloatingPointType>
+ void load_backward_4(uint4 offset, T dOut);
+
+ __generic<T : __BuiltinFloatingPointType>
+ void store_forward(uint offset, T dx);
+
+ __generic<T : __BuiltinFloatingPointType>
+ void store_forward_2(uint2 offset, T dx);
+
+ __generic<T : __BuiltinFloatingPointType>
+ void store_forward_3(uint3 offset, T dx);
+
+ __generic<T : __BuiltinFloatingPointType>
+ void store_forward_4(uint4 offset, T dx);
+
+ __generic<T : __BuiltinFloatingPointType>
+ T store_backward(uint offset);
+
+ __generic<T : __BuiltinFloatingPointType>
+ T store_backward_2(uint2 offset);
+
+ __generic<T : __BuiltinFloatingPointType>
+ T store_backward_3(uint3 offset);
+
+ __generic<T : __BuiltinFloatingPointType>
+ T store_backward_4(uint4 offset);
+};
+
+struct AtomicAdd : IDiffTensorWrapper
+{
+ TensorView<float> diff;
+
+ __generic<T : __BuiltinFloatingPointType>
+ T load_forward(uint i)
+ {
+ return __realCast<T, float>(diff.load(i));
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ T load_forward_2(uint2 i)
+ {
+ return __realCast<T, float>(diff.load(i.x, i.y));
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ T load_forward_3(uint3 i)
+ {
+ return __realCast<T, float>(diff.load(i.x, i.y, i.z));
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ T load_forward_4(uint4 i)
+ {
+ return __realCast<T, float>(diff.load(i.x, i.y, i.z, i.w));
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ void load_backward(uint i, T dOut)
+ {
+ float oldVal;
+ diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ void load_backward_2(uint2 i, T dOut)
+ {
+ float oldVal;
+ diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ void load_backward_3(uint3 i, T dOut)
+ {
+ float oldVal;
+ diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ void load_backward_4(uint4 i, T dOut)
+ {
+ float oldVal;
+ diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ void store_forward(uint i, T dx)
+ {
+ diff.store(i, __realCast<float, T>(dx));
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ void store_forward_2(uint2 i, T dx)
+ {
+ diff.store(i.x, i.y, __realCast<float, T>(dx));
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ void store_forward_3(uint3 i, T dx)
+ {
+ diff.store(i.x, i.y, i.z, __realCast<float, T>(dx));
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ void store_forward_4(uint4 i, T dx)
+ {
+ diff.store(i.x, i.y, i.z, i.w, __realCast<float, T>(dx));
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ T store_backward(uint i)
+ {
+ float oldVal;
+ diff.InterlockedExchange(i, (float)0, oldVal);
+ return __realCast<T, float>(oldVal);
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ T store_backward_2(uint2 i)
+ {
+ float oldVal;
+ diff.InterlockedExchange(i, (float)0, oldVal);
+ return __realCast<T, float>(oldVal);
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ T store_backward_3(uint3 i)
+ {
+ float oldVal;
+ diff.InterlockedExchange(i, (float)0, oldVal);
+ return __realCast<T, float>(oldVal);
+ }
+
+ __generic<T : __BuiltinFloatingPointType>
+ T store_backward_4(uint4 i)
+ {
+ float oldVal;
+ diff.InterlockedExchange(i, (float)0, oldVal);
+ return __realCast<T, float>(oldVal);
+ }
+};
+
+__generic<T: __BuiltinFloatingPointType = float, A : IDiffTensorWrapper = AtomicAdd>
+struct DiffTensorView
+{
+ TensorView<T> primal;
+ A diff;
+
+ uint size(uint i)
+ {
+ return primal.size(i);
+ }
+
+ [BackwardDerivative(load_backward)]
+ [ForwardDerivative(load_forward)]
+ T load(uint i) { return primal.load(i); }
+
+ [BackwardDerivative(load_backward)]
+ [ForwardDerivative(load_forward)]
+ T load(uint2 i) { return primal.load(i.x, i.y); }
+
+ [BackwardDerivative(load_backward)]
+ [ForwardDerivative(load_forward)]
+ T load(uint3 i) { return primal.load(i.x, i.y, i.z); }
+
+ [BackwardDerivative(load_backward)]
+ [ForwardDerivative(load_forward)]
+ T load(uint4 i) { return primal.load(i.x, i.y, i.z, i.w); }
+
+ DifferentialPair<T> load_forward(uint x)
+ {
+ return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.load_forward<T>(x)));
+ }
+
+ DifferentialPair<T> load_forward(uint2 x)
+ {
+ return diffPair(primal.load(x.x, x.y), reinterpret<T.Differential, T>(diff.load_forward_2<T>(x)));
+ }
+
+ DifferentialPair<T> load_forward(uint3 x)
+ {
+ return diffPair(primal.load(x.x, x.y, x.z), reinterpret<T.Differential, T>(diff.load_forward_3<T>(x)));
+ }
+
+ DifferentialPair<T> load_forward(uint4 x)
+ {
+ return diffPair(primal.load(x.x, x.y, x.z, x.w), reinterpret<T.Differential, T>(diff.load_forward_4<T>(x)));
+ }
+
+ void load_backward(uint x, T.Differential dOut)
+ {
+ diff.load_backward<T>(x, reinterpret<T, T.Differential>(dOut));
+ }
+
+ void load_backward(uint2 x, T.Differential dOut)
+ {
+ diff.load_backward_2<T>(x, reinterpret<T, T.Differential>(dOut));
+ }
+
+ void load_backward(uint3 x, T.Differential dOut)
+ {
+ diff.load_backward_3<T>(x, reinterpret<T, T.Differential>(dOut));
+ }
+
+ void load_backward(uint4 x, T.Differential dOut)
+ {
+ diff.load_backward_4<T>(x, reinterpret<T, T.Differential>(dOut));
+ }
+
+ [BackwardDerivative(store_backward)]
+ [ForwardDerivative(store_forward)]
+ void store(uint x, T val) { primal.store(x, val); }
+
+ [BackwardDerivative(store_backward)]
+ [ForwardDerivative(store_forward)]
+ void store(uint2 x, T val) { primal.store(x.x, x.y, val); }
+
+ [BackwardDerivative(store_backward)]
+ [ForwardDerivative(store_forward)]
+ void store(uint3 x, T val) { primal.store(x.x, x.y, x.z, val); }
+
+ [BackwardDerivative(store_backward)]
+ [ForwardDerivative(store_forward)]
+ void store(uint4 x, T val) { primal.store(x.x, x.y, x.z, x.w, val); }
+
+ void store_forward(uint x, DifferentialPair<T> dpval)
+ {
+ primal.store(x, dpval.p);
+ diff.store_forward<T>(x, reinterpret<T, T.Differential>(dpval.d));
+ }
+
+ void store_forward(uint2 x, DifferentialPair<T> dpval)
+ {
+ primal.store(x.x, x.y, dpval.p);
+ diff.store_forward_2<T>(x, reinterpret<T, T.Differential>(dpval.d));
+ }
+
+ void store_forward(uint3 x, DifferentialPair<T> dpval)
+ {
+ primal.store(x.x, x.y, x.z, dpval.p);
+ diff.store_forward_3<T>(x, reinterpret<T, T.Differential>(dpval.d));
+ }
+
+ void store_forward(uint4 x, DifferentialPair<T> dpval)
+ {
+ primal.store(x.x, x.y, x.z, x.w, dpval.p);
+ diff.store_forward_4<T>(x, reinterpret<T, T.Differential>(dpval.d));
+ }
+
+ void store_backward(uint x, inout DifferentialPair<T> dpval)
+ {
+ dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward<T>(x)));
+ }
+
+ void store_backward(uint2 x, inout DifferentialPair<T> dpval)
+ {
+ dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward_2<T>(x)));
+ }
+
+ void store_backward(uint3 x, inout DifferentialPair<T> dpval)
+ {
+ dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward_3<T>(x)));
+ }
+
+ void store_backward(uint4 x, inout DifferentialPair<T> dpval)
+ {
+ dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward_4<T>(x)));
+ }
+};
+
/// Represents the handle of a Torch tensor object.
__generic<T>
__intrinsic_type($(kIROp_TorchTensorType))