diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-22 21:16:35 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-22 21:16:35 -0700 |
| commit | 259a015feb9d4ab65e8fbba32f6c777e92780cc7 (patch) | |
| tree | 45bd4cb9217325c67f5a27d8562b0e7e6b79bb77 /source/slang/diff.meta.slang | |
| parent | d4f99c8bac8b28f18c864a717d8833db6a1c872d (diff) | |
Type legalization and autodiff bug fixes. (#2722)
* Bug fixes.
* Fix.
* Only perform autodiff for functions whose derivative is actually used.
* Fix loop optimize bug.
* Fix high order diff.
* Fix trivial diff func generation.
* Fixes.
* Cleanup.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
| -rw-r--r-- | source/slang/diff.meta.slang | 28 |
1 files changed, 16 insertions, 12 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index a9b8209f3..26a673512 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1022,31 +1022,35 @@ __generic<T : __BuiltinFloatingPointType, let N : int> [__readNone] T __determinant_impl(matrix<T,N,N> m) { - if (N == 1) - return m[0][0]; - else if (N == 2) - return m[0][0] * m[1][1] - m[0][1] * m[1][0]; - else if (N == 3) + T result = T(0); + switch (N) { - return m[0][0] * (m[1][1] * m[2][2] - m[2][1] * m[1][2]) + case 1: + result = m[0][0]; + break; + case 2: + result = m[0][0] * m[1][1] - m[0][1] * m[1][0]; + break; + case 3: + result = m[0][0] * (m[1][1] * m[2][2] - m[2][1] * m[1][2]) - m[1][0] * (m[0][1] * m[2][2] - m[2][1] * m[0][2]) + m[2][0] * (m[0][1] * m[1][2] - m[1][1] * m[0][2]); - } - else if (N == 4) - { - T s00 = m[2][2] * m[3][3] - m[3][2] * m[2][3]; + break; + case 4: + T s00 = m[2][2] * m[3][3] - m[3][2] * m[2][3]; T s01 = m[2][1] * m[3][3] - m[3][1] * m[2][3]; T s02 = m[2][1] * m[3][2] - m[3][1] * m[2][2]; T s03 = m[2][0] * m[3][3] - m[3][0] * m[2][3]; T s04 = m[2][0] * m[3][2] - m[3][0] * m[2][2]; T s05 = m[2][0] * m[3][1] - m[3][0] * m[2][1]; - return m[0][0] * (m[1][1] * s00 - m[1][2] * s01 + m[1][3] * s02) + result = m[0][0] * (m[1][1] * s00 - m[1][2] * s01 + m[1][3] * s02) - m[0][1] * (m[1][0] * s00 - m[1][2] * s03 + m[1][3] * s04) + m[0][2] * (m[1][0] * s01 - m[1][1] * s03 + m[1][3] * s05) - m[0][3] * (m[1][0] * s02 - m[1][1] * s04 + m[1][2] * s05); + break; } - return T(0.0); + return result; } __generic<T : __BuiltinFloatingPointType, let N : int> [ForwardDerivativeOf(determinant)] |
