diff options
| author | kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> | 2025-02-05 12:12:49 -0600 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-05 10:12:49 -0800 |
| commit | 4b350645042b8e8fbdad19784ee745d11c7bc616 (patch) | |
| tree | acc7cb6b3cc82d71acf66ab2d4d052598aa72535 /source | |
| parent | d3e5f39cafbf1d65cf93cdd42c20c472c68197a2 (diff) | |
Fix autodiff issue for vector<T, N> (#6275)
* Fix autodiff issue for vector<T, N>
Close #6154
We didn't implement correctly for vector<T, N> regarding the
differentiablity. As we check differentiable before specialization,
however according to the definition of vector, it has to be specialized
to IFloat to know it's conformed to IDifferential type. Therefore for
parameter type vector<T, N> will become no_diff.
Therefore, we change the implementation a to make it explicit conform
to IDifferential type.
* fix typo
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 26 |
1 files changed, 20 insertions, 6 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 36e9d6885..19421735c 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1988,8 +1988,18 @@ extension vector<T,N> : IFloat [OverloadRank(-1)] [__unsafeForceInlineEarly] __init(float v) { this = vector<T,N>(T(v)); } - // IDifferentiable +} + +__intrinsic_op($(kIROp_Add)) +T __internal_add<T>(T a, T b); +__intrinsic_op($(kIROp_Mul)) +T __internal_mul<T, U>(U a, T b); + +__generic<T:IDifferentiable, let N : int> +extension vector<T,N> : IDifferentiable +{ + // IDifferentiable typedef vector<T, N> Differential; [__unsafeForceInlineEarly] @@ -2003,7 +2013,7 @@ extension vector<T,N> : IFloat [BackwardDifferentiable] static Differential dadd(Differential a, Differential b) { - return a + b; + return __internal_add(a, b); } __generic<U : __BuiltinRealType> @@ -2011,7 +2021,7 @@ extension vector<T,N> : IFloat [BackwardDifferentiable] static Differential dmul(U a, Differential b) { - return __realCast<T, U>(a) * b; + return __internal_mul(__realCast<float>(a), b); } } @@ -2042,7 +2052,11 @@ extension matrix<T,N,M,L> : IFloat [__unsafeForceInlineEarly] __implicit_conversion($(kConversionCost_ScalarToMatrix)) __init(float v) { this = matrix<T,N,M>(T(v)); } +} +__generic<T:IDifferentiable, let N : int, let M : int, let L : int> +extension matrix<T,N,M,L> : IDifferentiable +{ // IDifferentiable. typedef matrix<T, N,M,L> Differential; @@ -2057,15 +2071,15 @@ extension matrix<T,N,M,L> : IFloat [BackwardDifferentiable] static Differential dadd(Differential a, Differential b) { - return a + b; + return __internal_add(a, b); } - + __generic<U : __BuiltinRealType> [__unsafeForceInlineEarly] [BackwardDifferentiable] static Differential dmul(U a, Differential b) { - return __realCast<T, U>(a) * b; + return __internal_mul(__realCast<float>(a), b); } } |
