summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorkaizhangNV <149626564+kaizhangNV@users.noreply.github.com>2025-02-05 12:12:49 -0600
committerGitHub <noreply@github.com>2025-02-05 10:12:49 -0800
commit4b350645042b8e8fbdad19784ee745d11c7bc616 (patch)
treeacc7cb6b3cc82d71acf66ab2d4d052598aa72535 /source
parentd3e5f39cafbf1d65cf93cdd42c20c472c68197a2 (diff)
Fix autodiff issue for vector<T, N> (#6275)
* Fix autodiff issue for vector<T, N> Close #6154 We didn't implement correctly for vector<T, N> regarding the differentiablity. As we check differentiable before specialization, however according to the definition of vector, it has to be specialized to IFloat to know it's conformed to IDifferential type. Therefore for parameter type vector<T, N> will become no_diff. Therefore, we change the implementation a to make it explicit conform to IDifferential type. * fix typo
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang26
1 files changed, 20 insertions, 6 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 36e9d6885..19421735c 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -1988,8 +1988,18 @@ extension vector<T,N> : IFloat
[OverloadRank(-1)]
[__unsafeForceInlineEarly] __init(float v) { this = vector<T,N>(T(v)); }
- // IDifferentiable
+}
+
+__intrinsic_op($(kIROp_Add))
+T __internal_add<T>(T a, T b);
+__intrinsic_op($(kIROp_Mul))
+T __internal_mul<T, U>(U a, T b);
+
+__generic<T:IDifferentiable, let N : int>
+extension vector<T,N> : IDifferentiable
+{
+ // IDifferentiable
typedef vector<T, N> Differential;
[__unsafeForceInlineEarly]
@@ -2003,7 +2013,7 @@ extension vector<T,N> : IFloat
[BackwardDifferentiable]
static Differential dadd(Differential a, Differential b)
{
- return a + b;
+ return __internal_add(a, b);
}
__generic<U : __BuiltinRealType>
@@ -2011,7 +2021,7 @@ extension vector<T,N> : IFloat
[BackwardDifferentiable]
static Differential dmul(U a, Differential b)
{
- return __realCast<T, U>(a) * b;
+ return __internal_mul(__realCast<float>(a), b);
}
}
@@ -2042,7 +2052,11 @@ extension matrix<T,N,M,L> : IFloat
[__unsafeForceInlineEarly]
__implicit_conversion($(kConversionCost_ScalarToMatrix))
__init(float v) { this = matrix<T,N,M>(T(v)); }
+}
+__generic<T:IDifferentiable, let N : int, let M : int, let L : int>
+extension matrix<T,N,M,L> : IDifferentiable
+{
// IDifferentiable.
typedef matrix<T, N,M,L> Differential;
@@ -2057,15 +2071,15 @@ extension matrix<T,N,M,L> : IFloat
[BackwardDifferentiable]
static Differential dadd(Differential a, Differential b)
{
- return a + b;
+ return __internal_add(a, b);
}
-
+
__generic<U : __BuiltinRealType>
[__unsafeForceInlineEarly]
[BackwardDifferentiable]
static Differential dmul(U a, Differential b)
{
- return __realCast<T, U>(a) * b;
+ return __internal_mul(__realCast<float>(a), b);
}
}