From 41eb19e65a0974e23048bd7b3b1eb1e2f569b1d0 Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 8 Dec 2022 14:56:20 -0800 Subject: Auto-diff for matrix operations. (#2559) Co-authored-by: Yong He --- source/slang/core.meta.slang | 6 +- source/slang/diff.meta.slang | 32 +++++++++++ source/slang/hlsl.meta.slang | 111 +++++++++++++++++++++++++++++++++++-- source/slang/slang-emit-c-like.cpp | 20 ++++++- 4 files changed, 163 insertions(+), 6 deletions(-) (limited to 'source') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index ed80b3730..edb98b3c2 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -93,7 +93,11 @@ interface __BuiltinArithmeticType : __BuiltinType /// A type that can be used for logical/bitwise operations [sealed] [builtin] -interface __BuiltinLogicalType : __BuiltinType {} +interface __BuiltinLogicalType : __BuiltinType +{ + /// Initialize from a 32-bit signed integer value. + __init(int value); +} /// A type that logically has a sign (positive/negative/zero) [sealed] diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index a97ab9eaf..248112810 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -110,6 +110,38 @@ void updatePair(inout DifferentialPair p, T newPrimal, T p = DifferentialPair(newPrimal, newDiff); } +// vector-matrix +__generic +__target_intrinsic(hlsl) +__target_intrinsic(glsl, "($1 * $0)") +DifferentialPair> mul(DifferentialPair> left, DifferentialPair> right) +{ + let primal = mul(left.p, right.p); + let diff = mul(left.d, right.p) + mul(left.p, right.d); + return DifferentialPair>(primal, diff); +} + +// matrix-vector +__generic +[ForceInline] +[ForwardDerivativeOf(mul)] +DifferentialPair> mul(DifferentialPair> left, DifferentialPair> right) +{ + let primal = mul(left.p, right.p); + let diff = mul(left.d, right.p) + mul(left.p, right.d); + return DifferentialPair>(primal, diff); +} + + +// matrix-matrix +__generic +[ForwardDerivativeOf(mul)] +DifferentialPair> mul(DifferentialPair> right, DifferentialPair> left) +{ + let primal = mul(right.p, left.p); + let diff = mul(right.d, left.p) + mul(right.p, left.d); + return DifferentialPair>(primal, diff); +} #define VECTOR_MAP_D_UNARY(TYPE, COUNT, D_FUNC, VALUE) \ vector result; \ diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 0232129ee..ed150a4c0 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -2715,7 +2715,24 @@ T mul(vector x, vector y) } // vector-matrix -__generic +__generic +__target_intrinsic(hlsl) +__target_intrinsic(glsl, "($1 * $0)") +vector mul(vector left, matrix right) +{ + vector result; + for( int j = 0; j < M; ++j ) + { + T sum = T(0); + for( int i = 0; i < N; ++i ) + { + sum += left[i] * right[i][j]; + } + result[j] = sum; + } + return result; +} +__generic __target_intrinsic(hlsl) __target_intrinsic(glsl, "($1 * $0)") vector mul(vector left, matrix right) @@ -2732,9 +2749,26 @@ vector mul(vector left, matrix right) } return result; } +__generic +__target_intrinsic(hlsl) +__target_intrinsic(glsl, "($1 * $0)") +vector mul(vector left, matrix right) +{ + vector result; + for( int j = 0; j < M; ++j ) + { + T sum = T(0); + for( int i = 0; i < N; ++i ) + { + sum |= left[i] & right[i][j]; + } + result[j] = sum; + } + return result; +} // matrix-vector -__generic +__generic __target_intrinsic(hlsl) __target_intrinsic(glsl, "($1 * $0)") vector mul(matrix left, vector right) @@ -2751,10 +2785,43 @@ vector mul(matrix left, vector right) } return result; } - +__generic +__target_intrinsic(hlsl) +__target_intrinsic(glsl, "($1 * $0)") +vector mul(matrix left, vector right) +{ + vector result; + for( int i = 0; i < N; ++i ) + { + T sum = T(0); + for( int j = 0; j < M; ++j ) + { + sum += left[i][j] * right[j]; + } + result[i] = sum; + } + return result; +} +__generic +__target_intrinsic(hlsl) +__target_intrinsic(glsl, "($1 * $0)") +vector mul(matrix left, vector right) +{ + vector result; + for( int i = 0; i < N; ++i ) + { + T sum = T(0); + for( int j = 0; j < M; ++j ) + { + sum |= left[i][j] & right[j]; + } + result[i] = sum; + } + return result; +} // matrix-matrix -__generic +__generic __target_intrinsic(hlsl) __target_intrinsic(glsl, "($1 * $0)") matrix mul(matrix right, matrix left) @@ -2772,6 +2839,42 @@ matrix mul(matrix right, matrix left) } return result; } +__generic +__target_intrinsic(hlsl) +__target_intrinsic(glsl, "($1 * $0)") +matrix mul(matrix right, matrix left) +{ + matrix result; + for( int r = 0; r < R; ++r) + for( int c = 0; c < C; ++c) + { + T sum = T(0); + for( int i = 0; i < N; ++i ) + { + sum += left[r][i] * right[i][c]; + } + result[r][c] = sum; + } + return result; +} +__generic +__target_intrinsic(hlsl) +__target_intrinsic(glsl, "($1 * $0)") +matrix mul(matrix right, matrix left) +{ + matrix result; + for( int r = 0; r < R; ++r) + for( int c = 0; c < C; ++c) + { + T sum = T(0); + for( int i = 0; i < N; ++i ) + { + sum |= left[r][i] & right[i][c]; + } + result[r][c] = sum; + } + return result; +} // noise (deprecated) diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index cadab5690..dd40b7856 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1778,7 +1778,6 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO case kIROp_MakeVector: case kIROp_MakeMatrix: - case kIROp_MakeMatrixFromScalar: case kIROp_MatrixReshape: case kIROp_VectorReshape: case kIROp_CastFloatToInt: @@ -1789,6 +1788,25 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO emitType(inst->getDataType()); emitArgs(inst); break; + case kIROp_MakeMatrixFromScalar: + { + emitType(inst->getDataType()); + auto matrixType = as(inst->getDataType()); + SLANG_RELEASE_ASSERT(matrixType); + auto columnCount = as(matrixType->getColumnCount()); + SLANG_RELEASE_ASSERT(columnCount); + auto rowCount = as(matrixType->getRowCount()); + SLANG_RELEASE_ASSERT(rowCount); + m_writer->emit("("); + for (IRIntegerValue i = 0; i < rowCount->getValue() * columnCount->getValue(); i++) + { + if (i != 0) + m_writer->emit(", "); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + } + m_writer->emit(")"); + } + break; case kIROp_AllocObj: m_writer->emit("new "); m_writer->emit(getName(inst->getDataType())); -- cgit v1.2.3