Skip to content

Commit

Permalink
Support KoboldAI's API, not just KoboldCPP. Improve mic handling.
Browse files Browse the repository at this point in the history
- Support the KoboldAI API, not just KoboldCPP.
- Autodetect the first suitable microphone, by default.
- Support auto-calibrating the microphone, and do that by default.
- Add a setting for SLOW_AI_RESPONSES, off by default, for quicker
models / machines where realtime conversations are possible.
- Improve the prompt to try to tell the AI how to joke around.
- Plus various other small changes

Signed-off-by: Lee Braiden <leebraid@gmail.com>
  • Loading branch information
lee-b committed May 22, 2023
1 parent 04dc255 commit 9c17aeb
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 34 deletions.
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.poetry]
name = "kobold_assistant"
name = "kobold-assistant"
version = "0.1.1"
description = ""
authors = ["Lee Braiden <leebraid@gmail.com>"]
Expand Down
127 changes: 99 additions & 28 deletions src/kobold_assistant/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from tempfile import NamedTemporaryFile
from typing import List, Optional, Tuple

import pyaudio
import speech_recognition as stt
import torch
import torchaudio

import speech_recognition as stt
from TTS.api import TTS
from pydub import AudioSegment
from pydub.playback import play as play_audio_segment
Expand All @@ -34,35 +34,66 @@
settings = default_settings


print(f"Loaded settings from {settings.__file__}")


def text_to_phonemes(text: str) -> str:
# passthrough, since we (try to) get around this by prompting
# the AI to spell-out any abbreviations instead, for now.
# (but it doesn't work with the current prompt)
return text


def get_microphone_device_id(Microphone) -> int:
"""
Returns the users chosen settings.MICROPHONE_DEVICE_INDEX,
or if that is None, returns the microphone device index
for the first working microphone, as a best guess
"""
if settings.MICROPHONE_DEVICE_INDEX is not None:
return settings.MICROPHONE_DEVICE_INDEX

working_mics = Microphone.list_working_microphones()
if working_mics:
return list(working_mics.keys())[0]

return None


def prompt_ai(prompt: str) -> str:
post_data = json.dumps({
post_data = {
'prompt': prompt,
'temperature': settings.GENERATE_TEMPERATURE,
}).encode('utf8')
}

post_json = json.dumps(post_data)
post_json_bytes = post_json.encode('utf8')

print(f"settings.GENERATE_URL is {settings.GENERATE_URL!r}")
req = urllib.request.Request(settings.GENERATE_URL, method="POST")
req.add_header('Content-Type', 'application/json; charset=utf-8')

try:
response_obj = urllib.request.urlopen(req, data=post_json_bytes)
response_charset = response_obj.info().get_param('charset') or 'utf-8'
json_response = json.loads(response_obj.read().decode(response_charset))

with urllib.request.urlopen(settings.GENERATE_URL, post_data) as f:
response_json = f.read().decode('utf-8')
response = json.loads(response_json)['results'][0]['text']
try:
return json_response['results'][0]['text']
except (KeyError, IndexError) as e:
print("ERROR: KoboldAI API returned an unexpected response format!", file=sys.stderr)
return None

print(f"The AI returned {response!r}")
except urllib.error.URLError as e:
print(f"ERROR: the KoboldAI API returned {e!r}!", file=sys.stderr)
json_response = None

return response
return json_response


# horrible global hack for now; will get fixed
temp_audio_files = {}


def say(tts_engine, text, cache=False, warmup_only=False):
# horrible global hack for now; will get fixed
global temp_audio_files

# TODO: Choose (or obtain from config) the best speaker
Expand All @@ -75,7 +106,6 @@ def say(tts_engine, text, cache=False, warmup_only=False):
if tts_engine.languages is not None and len(tts_engine.languages) > 0:
params['language'] = tts_engine.languages[0]


if text in temp_audio_files:
audio = temp_audio_files[text]
else:
Expand All @@ -97,8 +127,9 @@ def say(tts_engine, text, cache=False, warmup_only=False):
tts_done = True
except BaseException as e:
err_msg = f"WARNING: TTS model {settings.TTS_MODEL_NAME!r} threw error {e}. Retrying. If this keeps failing, override the TTS_MODEL_NAME setting, and/or file a bug if it's the default setting."
print(err_msg, file=sys.stderr)
print(err_msg, file=sys.stdout)
for out_fp in (sys.stderr, sys.stdout):
# TODO: proper logging :D
print(err_msg, file=out_fp)
time.sleep(1) # because the loop doesn't respond to Ctrl-C otherwise

audio = AudioSegment.from_wav(audio_file.name)
Expand All @@ -125,20 +156,30 @@ def strip_stop_words(response: str) -> Optional[str]:

def warm_up_stt_engine(stt_engine, source):
# warm up / initialize the speech-to-text engine
audio = stt_engine.listen(source, 0)
if source.stream is None:
print("ERROR: SpeechRecognition/pyaudio microphone failed to initialize. This seems to be a bug in the pyaudio or SpeechRecognition libraries, but check your MICROPHONE_DEVICE_INDEX setting?", file=sys.stderr)

recognize = lambda audio: stt_engine.recognize_whisper(audio, model=settings.WHISPER_MODEL)
class DummyStream:
def close(self):
pass

