summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-22 21:16:35 -0700
committerGitHub <noreply@github.com>2023-03-22 21:16:35 -0700
commit259a015feb9d4ab65e8fbba32f6c777e92780cc7 (patch)
tree45bd4cb9217325c67f5a27d8562b0e7e6b79bb77 /source/slang/diff.meta.slang
parentd4f99c8bac8b28f18c864a717d8833db6a1c872d (diff)
Type legalization and autodiff bug fixes. (#2722)
* Bug fixes. * Fix. * Only perform autodiff for functions whose derivative is actually used. * Fix loop optimize bug. * Fix high order diff. * Fix trivial diff func generation. * Fixes. * Cleanup. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang28
1 files changed, 16 insertions, 12 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index a9b8209f3..26a673512 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -1022,31 +1022,35 @@ __generic<T : __BuiltinFloatingPointType, let N : int>
[__readNone]
T __determinant_impl(matrix<T,N,N> m)
{
- if (N == 1)
- return m[0][0];
- else if (N == 2)
- return m[0][0] * m[1][1] - m[0][1] * m[1][0];
- else if (N == 3)
+ T result = T(0);
+ switch (N)
{
- return m[0][0] * (m[1][1] * m[2][2] - m[2][1] * m[1][2])
+ case 1:
+ result = m[0][0];
+ break;
+ case 2:
+ result = m[0][0] * m[1][1] - m[0][1] * m[1][0];
+ break;
+ case 3:
+ result = m[0][0] * (m[1][1] * m[2][2] - m[2][1] * m[1][2])
- m[1][0] * (m[0][1] * m[2][2] - m[2][1] * m[0][2])
+ m[2][0] * (m[0][1] * m[1][2] - m[1][1] * m[0][2]);
- }
- else if (N == 4)
- {
- T s00 = m[2][2] * m[3][3] - m[3][2] * m[2][3];
+ break;
+ case 4:
+ T s00 = m[2][2] * m[3][3] - m[3][2] * m[2][3];
T s01 = m[2][1] * m[3][3] - m[3][1] * m[2][3];
T s02 = m[2][1] * m[3][2] - m[3][1] * m[2][2];
T s03 = m[2][0] * m[3][3] - m[3][0] * m[2][3];
T s04 = m[2][0] * m[3][2] - m[3][0] * m[2][2];
T s05 = m[2][0] * m[3][1] - m[3][0] * m[2][1];
- return m[0][0] * (m[1][1] * s00 - m[1][2] * s01 + m[1][3] * s02)
+ result = m[0][0] * (m[1][1] * s00 - m[1][2] * s01 + m[1][3] * s02)
- m[0][1] * (m[1][0] * s00 - m[1][2] * s03 + m[1][3] * s04)
+ m[0][2] * (m[1][0] * s01 - m[1][1] * s03 + m[1][3] * s05)
- m[0][3] * (m[1][0] * s02 - m[1][1] * s04 + m[1][2] * s05);
+ break;
}
- return T(0.0);
+ return result;
}
__generic<T : __BuiltinFloatingPointType, let N : int>
[ForwardDerivativeOf(determinant)]