diff options
| author | yum <yum.food.vr@gmail.com> | 2023-02-27 18:03:02 -0800 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2023-02-27 18:03:02 -0800 |
| commit | 4f967dbdb7972ec52039bd3e3ce3e1e4cbcf6545 (patch) | |
| tree | 158cf93f6ec626f32b961acc6d791bd2a2040b94 /Designs/BeamSearch.md | |
| parent | 1136acfc365f357d2df13a263714e8ae0614c4f9 (diff) | |
Begin work on beam search decoding
* ContextImpl.h puts prompts, previous prompts, probabilities, and
probability IDs into vectors of size 1 or N_BEAMS, depending on
the decoding strategy.
* Extend sampleBest and friends to return top N tokens, instead of
just the top 1 token.
Diffstat (limited to 'Designs/BeamSearch.md')
| -rw-r--r-- | Designs/BeamSearch.md | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/Designs/BeamSearch.md b/Designs/BeamSearch.md new file mode 100644 index 0000000..4902b45 --- /dev/null +++ b/Designs/BeamSearch.md @@ -0,0 +1,33 @@ +This is the design for the beam search decoding algorithm. + +ContextImpl.cpp is where greedy decoding is implemented. + +sampleBest() does the following: + +1. Populate `probs_ids` with (probability, idx). +2. Find the `top_k` most likely tokens. +3. Filter out tokens matching `token_sot`, `token_solm`, `token_not`. + 3.1. sot == start of text, solm = ???, not == not timestamps +4. Return the result. + +We should modify this to return the most likely N tokens. We'll call this +sampleBestN(). + +Greedy search works like this: + +1. Call decode(prompt, tokens). + 1.1. This populates `probs`. Need to wrap this in a vector of size `N_BEAMS`. + 1.2. Need to wrap `prompt` in a vector of size `N_BEAMS`. +2. Extract the most likely token and add it to the prompt. +3. Repeat until EOT (end of transcript) or max tokens or end of audio stream. + +Beam search will work like this: + +1. Initialize `prompts` as a vector containing `prompt`. +2. Call decode(prompt, tokens) for each prompt in `prompts`. +3. Extract the most likely `N_BEAMS` tokens in each result. +4. Compute joint probabilities for each token. +5. Extract the most likely `N_BEAMS` prompts using joint probabilities. +6. Update `prompts`. +7. Repeat until end condition is reached for all prompts. +8. Return most likely prompt. |
