summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-05-04 16:43:40 -0400
committerGitHub <noreply@github.com>2023-05-04 13:43:40 -0700
commita9444925750da2498f456a626f1b164d68efedf1 (patch)
treeffdbf4c28b4ca01f50a74f45ef27834482b209c6
parentab3ac985479132856613d6dcc8b43f4a4ef8c6b7 (diff)
Fix issue with out-of-order insts during type promotion when transposing code. (#2866)
* Bugfixes for warped-area reparameterization * Update slang-ir-autodiff-transpose.h * Update slang-ir-autodiff-transpose.h * Mark all stdlib methods backward differentiable * Update diff.meta.slang --------- Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--source/slang/diff.meta.slang27
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h32
2 files changed, 56 insertions, 3 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index f2f1d0cc3..84a72a425 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -367,6 +367,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
[ForwardDerivativeOf(transpose)]
[PreferRecompute]
+[BackwardDifferentiable]
DifferentialPair<matrix<T, M, N>> __d_transpose(DifferentialPair<matrix<T, N, M>> m)
{
return DifferentialPair<matrix<T, M, N>>(transpose(m.p), transpose(m.d));
@@ -376,6 +377,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
[BackwardDerivativeOf(transpose)]
[PreferRecompute]
+[BackwardDifferentiable]
void __d_transpose(inout DifferentialPair<matrix<T, N, M>> m, matrix<T, M, N>.Differential dOut)
{
m = diffPair(m.p, transpose(dOut));
@@ -386,6 +388,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
[ForwardDerivativeOf(mul)]
[PreferRecompute]
+[BackwardDifferentiable]
DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, DifferentialPair<matrix<T, N, M>> right)
{
let primal = mul(left.p, right.p);
@@ -396,13 +399,16 @@ DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, Differen
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[BackwardDerivativeOf(mul)]
[PreferRecompute]
+[BackwardDifferentiable]
void __d_mul(inout DifferentialPair<vector<T, N>> left, inout DifferentialPair<matrix<T, N, M>> right, vector<T, M>.Differential dOut)
{
vector<T, N>.Differential left_d_result;
matrix<T, N, M>.Differential right_d_result;
+ [ForceUnroll]
for (int i = 0; i < N; ++i)
{
T sum = T(0);
+ [ForceUnroll]
for (int j = 0; j < M; ++j)
{
sum += right.p[i][j] * dOut[j];
@@ -419,6 +425,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
[ForwardDerivativeOf(mul)]
[PreferRecompute]
+[BackwardDifferentiable]
DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, DifferentialPair<vector<T,M>> right)
{
let primal = mul(left.p, right.p);
@@ -429,13 +436,16 @@ DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, Differen
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[BackwardDerivativeOf(mul)]
[PreferRecompute]
+[BackwardDifferentiable]
void __d_mul(inout DifferentialPair<matrix<T, N, M>> left, inout DifferentialPair<vector<T, M>> right, vector<T, N>.Differential dOut)
{
matrix<T, N, M>.Differential left_d_result;
vector<T, M>.Differential right_d_result;
+ [ForceUnroll]
for (int j = 0; j < M; ++j)
{
T sum = T(0);
+ [ForceUnroll]
for (int i = 0; i < N; ++i)
{
sum += left.p[i][j] * dOut[i];
@@ -452,6 +462,7 @@ __generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int>
[ForceInline]
[ForwardDerivativeOf(mul)]
[PreferRecompute]
+[BackwardDifferentiable]
DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> left, DifferentialPair<matrix<T,N,C>> right)
{
let primal = mul(left.p, right.p);
@@ -462,22 +473,30 @@ DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> left, Differ
__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int>
[BackwardDerivativeOf(mul)]
[PreferRecompute]
+[BackwardDifferentiable]
void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<matrix<T, N, C>> right, matrix<T, R, C>.Differential dOut)
{
matrix<T, R, N>.Differential left_d_result;
+ [ForceUnroll]
for (int r = 0; r < R; ++r)
+ [ForceUnroll]
for (int n = 0; n < N; ++n)
left_d_result[r][n] = T(0.0);
-
+
matrix<T, N, C>.Differential right_d_result;
+ [ForceUnroll]
for (int n = 0; n < N; ++n)
+ [ForceUnroll]
for (int c = 0; c < C; ++c)
right_d_result[n][c] = T(0.0);
-
+
+ [ForceUnroll]
for (int r = 0; r < R; ++r)
{
+ [ForceUnroll]
for (int c = 0; c < C; ++c)
{
+ [ForceUnroll]
for (int n = 0; n < N; ++n)
{
left_d_result[r][n] += right.p[n][c] * dOut[r][c];
@@ -493,6 +512,7 @@ void mul(inout DifferentialPair<matrix<T, R, N>> left, inout DifferentialPair<ma
__generic<T : __BuiltinFloatingPointType, let N : int>
[ForwardDerivativeOf(dot)]
[PreferRecompute]
+[BackwardDifferentiable]
DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair<vector<T, N>> dpy)
{
T result = T(0);
@@ -510,6 +530,7 @@ DifferentialPair<T> __d_dot(DifferentialPair<vector<T, N>> dpx, DifferentialPair
__generic<T : __BuiltinFloatingPointType, let N : int>
[BackwardDerivativeOf(dot)]
[PreferRecompute]
+[BackwardDifferentiable]
void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<vector<T, N>> dpy, T.Differential dOut)
{
vector<T, N>.Differential x_d_result, y_d_result;
@@ -527,6 +548,7 @@ void __d_dot(inout DifferentialPair<vector<T, N>> dpx, inout DifferentialPair<ve
__generic<T : __BuiltinFloatingPointType>
[ForwardDerivativeOf(cross)]
[PreferRecompute]
+[BackwardDifferentiable]
DifferentialPair<vector<T, 3>> __d_cross(DifferentialPair<vector<T, 3>> a, DifferentialPair<vector<T, 3>> b)
{
/*
@@ -555,6 +577,7 @@ DifferentialPair<vector<T, 3>> __d_cross(DifferentialPair<vector<T, 3>> a, Diffe
__generic<T : __BuiltinFloatingPointType>
[BackwardDerivativeOf(cross)]
[PreferRecompute]
+[BackwardDifferentiable]
void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<vector<T, 3>> b, vector<T, 3>.Differential dOut)
{
/*
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index e8cb821bd..c479ea6d1 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -2042,6 +2042,36 @@ struct DiffTransposePass
}
}
+ void safeSetInsertAfterInst(IRBuilder* builder, IRInst* inst)
+ {
+ // If the inst is in the first or second block of the parent function, then
+ // insert into the third block, otherwise simply call setInsertAfterOrdinaryInst.
+ // The second block is the block that the first block branches into unconditionaly.
+ //
+ if (auto block = as<IRBlock>(inst->getParent()))
+ {
+ auto firstBlock = cast<IRFunc>(block->getParent())->getFirstBlock();
+ if (auto firstBranch = as<IRUnconditionalBranch>(firstBlock->getTerminator()))
+ {
+ auto secondBlock = firstBranch->getTargetBlock();
+
+ if (block == firstBlock || block == secondBlock)
+ {
+ if (auto branch = as<IRUnconditionalBranch>(secondBlock->getTerminator()))
+ {
+ if (auto ordInst = branch->getTargetBlock()->getFirstOrdinaryInst())
+ builder->setInsertAfter(ordInst);
+ else
+ builder->setInsertInto(branch->getTargetBlock());
+
+ return;
+ }
+ }
+ }
+ }
+ setInsertAfterOrdinaryInst(builder, inst);
+ }
+
IRInst* promoteOperandsToTargetType(IRBuilder* builder, IRInst* fwdInst)
{
auto oldLoc = builder->getInsertLoc();
@@ -2060,7 +2090,7 @@ struct DiffTransposePass
// Insert new operand just after the old operand, so we have the old
// operands available.
//
- builder->setInsertAfter(operand);
+ safeSetInsertAfterInst(builder, operand);
IRInst* newOperand = promoteToType(builder, targetType, operand);