recognize(audio)
source.stream = DummyStream() # workaround to allow us to exit cleanly despite the ContextManager bug

return False

audio = stt_engine.listen(source, 0)
stt_engine.recognize_whisper(audio, model=settings.WHISPER_MODEL)

return True

def warm_up_tts_engine(tts_engine):
# warm up / initialize the text-to-speech engine
common_responses_to_cache = (
common_responses_to_cache = [
settings.SILENT_PERIOD_PROMPT,
settings.THINKING,
settings.NON_COMMITTAL_RESPONSE,
)
]
if settings.SLOW_AI_RESPONSES:
common_responses_to_cache.append(settings.THINKING)

for common_response_to_cache in common_responses_to_cache:
say(tts_engine, common_response_to_cache, cache=True, warmup_only=True)
Expand All @@ -149,10 +190,21 @@ def get_assistant_response(tts_engine, context: str, chat_log: List[str], assist

stripped_response_text = None
while not stripped_response_text:
# Try to naturally let the user know that this will take a while
say(tts_engine, settings.THINKING, cache=True)
response_text = None
while True:
if settings.SLOW_AI_RESPONSES:
# Try to naturally let the user know that this will take a while
say(tts_engine, settings.THINKING, cache=True)

response_text = prompt_ai(conversation_so_far)

if response_text is None:
time.sleep(2)
print("Retrying request to KoboldAI API", file=sys.stderr)
continue
else:
break

response_text = prompt_ai(conversation_so_far)
stripped_response_text = strip_stop_words(response_text)

# TODO: Handle bad responses by looping with varying
Expand Down Expand Up @@ -232,20 +284,37 @@ def serve():

# set up microphone and speech recognition
stt_engine = stt.Recognizer()
mic = stt.Microphone(device_index=settings.MICROPHONE_DEVICE_INDEX)

# configure speech recognition
stt_engine.energy_threshold = settings.STT_ENERGY_THRESHOLD

mic_device_index = get_microphone_device_id(stt.Microphone)
mic = stt.Microphone(device_index=mic_device_index)

if mic is None:
print("ERROR: couldn't find a working microphone on this system! Connect/enable one, or set MICROPHONE_DEVICE_INDEX in the settings to force its selection.", file=sys.stderr)
return 1 # error exit code

chat_log = []

with mic as source:
warm_up_stt_engine(stt_engine, source)
if settings.AUTO_CALIBRATE_MIC is True:
print(f"Calibrating microphone; please wait {settings.AUTO_CALIBRATE_MIC_SECONDS} seconds (warning: this doesn't seem to work, and might result in the AI not hearing your speech!) ...")
stt_engine.adjust_for_ambient_noise(source, duration=settings.AUTO_CALIBRATE_MIC_SECONDS)
print(f"Calibration complete.")

print("Initializing models and caching some data. Please wait, it could take a few minutes.")

if not warm_up_stt_engine(stt_engine, source):
print("ERROR: couldn't initialise the speech-to-text engine! Check previous error messages.", file=sys.stderr)
return 1 # error exit code

warm_up_tts_engine(tts_engine)

initial_log_line = f"{settings.ASSISTANT_NAME}: {settings.FULL_ASSISTANT_GREETING}"
print(initial_log_line)
chat_log.append(initial_log_line)

print("Ready to go.")

say(tts_engine, settings.FULL_ASSISTANT_GREETING)

# main dialog loop
Expand All @@ -268,6 +337,8 @@ def serve():


def main():
print(f"Loaded settings from {settings.__file__}")

parser = argparse.ArgumentParser()
parser.add_argument('mode', choices=('serve', 'list-mics',))

Expand Down
24 changes: 22 additions & 2 deletions src/kobold_assistant/default_settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
MICROPHONE_DEVICE_INDEX = 0
MICROPHONE_DEVICE_INDEX = None

USER_NAME = "User"

GENERATE_URL = "http://localhost:5001/api/v1/generate"
SLOW_AI_RESPONSES = False

GENERATE_URL = "http://localhost:5000/api/v1/generate"
GENERATE_TEMPERATURE = 0.7

ASSISTANT_NAME = "Jenny"
Expand All @@ -23,6 +25,20 @@
{ASSISTANT_NAME}: It's 4. It's a multiplication; pronounced "two times two". Would you like to know more about multiplication?
{USER_NAME}: Why did the chicken cross the road?
{ASSISTANT_NAME}: I don't know, why did the chicken cross the road?
{USER_NAME}: To get to the other side!
{ASSISTANT_NAME}: [laugh] very funny, {USER_NAME}. Here's another: why did the chicken cross the road?
{USER_NAME}: I don't know, why?
{ASSISTANT_NAME}: No one knows. But the road will have its vengeance!! [laugh]
{USER_NAME}: ha ha ha
"""

AI_MODEL_STOP_WORDS = (
Expand Down Expand Up @@ -74,3 +90,7 @@
'Good boy.',
)

# this doesn't seem to work well
AUTO_CALIBRATE_MIC = True
AUTO_CALIBRATE_MIC_SECONDS = 5

0 comments on commit 9c17aeb

Please sign in to comment.