diff options
| author | Yong He <yonghe@outlook.com> | 2022-12-08 14:56:20 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-12-08 14:56:20 -0800 |
| commit | 41eb19e65a0974e23048bd7b3b1eb1e2f569b1d0 (patch) | |
| tree | c6cde57da4d3415d86d09213936a48d3d26e07e1 | |
| parent | 468bb7ecf65c000c308adae511bf65a1ca4cc412 (diff) | |
Auto-diff for matrix operations. (#2559)
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/core.meta.slang | 6 | ||||
| -rw-r--r-- | source/slang/diff.meta.slang | 32 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 111 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 20 | ||||
| -rw-r--r-- | tests/autodiff/matrix-arithmetic-fwd.slang | 41 | ||||
| -rw-r--r-- | tests/autodiff/matrix-arithmetic-fwd.slang.expected.txt | 5 |
6 files changed, 209 insertions, 6 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index ed80b3730..edb98b3c2 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -93,7 +93,11 @@ interface __BuiltinArithmeticType : __BuiltinType /// A type that can be used for logical/bitwise operations [sealed] [builtin] -interface __BuiltinLogicalType : __BuiltinType {} +interface __BuiltinLogicalType : __BuiltinType +{ + /// Initialize from a 32-bit signed integer value. + __init(int value); +} /// A type that logically has a sign (positive/negative/zero) [sealed] diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index a97ab9eaf..248112810 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -110,6 +110,38 @@ void updatePair<T : IDifferentiable>(inout DifferentialPair<T> p, T newPrimal, T p = DifferentialPair<T>(newPrimal, newDiff); } +// vector-matrix +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> +__target_intrinsic(hlsl) +__target_intrinsic(glsl, "($1 * $0)") +DifferentialPair<vector<T, M>> mul(DifferentialPair<vector<T, N>> left, DifferentialPair<matrix<T, N, M>> right) +{ + let primal = mul(left.p, right.p); + let diff = mul(left.d, right.p) + mul(left.p, right.d); + return DifferentialPair<vector<T,M>>(primal, diff); +} + +// matrix-vector +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> +[ForceInline] +[ForwardDerivativeOf(mul)] +DifferentialPair<vector<T,N>> mul(DifferentialPair<matrix<T,N,M>> left, DifferentialPair<vector<T,M>> right) +{ + let primal = mul(left.p, right.p); + let diff = mul(left.d, right.p) + mul(left.p, right.d); + return DifferentialPair<vector<T,N>>(primal, diff); +} + + +// matrix-matrix +__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int> +[ForwardDerivativeOf(mul)] +DifferentialPair<matrix<T,R,C>> mul(DifferentialPair<matrix<T,R,N>> right, DifferentialPair<matrix<T,N,C>> left) +{ + let primal = mul(right.p, left.p); + let diff = mul(right.d, left.p) + mul(right.p, left.d); + return DifferentialPair<matrix<T,R,C>>(primal, diff); +} #define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \ vector<TYPE, COUNT> result; \ diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 0232129ee..ed150a4c0 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -2715,7 +2715,24 @@ T mul(vector<T, N> x, vector<T, N> y) } // vector-matrix -__generic<T : __BuiltinArithmeticType, let N : int, let M : int> +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> +__target_intrinsic(hlsl) +__target_intrinsic(glsl, "($1 * $0)") +vector<T, M> mul(vector<T, N> left, matrix<T, N, M> right) +{ + vector<T,M> result; + for( int j = 0; j < M; ++j ) + { + T sum = T(0); + for( int i = 0; i < N; ++i ) + { + sum += left[i] * right[i][j]; + } + result[j] = sum; + } + return result; +} +__generic<T : __BuiltinIntegerType, let N : int, let M : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "($1 * $0)") vector<T, M> mul(vector<T, N> left, matrix<T, N, M> right) @@ -2732,9 +2749,26 @@ vector<T, M> mul(vector<T, N> left, matrix<T, N, M> right) } return result; } +__generic<T : __BuiltinLogicalType, let N : int, let M : int> +__target_intrinsic(hlsl) +__target_intrinsic(glsl, "($1 * $0)") +vector<T, M> mul(vector<T, N> left, matrix<T, N, M> right) +{ + vector<T,M> result; + for( int j = 0; j < M; ++j ) + { + T sum = T(0); + for( int i = 0; i < N; ++i ) + { + sum |= left[i] & right[i][j]; + } + result[j] = sum; + } + return result; +} // matrix-vector -__generic<T : __BuiltinArithmeticType, let N : int, let M : int> +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "($1 * $0)") vector<T,N> mul(matrix<T,N,M> left, vector<T,M> right) @@ -2751,10 +2785,43 @@ vector<T,N> mul(matrix<T,N,M> left, vector<T,M> right) } return result; } - +__generic<T : __BuiltinIntegerType, let N : int, let M : int> +__target_intrinsic(hlsl) +__target_intrinsic(glsl, "($1 * $0)") +vector<T,N> mul(matrix<T,N,M> left, vector<T,M> right) +{ + vector<T,N> result; + for( int i = 0; i < N; ++i ) + { + T sum = T(0); + for( int j = 0; j < M; ++j ) + { + sum += left[i][j] * right[j]; + } + result[i] = sum; + } + return result; +} +__generic<T : __BuiltinLogicalType, let N : int, let M : int> +__target_intrinsic(hlsl) +__target_intrinsic(glsl, "($1 * $0)") +vector<T,N> mul(matrix<T,N,M> left, vector<T,M> right) +{ + vector<T,N> result; + for( int i = 0; i < N; ++i ) + { + T sum = T(0); + for( int j = 0; j < M; ++j ) + { + sum |= left[i][j] & right[j]; + } + result[i] = sum; + } + return result; +} // matrix-matrix -__generic<T : __BuiltinArithmeticType, let R : int, let N : int, let C : int> +__generic<T : __BuiltinFloatingPointType, let R : int, let N : int, let C : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "($1 * $0)") matrix<T,R,C> mul(matrix<T,R,N> right, matrix<T,N,C> left) @@ -2772,6 +2839,42 @@ matrix<T,R,C> mul(matrix<T,R,N> right, matrix<T,N,C> left) } return result; } +__generic<T : __BuiltinIntegerType, let R : int, let N : int, let C : int> +__target_intrinsic(hlsl) +__target_intrinsic(glsl, "($1 * $0)") +matrix<T,R,C> mul(matrix<T,R,N> right, matrix<T,N,C> left) +{ + matrix<T,R,C> result; + for( int r = 0; r < R; ++r) + for( int c = 0; c < C; ++c) + { + T sum = T(0); + for( int i = 0; i < N; ++i ) + { + sum += left[r][i] * right[i][c]; + } + result[r][c] = sum; + } + return result; +} +__generic<T : __BuiltinLogicalType, let R : int, let N : int, let C : int> +__target_intrinsic(hlsl) +__target_intrinsic(glsl, "($1 * $0)") +matrix<T,R,C> mul(matrix<T,R,N> right, matrix<T,N,C> left) +{ + matrix<T,R,C> result; + for( int r = 0; r < R; ++r) + for( int c = 0; c < C; ++c) + { + T sum = T(0); + for( int i = 0; i < N; ++i ) + { + sum |= left[r][i] & right[i][c]; + } + result[r][c] = sum; + } + return result; +} // noise (deprecated) diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index cadab5690..dd40b7856 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1778,7 +1778,6 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO case kIROp_MakeVector: case kIROp_MakeMatrix: - case kIROp_MakeMatrixFromScalar: case kIROp_MatrixReshape: case kIROp_VectorReshape: case kIROp_CastFloatToInt: @@ -1789,6 +1788,25 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO emitType(inst->getDataType()); emitArgs(inst); break; + case kIROp_MakeMatrixFromScalar: + { + emitType(inst->getDataType()); + auto matrixType = as<IRMatrixType>(inst->getDataType()); + SLANG_RELEASE_ASSERT(matrixType); + auto columnCount = as<IRIntLit>(matrixType->getColumnCount()); + SLANG_RELEASE_ASSERT(columnCount); + auto rowCount = as<IRIntLit>(matrixType->getRowCount()); + SLANG_RELEASE_ASSERT(rowCount); + m_writer->emit("("); + for (IRIntegerValue i = 0; i < rowCount->getValue() * columnCount->getValue(); i++) + { + if (i != 0) + m_writer->emit(", "); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + } + m_writer->emit(")"); + } + break; case kIROp_AllocObj: m_writer->emit("new "); m_writer->emit(getName(inst->getDataType())); diff --git a/tests/autodiff/matrix-arithmetic-fwd.slang b/tests/autodiff/matrix-arithmetic-fwd.slang new file mode 100644 index 000000000..7a953cef8 --- /dev/null +++ b/tests/autodiff/matrix-arithmetic-fwd.slang @@ -0,0 +1,41 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[ForwardDifferentiable] +float3x3 g(float3x3 x, float3x3 y) +{ + float3x3 a = x + y; + float3x3 b = x - y; + return a * b + 2 * x * y; +} + +[ForwardDifferentiable] +float h(float2x2 x, float2x2 y) +{ + let t = mul(x, y); + return t[0][0] + t[0][1] + t[1][0] + t[1][1]; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + float3x3 a = float3x3(2.0); + float3x3 b = float3x3(1.5); + float3x3 da = float3x3(1.0); + + outputBuffer[0] = __fwd_diff(g)( + DifferentialPair<float3x3>(a, da), + DifferentialPair<float3x3>(b, da)).d._11; // Expect: 8 + + float2x2 l = float2x2(1.0, 2.0, 3.0, 4.0); + float2x2 r = float2x2(10.0, 11.0, 12.0, 13.0); + float2x2 d = float2x2(1.0, 0.0, 1.0, 1.0); + + //float2x2 epsilon = d * 0.001f; + //outputBuffer[1] = (h(l + epsilon, r + epsilon) - h(l - epsilon, r - epsilon)) / (epsilon[0][0] * 2.0)); + + outputBuffer[1] = __fwd_diff(h)(DifferentialPair<float2x2>(l, d), DifferentialPair<float2x2>(r, d)).d; // Expect 83.0 +} diff --git a/tests/autodiff/matrix-arithmetic-fwd.slang.expected.txt b/tests/autodiff/matrix-arithmetic-fwd.slang.expected.txt new file mode 100644 index 000000000..c595048c3 --- /dev/null +++ b/tests/autodiff/matrix-arithmetic-fwd.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +8.0 +83.0 +0.0 +0.0
\ No newline at end of file |
