diff --git a/poetry.lock b/poetry.lock index ba41a48..43b135e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2901,13 +2901,13 @@ files = [ [[package]] name = "requests" -version = "2.30.0" +version = "2.31.0" description = "Python HTTP for Humans." optional = false python-versions = ">=3.7" files = [ - {file = "requests-2.30.0-py3-none-any.whl", hash = "sha256:10e94cc4f3121ee6da529d358cdaeaff2f1c409cd377dbc72b825852f2f7e294"}, - {file = "requests-2.30.0.tar.gz", hash = "sha256:239d7d4458afcb28a692cdd298d87542235f4ca8d36d03a15bfc128a6559a2f4"}, + {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, + {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index a6bda40..0cfdcc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.poetry] -name = "kobold_assistant" +name = "kobold-assistant" version = "0.1.1" description = "" authors = ["Lee Braiden "] diff --git a/src/kobold_assistant/__main__.py b/src/kobold_assistant/__main__.py index 63a47dd..bbe573c 100644 --- a/src/kobold_assistant/__main__.py +++ b/src/kobold_assistant/__main__.py @@ -13,10 +13,10 @@ from tempfile import NamedTemporaryFile from typing import List, Optional, Tuple +import pyaudio +import speech_recognition as stt import torch import torchaudio - -import speech_recognition as stt from TTS.api import TTS from pydub import AudioSegment from pydub.playback import play as play_audio_segment @@ -34,9 +34,6 @@ settings = default_settings -print(f"Loaded settings from {settings.__file__}") - - def text_to_phonemes(text: str) -> str: # passthrough, since we (try to) get around this by prompting # the AI to spell-out any abbreviations instead, for now. @@ -44,25 +41,59 @@ def text_to_phonemes(text: str) -> str: return text +def get_microphone_device_id(Microphone) -> int: + """ + Returns the users chosen settings.MICROPHONE_DEVICE_INDEX, + or if that is None, returns the microphone device index + for the first working microphone, as a best guess + """ + if settings.MICROPHONE_DEVICE_INDEX is not None: + return settings.MICROPHONE_DEVICE_INDEX + + working_mics = Microphone.list_working_microphones() + if working_mics: + return list(working_mics.keys())[0] + + return None + + def prompt_ai(prompt: str) -> str: - post_data = json.dumps({ + post_data = { 'prompt': prompt, 'temperature': settings.GENERATE_TEMPERATURE, - }).encode('utf8') + } + + post_json = json.dumps(post_data) + post_json_bytes = post_json.encode('utf8') + + print(f"settings.GENERATE_URL is {settings.GENERATE_URL!r}") + req = urllib.request.Request(settings.GENERATE_URL, method="POST") + req.add_header('Content-Type', 'application/json; charset=utf-8') + + try: + response_obj = urllib.request.urlopen(req, data=post_json_bytes) + response_charset = response_obj.info().get_param('charset') or 'utf-8' + json_response = json.loads(response_obj.read().decode(response_charset)) - with urllib.request.urlopen(settings.GENERATE_URL, post_data) as f: - response_json = f.read().decode('utf-8') - response = json.loads(response_json)['results'][0]['text'] + try: + return json_response['results'][0]['text'] + except (KeyError, IndexError) as e: + print("ERROR: KoboldAI API returned an unexpected response format!", file=sys.stderr) + return None - print(f"The AI returned {response!r}") + except urllib.error.URLError as e: + print(f"ERROR: the KoboldAI API returned {e!r}!", file=sys.stderr) + json_response = None - return response + return json_response +# horrible global hack for now; will get fixed temp_audio_files = {} def say(tts_engine, text, cache=False, warmup_only=False): + # horrible global hack for now; will get fixed global temp_audio_files # TODO: Choose (or obtain from config) the best speaker @@ -75,7 +106,6 @@ def say(tts_engine, text, cache=False, warmup_only=False): if tts_engine.languages is not None and len(tts_engine.languages) > 0: params['language'] = tts_engine.languages[0] - if text in temp_audio_files: audio = temp_audio_files[text] else: @@ -97,8 +127,9 @@ def say(tts_engine, text, cache=False, warmup_only=False): tts_done = True except BaseException as e: err_msg = f"WARNING: TTS model {settings.TTS_MODEL_NAME!r} threw error {e}. Retrying. If this keeps failing, override the TTS_MODEL_NAME setting, and/or file a bug if it's the default setting." - print(err_msg, file=sys.stderr) - print(err_msg, file=sys.stdout) + for out_fp in (sys.stderr, sys.stdout): + # TODO: proper logging :D + print(err_msg, file=out_fp) time.sleep(1) # because the loop doesn't respond to Ctrl-C otherwise audio = AudioSegment.from_wav(audio_file.name) @@ -125,20 +156,30 @@ def strip_stop_words(response: str) -> Optional[str]: def warm_up_stt_engine(stt_engine, source): # warm up / initialize the speech-to-text engine - audio = stt_engine.listen(source, 0) + if source.stream is None: + print("ERROR: SpeechRecognition/pyaudio microphone failed to initialize. This seems to be a bug in the pyaudio or SpeechRecognition libraries, but check your MICROPHONE_DEVICE_INDEX setting?", file=sys.stderr) - recognize = lambda audio: stt_engine.recognize_whisper(audio, model=settings.WHISPER_MODEL) + class DummyStream: + def close(self): + pass - recognize(audio) + source.stream = DummyStream() # workaround to allow us to exit cleanly despite the ContextManager bug + return False + + audio = stt_engine.listen(source, 0) + stt_engine.recognize_whisper(audio, model=settings.WHISPER_MODEL) + + return True def warm_up_tts_engine(tts_engine): # warm up / initialize the text-to-speech engine - common_responses_to_cache = ( + common_responses_to_cache = [ settings.SILENT_PERIOD_PROMPT, - settings.THINKING, settings.NON_COMMITTAL_RESPONSE, - ) + ] + if settings.SLOW_AI_RESPONSES: + common_responses_to_cache.append(settings.THINKING) for common_response_to_cache in common_responses_to_cache: say(tts_engine, common_response_to_cache, cache=True, warmup_only=True) @@ -149,10 +190,21 @@ def get_assistant_response(tts_engine, context: str, chat_log: List[str], assist stripped_response_text = None while not stripped_response_text: - # Try to naturally let the user know that this will take a while - say(tts_engine, settings.THINKING, cache=True) + response_text = None + while True: + if settings.SLOW_AI_RESPONSES: + # Try to naturally let the user know that this will take a while + say(tts_engine, settings.THINKING, cache=True) + + response_text = prompt_ai(conversation_so_far) + + if response_text is None: + time.sleep(2) + print("Retrying request to KoboldAI API", file=sys.stderr) + continue + else: + break - response_text = prompt_ai(conversation_so_far) stripped_response_text = strip_stop_words(response_text) # TODO: Handle bad responses by looping with varying @@ -232,20 +284,37 @@ def serve(): # set up microphone and speech recognition stt_engine = stt.Recognizer() - mic = stt.Microphone(device_index=settings.MICROPHONE_DEVICE_INDEX) - - # configure speech recognition stt_engine.energy_threshold = settings.STT_ENERGY_THRESHOLD + mic_device_index = get_microphone_device_id(stt.Microphone) + mic = stt.Microphone(device_index=mic_device_index) + + if mic is None: + print("ERROR: couldn't find a working microphone on this system! Connect/enable one, or set MICROPHONE_DEVICE_INDEX in the settings to force its selection.", file=sys.stderr) + return 1 # error exit code + chat_log = [] with mic as source: - warm_up_stt_engine(stt_engine, source) + if settings.AUTO_CALIBRATE_MIC is True: + print(f"Calibrating microphone; please wait {settings.AUTO_CALIBRATE_MIC_SECONDS} seconds (warning: this doesn't seem to work, and might result in the AI not hearing your speech!) ...") + stt_engine.adjust_for_ambient_noise(source, duration=settings.AUTO_CALIBRATE_MIC_SECONDS) + print(f"Calibration complete.") + + print("Initializing models and caching some data. Please wait, it could take a few minutes.") + + if not warm_up_stt_engine(stt_engine, source): + print("ERROR: couldn't initialise the speech-to-text engine! Check previous error messages.", file=sys.stderr) + return 1 # error exit code + warm_up_tts_engine(tts_engine) initial_log_line = f"{settings.ASSISTANT_NAME}: {settings.FULL_ASSISTANT_GREETING}" print(initial_log_line) chat_log.append(initial_log_line) + + print("Ready to go.") + say(tts_engine, settings.FULL_ASSISTANT_GREETING) # main dialog loop @@ -268,6 +337,8 @@ def serve(): def main(): + print(f"Loaded settings from {settings.__file__}") + parser = argparse.ArgumentParser() parser.add_argument('mode', choices=('serve', 'list-mics',)) diff --git a/src/kobold_assistant/default_settings.py b/src/kobold_assistant/default_settings.py index ddd8792..f942c82 100644 --- a/src/kobold_assistant/default_settings.py +++ b/src/kobold_assistant/default_settings.py @@ -1,8 +1,10 @@ -MICROPHONE_DEVICE_INDEX = 0 +MICROPHONE_DEVICE_INDEX = None USER_NAME = "User" -GENERATE_URL = "http://localhost:5001/api/v1/generate" +SLOW_AI_RESPONSES = False + +GENERATE_URL = "http://localhost:5000/api/v1/generate" GENERATE_TEMPERATURE = 0.7 ASSISTANT_NAME = "Jenny" @@ -23,6 +25,20 @@ {ASSISTANT_NAME}: It's 4. It's a multiplication; pronounced "two times two". Would you like to know more about multiplication? +{USER_NAME}: Why did the chicken cross the road? + +{ASSISTANT_NAME}: I don't know, why did the chicken cross the road? + +{USER_NAME}: To get to the other side! + +{ASSISTANT_NAME}: [laugh] very funny, {USER_NAME}. Here's another: why did the chicken cross the road? + +{USER_NAME}: I don't know, why? + +{ASSISTANT_NAME}: No one knows. But the road will have its vengeance!! [laugh] + +{USER_NAME}: ha ha ha + """ AI_MODEL_STOP_WORDS = ( @@ -74,3 +90,7 @@ 'Good boy.', ) +# this doesn't seem to work well +AUTO_CALIBRATE_MIC = True +AUTO_CALIBRATE_MIC_SECONDS = 5 +