From 739c3a7b53dc6489065fcd5e9f0a04370c5f9c8f Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 19 Sep 2023 18:51:24 -0400 Subject: Added `[AutoPyBindCUDA]` for automatic kernel binding + `[PyExport]` for exporting type information (#3209) * Initial: add a DiffTensor impl * Auto-binding and diff tensor implementations now work * Refactored diff-tensor implementation + added py-export for struct types * Cleanup * Update slang-ir-pytorch-cpp-binding.cpp * Updated test names * Update autodiff-data-flow.slang.expected * Add more versions of load/store & default generic args for DiffTensorView. * Add diagnostic for default generic arg and more tests * Add more `[AutoPyBind]` tests --- source/slang/diff.meta.slang | 292 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 292 insertions(+) (limited to 'source/slang/diff.meta.slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 423b6bfd0..495b6b989 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -264,6 +264,298 @@ extension TensorView void InterlockedCompareExchange(vector index, float compare, float val); } +interface IDiffTensorWrapper +{ + __generic + T load_forward(uint offset); + + __generic + T load_forward_2(uint2 offset); + + __generic + T load_forward_3(uint3 offset); + + __generic + T load_forward_4(uint4 offset); + + __generic + void load_backward(uint offset, T dOut); + + __generic + void load_backward_2(uint2 offset, T dOut); + + __generic + void load_backward_3(uint3 offset, T dOut); + + __generic + void load_backward_4(uint4 offset, T dOut); + + __generic + void store_forward(uint offset, T dx); + + __generic + void store_forward_2(uint2 offset, T dx); + + __generic + void store_forward_3(uint3 offset, T dx); + + __generic + void store_forward_4(uint4 offset, T dx); + + __generic + T store_backward(uint offset); + + __generic + T store_backward_2(uint2 offset); + + __generic + T store_backward_3(uint3 offset); + + __generic + T store_backward_4(uint4 offset); +}; + +struct AtomicAdd : IDiffTensorWrapper +{ + TensorView diff; + + __generic + T load_forward(uint i) + { + return __realCast(diff.load(i)); + } + + __generic + T load_forward_2(uint2 i) + { + return __realCast(diff.load(i.x, i.y)); + } + + __generic + T load_forward_3(uint3 i) + { + return __realCast(diff.load(i.x, i.y, i.z)); + } + + __generic + T load_forward_4(uint4 i) + { + return __realCast(diff.load(i.x, i.y, i.z, i.w)); + } + + __generic + void load_backward(uint i, T dOut) + { + float oldVal; + diff.InterlockedAdd(i, __realCast(dOut), oldVal); + } + + __generic + void load_backward_2(uint2 i, T dOut) + { + float oldVal; + diff.InterlockedAdd(i, __realCast(dOut), oldVal); + } + + __generic + void load_backward_3(uint3 i, T dOut) + { + float oldVal; + diff.InterlockedAdd(i, __realCast(dOut), oldVal); + } + + __generic + void load_backward_4(uint4 i, T dOut) + { + float oldVal; + diff.InterlockedAdd(i, __realCast(dOut), oldVal); + } + + __generic + void store_forward(uint i, T dx) + { + diff.store(i, __realCast(dx)); + } + + __generic + void store_forward_2(uint2 i, T dx) + { + diff.store(i.x, i.y, __realCast(dx)); + } + + __generic + void store_forward_3(uint3 i, T dx) + { + diff.store(i.x, i.y, i.z, __realCast(dx)); + } + + __generic + void store_forward_4(uint4 i, T dx) + { + diff.store(i.x, i.y, i.z, i.w, __realCast(dx)); + } + + __generic + T store_backward(uint i) + { + float oldVal; + diff.InterlockedExchange(i, (float)0, oldVal); + return __realCast(oldVal); + } + + __generic + T store_backward_2(uint2 i) + { + float oldVal; + diff.InterlockedExchange(i, (float)0, oldVal); + return __realCast(oldVal); + } + + __generic + T store_backward_3(uint3 i) + { + float oldVal; + diff.InterlockedExchange(i, (float)0, oldVal); + return __realCast(oldVal); + } + + __generic + T store_backward_4(uint4 i) + { + float oldVal; + diff.InterlockedExchange(i, (float)0, oldVal); + return __realCast(oldVal); + } +}; + +__generic +struct DiffTensorView +{ + TensorView primal; + A diff; + + uint size(uint i) + { + return primal.size(i); + } + + [BackwardDerivative(load_backward)] + [ForwardDerivative(load_forward)] + T load(uint i) { return primal.load(i); } + + [BackwardDerivative(load_backward)] + [ForwardDerivative(load_forward)] + T load(uint2 i) { return primal.load(i.x, i.y); } + + [BackwardDerivative(load_backward)] + [ForwardDerivative(load_forward)] + T load(uint3 i) { return primal.load(i.x, i.y, i.z); } + + [BackwardDerivative(load_backward)] + [ForwardDerivative(load_forward)] + T load(uint4 i) { return primal.load(i.x, i.y, i.z, i.w); } + + DifferentialPair load_forward(uint x) + { + return diffPair(primal.load(x), reinterpret(diff.load_forward(x))); + } + + DifferentialPair load_forward(uint2 x) + { + return diffPair(primal.load(x.x, x.y), reinterpret(diff.load_forward_2(x))); + } + + DifferentialPair load_forward(uint3 x) + { + return diffPair(primal.load(x.x, x.y, x.z), reinterpret(diff.load_forward_3(x))); + } + + DifferentialPair load_forward(uint4 x) + { + return diffPair(primal.load(x.x, x.y, x.z, x.w), reinterpret(diff.load_forward_4(x))); + } + + void load_backward(uint x, T.Differential dOut) + { + diff.load_backward(x, reinterpret(dOut)); + } + + void load_backward(uint2 x, T.Differential dOut) + { + diff.load_backward_2(x, reinterpret(dOut)); + } + + void load_backward(uint3 x, T.Differential dOut) + { + diff.load_backward_3(x, reinterpret(dOut)); + } + + void load_backward(uint4 x, T.Differential dOut) + { + diff.load_backward_4(x, reinterpret(dOut)); + } + + [BackwardDerivative(store_backward)] + [ForwardDerivative(store_forward)] + void store(uint x, T val) { primal.store(x, val); } + + [BackwardDerivative(store_backward)] + [ForwardDerivative(store_forward)] + void store(uint2 x, T val) { primal.store(x.x, x.y, val); } + + [BackwardDerivative(store_backward)] + [ForwardDerivative(store_forward)] + void store(uint3 x, T val) { primal.store(x.x, x.y, x.z, val); } + + [BackwardDerivative(store_backward)] + [ForwardDerivative(store_forward)] + void store(uint4 x, T val) { primal.store(x.x, x.y, x.z, x.w, val); } + + void store_forward(uint x, DifferentialPair dpval) + { + primal.store(x, dpval.p); + diff.store_forward(x, reinterpret(dpval.d)); + } + + void store_forward(uint2 x, DifferentialPair dpval) + { + primal.store(x.x, x.y, dpval.p); + diff.store_forward_2(x, reinterpret(dpval.d)); + } + + void store_forward(uint3 x, DifferentialPair dpval) + { + primal.store(x.x, x.y, x.z, dpval.p); + diff.store_forward_3(x, reinterpret(dpval.d)); + } + + void store_forward(uint4 x, DifferentialPair dpval) + { + primal.store(x.x, x.y, x.z, x.w, dpval.p); + diff.store_forward_4(x, reinterpret(dpval.d)); + } + + void store_backward(uint x, inout DifferentialPair dpval) + { + dpval = diffPair(dpval.p, reinterpret(diff.store_backward(x))); + } + + void store_backward(uint2 x, inout DifferentialPair dpval) + { + dpval = diffPair(dpval.p, reinterpret(diff.store_backward_2(x))); + } + + void store_backward(uint3 x, inout DifferentialPair dpval) + { + dpval = diffPair(dpval.p, reinterpret(diff.store_backward_3(x))); + } + + void store_backward(uint4 x, inout DifferentialPair dpval) + { + dpval = diffPair(dpval.p, reinterpret(diff.store_backward_4(x))); + } +}; + /// Represents the handle of a Torch tensor object. __generic __intrinsic_type($(kIROp_TorchTensorType)) -- cgit v1.2.3