diff options
Diffstat (limited to 'Scripts/transcribe.py')
| -rw-r--r-- | Scripts/transcribe.py | 28 |
1 files changed, 18 insertions, 10 deletions
diff --git a/Scripts/transcribe.py b/Scripts/transcribe.py index 8491e4d..4d36e53 100644 --- a/Scripts/transcribe.py +++ b/Scripts/transcribe.py @@ -202,8 +202,7 @@ def resetAudio(audio_state): audio_state.transcribe_lock.release() # Transcribe the audio recorded in a file. -def transcribe(audio_state, model, frames): - +def transcribe(audio_state, model, frames, use_cpu: bool): start_time = time.time() frames = audio_state.frames @@ -223,8 +222,10 @@ def transcribe(audio_state, model, frames): #for temp in (0.00, 0.05, 0.10, 0.15, 0.20): #for temp in (0.00, 0.05): for temp in (0.00,): + use_gpu = not use_cpu options = whisper.DecodingOptions(language = audio_state.language, - beam_size = 5, temperature = temp, without_timestamps = True) + beam_size = 5, temperature = temp, without_timestamps = True, + fp16 = use_gpu) result = whisper.decode(model, mel, options) if result.avg_logprob < -1.0: @@ -247,7 +248,7 @@ def transcribe(audio_state, model, frames): return result -def transcribeAudio(audio_state, model): +def transcribeAudio(audio_state, model, use_cpu: bool): last_transcribe_time = time.time() while audio_state.run_app == True: # Pace this out @@ -266,7 +267,7 @@ def transcribeAudio(audio_state, model): audio_state.transcribe_sleep_duration_max_s, longer_sleep_dur) - text = transcribe(audio_state, model, audio_state.frames) + text = transcribe(audio_state, model, audio_state.frames, use_cpu) if not text: print("no transcription, spin ({} seconds)".format(time.time() - last_transcribe_time)) last_transcribe_time = time.time() @@ -373,7 +374,7 @@ def readControllerInput(audio_state, enable_local_beep): # model should correspond to one of the Whisper models defined in # whisper/__init__.py. Examples: tiny, base, small, medium. -def transcribeLoop(mic: str, language: str, model: str, enable_local_beep: bool): +def transcribeLoop(mic: str, language: str, model: str, enable_local_beep: bool, use_cpu: bool): audio_state = getMicStream(mic) audio_state.language = whisper.tokenizer.TO_LANGUAGE_CODE[language] @@ -386,7 +387,7 @@ def transcribeLoop(mic: str, language: str, model: str, enable_local_beep: bool) print("Model {} will be saved to {}".format(model, model_root)) model = whisper.load_model(model, download_root=model_root) - transcribe_audio_thd = threading.Thread(target = transcribeAudio, args = [audio_state, model]) + transcribe_audio_thd = threading.Thread(target = transcribeAudio, args = [audio_state, model, use_cpu]) transcribe_audio_thd.daemon = True transcribe_audio_thd.start() @@ -432,10 +433,11 @@ if __name__ == "__main__": parser.add_argument("--model", type=str, help="Which AI model to use. Ex: tiny, base, small, medium") parser.add_argument("--bytes_per_char", type=str, help="The number of bytes to use to represent each character") parser.add_argument("--chars_per_sync", type=str, help="The number of characters to send on each sync event") - parser.add_argument("--enable_local_beep", type=int, help="Whether to play a local auditory indicator when transcription starts/stops."); + parser.add_argument("--enable_local_beep", type=int, help="Whether to play a local auditory indicator when transcription starts/stops.") parser.add_argument("--rows", type=int, help="The number of rows on the board") parser.add_argument("--cols", type=int, help="The number of columns on the board") - parser.add_argument("--window_duration_s", type=int, help="The length in seconds of the audio recording handed to the transcription algorithm"); + parser.add_argument("--window_duration_s", type=int, help="The length in seconds of the audio recording handed to the transcription algorithm") + parser.add_argument("--cpu", type=int, help="If set to 1, use CPU instead of GPU") args = parser.parse_args() if not args.mic: @@ -458,10 +460,16 @@ if __name__ == "__main__": if args.window_duration_s: config.MAX_LENGTH_S = int(args.window_duration_s) + if args.cpu == 1: + args.cpu = True + else: + args.cpu = False + generate_utils.config.BYTES_PER_CHAR = int(args.bytes_per_char) generate_utils.config.CHARS_PER_SYNC = int(args.chars_per_sync) generate_utils.config.BOARD_ROWS = int(args.rows) generate_utils.config.BOARD_COLS = int(args.cols) - transcribeLoop(args.mic, args.language, args.model, args.enable_local_beep) + transcribeLoop(args.mic, args.language, args.model, args.enable_local_beep, + args.cpu) |
