From 4f967dbdb7972ec52039bd3e3ce3e1e4cbcf6545 Mon Sep 17 00:00:00 2001 From: yum Date: Mon, 27 Feb 2023 18:03:02 -0800 Subject: Begin work on beam search decoding * ContextImpl.h puts prompts, previous prompts, probabilities, and probability IDs into vectors of size 1 or N_BEAMS, depending on the decoding strategy. * Extend sampleBest and friends to return top N tokens, instead of just the top 1 token. --- Designs/.BeamSearch.md.swp | Bin 0 -> 12288 bytes Designs/BeamSearch.md | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) create mode 100644 Designs/.BeamSearch.md.swp create mode 100644 Designs/BeamSearch.md (limited to 'Designs') diff --git a/Designs/.BeamSearch.md.swp b/Designs/.BeamSearch.md.swp new file mode 100644 index 0000000..8102d31 Binary files /dev/null and b/Designs/.BeamSearch.md.swp differ diff --git a/Designs/BeamSearch.md b/Designs/BeamSearch.md new file mode 100644 index 0000000..4902b45 --- /dev/null +++ b/Designs/BeamSearch.md @@ -0,0 +1,33 @@ +This is the design for the beam search decoding algorithm. + +ContextImpl.cpp is where greedy decoding is implemented. + +sampleBest() does the following: + +1. Populate `probs_ids` with (probability, idx). +2. Find the `top_k` most likely tokens. +3. Filter out tokens matching `token_sot`, `token_solm`, `token_not`. + 3.1. sot == start of text, solm = ???, not == not timestamps +4. Return the result. + +We should modify this to return the most likely N tokens. We'll call this +sampleBestN(). + +Greedy search works like this: + +1. Call decode(prompt, tokens). + 1.1. This populates `probs`. Need to wrap this in a vector of size `N_BEAMS`. + 1.2. Need to wrap `prompt` in a vector of size `N_BEAMS`. +2. Extract the most likely token and add it to the prompt. +3. Repeat until EOT (end of transcript) or max tokens or end of audio stream. + +Beam search will work like this: + +1. Initialize `prompts` as a vector containing `prompt`. +2. Call decode(prompt, tokens) for each prompt in `prompts`. +3. Extract the most likely `N_BEAMS` tokens in each result. +4. Compute joint probabilities for each token. +5. Extract the most likely `N_BEAMS` prompts using joint probabilities. +6. Update `prompts`. +7. Repeat until end condition is reached for all prompts. +8. Return most likely prompt. -- cgit v1.2.3