diff options
| author | Konstantin <const@const.me> | 2023-01-23 13:48:29 +0100 |
|---|---|---|
| committer | Konstantin <const@const.me> | 2023-01-23 13:48:29 +0100 |
| commit | 01aba39f15a03ed96e034ffc3b6ee9ec12294b0d (patch) | |
| tree | c8f4b5610c15404f97e01f2436febf13994ed586 | |
| parent | 99b87744ba22df78e3c476c92945b75acd267b87 (diff) | |
Minor, profiler tags
| -rw-r--r-- | Whisper/ML/Context.ops.cpp | 1 | ||||
| -rw-r--r-- | Whisper/Whisper/WhisperContext.cpp | 3 |
2 files changed, 4 insertions, 0 deletions
diff --git a/Whisper/ML/Context.ops.cpp b/Whisper/ML/Context.ops.cpp index 7dfca9f..a94497e 100644 --- a/Whisper/ML/Context.ops.cpp +++ b/Whisper/ML/Context.ops.cpp @@ -192,6 +192,7 @@ Tensor MlContext::flashAttention( const Tensor& q, const Tensor& k, const Tensor profiler.setNextTag( "flashAttn.1" ); Tensor tmp = mulMat( k, q ); + profiler.setNextTag( "flashAttention" ); const float tempScale = (float)( 1.0 / sqrt( (double)(int)q.ne[ 0 ] ) ); softMax( tmp, tempScale ); diff --git a/Whisper/Whisper/WhisperContext.cpp b/Whisper/Whisper/WhisperContext.cpp index d558aa6..e694930 100644 --- a/Whisper/Whisper/WhisperContext.cpp +++ b/Whisper/Whisper/WhisperContext.cpp @@ -457,6 +457,7 @@ Tensor WhisperContext::decodeLayer( const Tensor& inpL, size_t il, const sLayerD if( 0 == il ) Tracing::tensor( "dec-KQ-0", KQ ); diagMaskInf( KQ, ldp.n_past ); if( 0 == il ) Tracing::tensor( "dec-KQ-1", KQ ); + profiler.setNextTag( "decLayer.1" ); softMax( KQ ); if( 0 == il ) Tracing::tensor( "dec-KQ-2", KQ ); @@ -506,6 +507,7 @@ Tensor WhisperContext::decodeLayer( const Tensor& inpL, size_t il, const sLayerD Tensor K = permute( Kcross, 0, 2, 1, 3 ); profiler.setNextTag( "dec.layer.8" ); Tensor KQ = mulMat( K, Q ); + profiler.setNextTag( "decLayer.2" ); softMax( KQ ); Tensor V_trans = permute( Vcross, 1, 2, 0, 3 ); profiler.setNextTag( "dec.layer.9" ); @@ -628,6 +630,7 @@ void WhisperContext::decode( const int* tokens, const int n_tokens, const sDecod cur = mulMat( gpuModel.dec.tokenEmbedding, cur ); // logits -> probs + profiler.setNextTag( "dec.probs" ); softMax( cur ); decoderOutput.copyFromVram( cur ); |
