diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/diff.meta.slang | 27 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 32 |
2 files changed, 56 insertions, 3 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index f2f1d0cc3..84a72a425 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -367,6 +367,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [ForceInline] [ForwardDerivativeOf(transpose)] [PreferRecompute] +[BackwardDifferentiable] DifferentialPair<matrix<T, M, N>> __d_transpose(DifferentialPair<matrix<T, N, M>> m) { return DifferentialPair<matrix<T, M, N>>(transpose(m.p), transpose(m.d)); @@ -376,6 +377,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [ForceInline] [BackwardDerivativeOf(transpose)] [PreferRecompute] +[BackwardDifferentiable] void __d_transpose(inout DifferentialPair<matrix<T, N, M>> m, matrix<T, M, N>.Differential dOut) { m = diffPair(m.p, transpose(dOut)); @@ -386,6 +388,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [ForceInline] [ForwardDerivativeOf(mul)] [PreferRecompute] +[BackwardDifferentiable] DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, DifferentialPair<matrix<T, N, M>> right) { let primal = mul(left.p, right.p); @@ -396,13 +399,16 @@ DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, Differen __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [BackwardDerivativeOf(mul)] [PreferRecompute] +[BackwardDifferentiable] void __d_mul(inout DifferentialPair<vector<T, N>> left, inout DifferentialPair<matrix<T, N, M>> right, vector<T, M>.Differential dOut) { vector<T, N>.Differential left_d_result; matrix<T, N, M>.Differential right_d_result; + [ForceUnroll] for (int i = 0; i < N; ++i) { T sum = T(0); + [ForceUnroll] for (int j = 0; j < M; ++j) { sum += right.p[i][j] * dOut[j]; @@ -419,6 +425,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [ForceInline] [ForwardDerivativeOf(mul)] [PreferRecompute] +[BackwardDifferentiable] DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, DifferentialPair<vector<T,M>> right) { let primal = mul(left.p, right.p); @@ -429,13 +436,16 @@ DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, Differen __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [BackwardDerivativeOf(mul)] [PreferRecompute] +[BackwardDifferentiable] void __d_mul(inout DifferentialPair<matrix<T, N, M>> left, inout DifferentialPair<vector<T, M>> right, vector<T, N>.Differential dOut) { matrix<T, N, M>.Differential left_d_result; vector<T, M>.Differential right_d_result; + [ForceUnroll] for (int j = 0; j < M; ++j) { T sum = T(0); + [ForceUnroll] for (int i = 0; i < N; ++i) { sum += left.p[i][j] * dOut[i]; @@ -452,6 +462,7 @@ __generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int> [ForceInline] [ForwardDerivativeOf(mul)] [PreferRecompute] +[BackwardDifferentiable] DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> left, DifferentialPair<matrix<T,N,C>> right) { let primal = mul(left.p, right.p); @@ -462,22 +473,30 @@ DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> left, Differ __generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int> [BackwardDerivativeOf(mul)] [PreferRecompute] +[BackwardDifferentiable] void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<matrix<T, N, C>> right, matrix<T, R, C>.Differential dOut) { matrix<T, R, N>.Differential left_d_result; + [ForceUnroll] for (int r = 0; r < R; ++r) + [ForceUnroll] for (int n = 0; n < N; ++n) left_d_result[r][n] = T(0.0); - + matrix<T, N, C>.Differential right_d_result; + [ForceUnroll] for (int n = 0; n < N; ++n) + [ForceUnroll] for (int c = 0; c < C; ++c) right_d_result[n][c] = T(0.0); - + + [ForceUnroll] for (int r = 0; r < R; ++r) { + [ForceUnroll] for (int c = 0; c < C; ++c) { + [ForceUnroll] for (int n = 0; n < N; ++n) { left_d_result[r][n] += right.p[n][c] * dOut[r][c]; @@ -493,6 +512,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma __generic<T : __BuiltinFloatingPointType, let N : int> [ForwardDerivativeOf(dot)] [PreferRecompute] +[BackwardDifferentiable] DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair<vector<T, N>> dpy) { T result = T(0); @@ -510,6 +530,7 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair __generic<T : __BuiltinFloatingPointType, let N : int> [BackwardDerivativeOf(dot)] [PreferRecompute] +[BackwardDifferentiable] void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, T.Differential dOut) { vector<T, N>.Differential x_d_result, y_d_result; @@ -527,6 +548,7 @@ void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<ve __generic<T : __BuiltinFloatingPointType> [ForwardDerivativeOf(cross)] [PreferRecompute] +[BackwardDifferentiable] DifferentialPair<vector<T, 3>> __d_cross(DifferentialPair<vector<T, 3>> a, DifferentialPair<vector<T, 3>> b) { /* @@ -555,6 +577,7 @@ DifferentialPair<vector<T, 3>> __d_cross(DifferentialPair<vector<T, 3>> a, Diffe __generic<T : __BuiltinFloatingPointType> [BackwardDerivativeOf(cross)] [PreferRecompute] +[BackwardDifferentiable] void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<vector<T, 3>> b, vector<T, 3>.Differential dOut) { /* diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index e8cb821bd..c479ea6d1 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -2042,6 +2042,36 @@ struct DiffTransposePass } } + void safeSetInsertAfterInst(IRBuilder* builder, IRInst* inst) + { + // If the inst is in the first or second block of the parent function, then + // insert into the third block, otherwise simply call setInsertAfterOrdinaryInst. + // The second block is the block that the first block branches into unconditionaly. + // + if (auto block = as<IRBlock>(inst->getParent())) + { + auto firstBlock = cast<IRFunc>(block->getParent())->getFirstBlock(); + if (auto firstBranch = as<IRUnconditionalBranch>(firstBlock->getTerminator())) + { + auto secondBlock = firstBranch->getTargetBlock(); + + if (block == firstBlock || block == secondBlock) + { + if (auto branch = as<IRUnconditionalBranch>(secondBlock->getTerminator())) + { + if (auto ordInst = branch->getTargetBlock()->getFirstOrdinaryInst()) + builder->setInsertAfter(ordInst); + else + builder->setInsertInto(branch->getTargetBlock()); + + return; + } + } + } + } + setInsertAfterOrdinaryInst(builder, inst); + } + IRInst* promoteOperandsToTargetType(IRBuilder* builder, IRInst* fwdInst) { auto oldLoc = builder->getInsertLoc(); @@ -2060,7 +2090,7 @@ struct DiffTransposePass // Insert new operand just after the old operand, so we have the old // operands available. // - builder->setInsertAfter(operand); + safeSetInsertAfterInst(builder, operand); IRInst* newOperand = promoteToType(builder, targetType, operand); |
