summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-12-08 14:56:20 -0800
committerGitHub <noreply@github.com>2022-12-08 14:56:20 -0800
commit41eb19e65a0974e23048bd7b3b1eb1e2f569b1d0 (patch)
treec6cde57da4d3415d86d09213936a48d3d26e07e1
parent468bb7ecf65c000c308adae511bf65a1ca4cc412 (diff)
Auto-diff for matrix operations. (#2559)
Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/core.meta.slang6
-rw-r--r--source/slang/diff.meta.slang32
-rw-r--r--source/slang/hlsl.meta.slang111
-rw-r--r--source/slang/slang-emit-c-like.cpp20
-rw-r--r--tests/autodiff/matrix-arithmetic-fwd.slang41
-rw-r--r--tests/autodiff/matrix-arithmetic-fwd.slang.expected.txt5
6 files changed, 209 insertions, 6 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index ed80b3730..edb98b3c2 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -93,7 +93,11 @@ interface __BuiltinArithmeticType : __BuiltinType
/// A type that can be used for logical/bitwise operations
[sealed]
[builtin]
-interface __BuiltinLogicalType : __BuiltinType {}
+interface __BuiltinLogicalType : __BuiltinType
+{
+ /// Initialize from a 32-bit signed integer value.
+ __init(int value);
+}
/// A type that logically has a sign (positive/negative/zero)
[sealed]
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index a97ab9eaf..248112810 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -110,6 +110,38 @@ void updatePair<T : IDifferentiable>(inout DifferentialPair<T> p, T newPrimal, T
p = DifferentialPair<T>(newPrimal, newDiff);
}
+// vector-matrix
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
+__target_intrinsic(hlsl)
+__target_intrinsic(glsl, "($1 * $0)")
+DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, DifferentialPair<matrix<T, N, M>> right)
+{
+ let primal = mul(left.p, right.p);
+ let diff = mul(left.d, right.p) + mul(left.p, right.d);
+ return DifferentialPair<vector<T,M>>(primal, diff);
+}
+
+// matrix-vector
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
+[ForceInline]
+[ForwardDerivativeOf(mul)]
+DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, DifferentialPair<vector<T,M>> right)
+{
+ let primal = mul(left.p, right.p);
+ let diff = mul(left.d, right.p) + mul(left.p, right.d);
+ return DifferentialPair<vector<T,N>>(primal, diff);
+}
+
+
+// matrix-matrix
+__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int>
+[ForwardDerivativeOf(mul)]
+DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, DifferentialPair<matrix<T,N,C>> left)
+{
+ let primal = mul(right.p, left.p);
+ let diff = mul(right.d, left.p) + mul(right.p, left.d);
+ return DifferentialPair<matrix<T,R,C>>(primal, diff);
+}
#define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \
vector<TYPE, COUNT> result; \
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 0232129ee..ed150a4c0 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -2715,7 +2715,24 @@ T mul(vector<T, N> x, vector<T, N> y)
}
// vector-matrix
-__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
+__target_intrinsic(hlsl)
+__target_intrinsic(glsl, "($1 * $0)")
+vector<T, M> mul(vector<T, N> left, matrix<T, N, M> right)
+{
+ vector<T,M> result;
+ for( int j = 0; j < M; ++j )
+ {
+ T sum = T(0);
+ for( int i = 0; i < N; ++i )
+ {
+ sum += left[i] * right[i][j];
+ }
+ result[j] = sum;
+ }
+ return result;
+}
+__generic<T : __BuiltinIntegerType, let N : int, let M : int>
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "($1 * $0)")
vector<T, M> mul(vector<T, N> left, matrix<T, N, M> right)
@@ -2732,9 +2749,26 @@ vector<T, M> mul(vector<T, N> left, matrix<T, N, M> right)
}
return result;
}
+__generic<T : __BuiltinLogicalType, let N : int, let M : int>
+__target_intrinsic(hlsl)
+__target_intrinsic(glsl, "($1 * $0)")
+vector<T, M> mul(vector<T, N> left, matrix<T, N, M> right)
+{
+ vector<T,M> result;
+ for( int j = 0; j < M; ++j )
+ {
+ T sum = T(0);
+ for( int i = 0; i < N; ++i )
+ {
+ sum |= left[i] & right[i][j];
+ }
+ result[j] = sum;
+ }
+ return result;
+}
// matrix-vector
-__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
+__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "($1 * $0)")
vector<T,N> mul(matrix<T,N,M> left, vector<T,M> right)
@@ -2751,10 +2785,43 @@ vector<T,N> mul(matrix<T,N,M> left, vector<T,M> right)
}
return result;
}
-
+__generic<T : __BuiltinIntegerType, let N : int, let M : int>
+__target_intrinsic(hlsl)
+__target_intrinsic(glsl, "($1 * $0)")
+vector<T,N> mul(matrix<T,N,M> left, vector<T,M> right)
+{
+ vector<T,N> result;
+ for( int i = 0; i < N; ++i )
+ {
+ T sum = T(0);
+ for( int j = 0; j < M; ++j )
+ {
+ sum += left[i][j] * right[j];
+ }
+ result[i] = sum;
+ }
+ return result;
+}
+__generic<T : __BuiltinLogicalType, let N : int, let M : int>
+__target_intrinsic(hlsl)
+__target_intrinsic(glsl, "($1 * $0)")
+vector<T,N> mul(matrix<T,N,M> left, vector<T,M> right)
+{
+ vector<T,N> result;
+ for( int i = 0; i < N; ++i )
+ {
+ T sum = T(0);
+ for( int j = 0; j < M; ++j )
+ {
+ sum |= left[i][j] & right[j];
+ }
+ result[i] = sum;
+ }
+ return result;
+}
// matrix-matrix
-__generic<T : __BuiltinArithmeticType, let R : int, let N : int, let C : int>
+__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int>
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "($1 * $0)")
matrix<T,R,C> mul(matrix<T,R,N> right, matrix<T,N,C> left)
@@ -2772,6 +2839,42 @@ matrix<T,R,C> mul(matrix<T,R,N> right, matrix<T,N,C> left)
}
return result;
}
+__generic<T : __BuiltinIntegerType, let R : int, let N : int, let C : int>
+__target_intrinsic(hlsl)
+__target_intrinsic(glsl, "($1 * $0)")
+matrix<T,R,C> mul(matrix<T,R,N> right, matrix<T,N,C> left)
+{
+ matrix<T,R,C> result;
+ for( int r = 0; r < R; ++r)
+ for( int c = 0; c < C; ++c)
+ {
+ T sum = T(0);
+ for( int i = 0; i < N; ++i )
+ {
+ sum += left[r][i] * right[i][c];
+ }
+ result[r][c] = sum;
+ }
+ return result;
+}
+__generic<T : __BuiltinLogicalType, let R : int, let N : int, let C : int>
+__target_intrinsic(hlsl)
+__target_intrinsic(glsl, "($1 * $0)")
+matrix<T,R,C> mul(matrix<T,R,N> right, matrix<T,N,C> left)
+{
+ matrix<T,R,C> result;
+ for( int r = 0; r < R; ++r)
+ for( int c = 0; c < C; ++c)
+ {
+ T sum = T(0);
+ for( int i = 0; i < N; ++i )
+ {
+ sum |= left[r][i] & right[i][c];
+ }
+ result[r][c] = sum;
+ }
+ return result;
+}
// noise (deprecated)
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index cadab5690..dd40b7856 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -1778,7 +1778,6 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
case kIROp_MakeVector:
case kIROp_MakeMatrix:
- case kIROp_MakeMatrixFromScalar:
case kIROp_MatrixReshape:
case kIROp_VectorReshape:
case kIROp_CastFloatToInt:
@@ -1789,6 +1788,25 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
emitType(inst->getDataType());
emitArgs(inst);
break;
+ case kIROp_MakeMatrixFromScalar:
+ {
+ emitType(inst->getDataType());
+ auto matrixType = as<IRMatrixType>(inst->getDataType());
+ SLANG_RELEASE_ASSERT(matrixType);
+ auto columnCount = as<IRIntLit>(matrixType->getColumnCount());
+ SLANG_RELEASE_ASSERT(columnCount);
+ auto rowCount = as<IRIntLit>(matrixType->getRowCount());
+ SLANG_RELEASE_ASSERT(rowCount);
+ m_writer->emit("(");
+ for (IRIntegerValue i = 0; i < rowCount->getValue() * columnCount->getValue(); i++)
+ {
+ if (i != 0)
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ }
+ m_writer->emit(")");
+ }
+ break;
case kIROp_AllocObj:
m_writer->emit("new ");
m_writer->emit(getName(inst->getDataType()));
diff --git a/tests/autodiff/matrix-arithmetic-fwd.slang b/tests/autodiff/matrix-arithmetic-fwd.slang
new file mode 100644
index 000000000..7a953cef8
--- /dev/null
+++ b/tests/autodiff/matrix-arithmetic-fwd.slang
@@ -0,0 +1,41 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[ForwardDifferentiable]
+float3x3 g(float3x3 x, float3x3 y)
+{
+ float3x3 a = x + y;
+ float3x3 b = x - y;
+ return a * b + 2 * x * y;
+}
+
+[ForwardDifferentiable]
+float h(float2x2 x, float2x2 y)
+{
+ let t = mul(x, y);
+ return t[0][0] + t[0][1] + t[1][0] + t[1][1];
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ float3x3 a = float3x3(2.0);
+ float3x3 b = float3x3(1.5);
+ float3x3 da = float3x3(1.0);
+
+ outputBuffer[0] = __fwd_diff(g)(
+ DifferentialPair<float3x3>(a, da),
+ DifferentialPair<float3x3>(b, da)).d._11; // Expect: 8
+
+ float2x2 l = float2x2(1.0, 2.0, 3.0, 4.0);
+ float2x2 r = float2x2(10.0, 11.0, 12.0, 13.0);
+ float2x2 d = float2x2(1.0, 0.0, 1.0, 1.0);
+
+ //float2x2 epsilon = d * 0.001f;
+ //outputBuffer[1] = (h(l + epsilon, r + epsilon) - h(l - epsilon, r - epsilon)) / (epsilon[0][0] * 2.0));
+
+ outputBuffer[1] = __fwd_diff(h)(DifferentialPair<float2x2>(l, d), DifferentialPair<float2x2>(r, d)).d; // Expect 83.0
+}
diff --git a/tests/autodiff/matrix-arithmetic-fwd.slang.expected.txt b/tests/autodiff/matrix-arithmetic-fwd.slang.expected.txt
new file mode 100644
index 000000000..c595048c3
--- /dev/null
+++ b/tests/autodiff/matrix-arithmetic-fwd.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+8.0
+83.0
+0.0
+0.0 \ No newline at end of file