summaryrefslogtreecommitdiff
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang9
1 files changed, 9 insertions, 0 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index b39d91494..6042ff5cc 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -855,17 +855,20 @@ struct DiffTensorView
return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.load_forward<T>(x)));
}
+ [ForceInline]
__generic<let N : int>
DifferentialPair<T> __load_forward(vector<uint, N> x)
{
return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.load_forward<T, N>(x)));
}
+ [ForceInline]
void __load_backward(uint x, T.Differential dOut)
{
diff.load_backward<T>(x, reinterpret<T, T.Differential>(dOut));
}
+ [ForceInline]
__generic<let N : int>
void __load_backward(vector<uint, N> x, T.Differential dOut)
{
@@ -894,11 +897,13 @@ struct DiffTensorView
diff.store_forward<T, N>(x, reinterpret<T, T.Differential>(dpval.d));
}
+ [ForceInline]
void __store_backward(uint x, inout DifferentialPair<T> dpval)
{
dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward<T>(x)));
}
+ [ForceInline]
__generic<let N : int>
void __store_backward(vector<uint, N> x, inout DifferentialPair<T> dpval)
{
@@ -999,11 +1004,13 @@ struct DiffTensorView
return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.loadOnce_forward<T, N>(x)));
}
+ [ForceInline]
void __loadOnce_backward(uint x, T.Differential dOut)
{
diff.loadOnce_backward<T>(x, reinterpret<T, T.Differential>(dOut));
}
+ [ForceInline]
__generic<let N : int>
void __loadOnce_backward(vector<uint, N> x, T.Differential dOut)
{
@@ -1032,11 +1039,13 @@ struct DiffTensorView
diff.storeOnce_forward<T, N>(x, reinterpret<T, T.Differential>(dpval.d));
}
+ [ForceInline]
void __storeOnce_backward(uint x, inout DifferentialPair<T> dpval)
{
dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.storeOnce_backward<T>(x)));
}
+ [ForceInline]
__generic<let N : int>
void __storeOnce_backward(vector<uint, N> x, inout DifferentialPair<T> dpval)
{