From d1fc4273e8ae684181206329da7294bbca6ee159 Mon Sep 17 00:00:00 2001 From: Ian Homer Date: Tue, 24 Sep 2024 12:33:39 +0100 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20move=20input=20handler=20l?= =?UTF-8?q?ogic=20into=20separate=20class?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ask/ask.py | 20 ++++++++----- ask/{service.py => bot_service.py} | 8 ++--- ask/gemini.py | 31 ++------------------ ask/handler.py | 47 ++++++++++++++++++++++++++++++ ask/tests/test_ask.py | 2 +- ask/tests/test_ask_gemini.py | 37 ++++++++++++++++------- 6 files changed, 94 insertions(+), 51 deletions(-) rename ask/{service.py => bot_service.py} (56%) create mode 100644 ask/handler.py diff --git a/ask/ask.py b/ask/ask.py index 3b3ae7f..596c0e9 100644 --- a/ask/ask.py +++ b/ask/ask.py @@ -13,7 +13,8 @@ from .config import load_config from .gemini import Gemini from .renderer import RichRenderer, AbstractRenderer -from .service import BotService +from .bot_service import BotService +from .handler import InputHandler transcribe_thread: Optional[threading.Thread] = None @@ -24,8 +25,8 @@ def signal_handler(sig: int, frame: Optional[object]) -> None: quit() -def quit() -> None: - print("\nBye ...") +def quit(renderer: AbstractRenderer) -> None: + renderer.print("\nBye ...") if transcribe_thread: stop_transcribe() @@ -69,6 +70,7 @@ def main( args = parse_args() renderer = Renderer(pretty_markdown=not args.no_markdown) + input_handler = InputHandler(renderer=renderer) file_input = False @@ -81,9 +83,9 @@ def main( service = Service(renderer=renderer, prompt=prompt, line_target=args.line_target) - def process(user_input, response_text: Optional[str] = None) -> Optional[str]: + def process(user_input) -> Optional[str]: renderer.print_processing() - response_text = service.process(user_input, response_text) + response_text = service.process(user_input) renderer.print_response(response_text) return response_text @@ -101,10 +103,12 @@ def process(user_input, response_text: Optional[str] = None) -> Optional[str]: try: user_input = inputter() except InputInterrupt: - quit() + quit(renderer) break - - response_text = process(user_input, response_text) + if user_input and len(user_input) > 0: + input_handler_response = input_handler.handle(user_input, response_text) + if input_handler_response.process: + response_text = process(user_input) return renderer diff --git a/ask/service.py b/ask/bot_service.py similarity index 56% rename from ask/service.py rename to ask/bot_service.py index fffcd2f..1a7a7a6 100644 --- a/ask/service.py +++ b/ask/bot_service.py @@ -6,13 +6,13 @@ class BotService: @abstractmethod - def __init__(self, prompt: str, line_target: int, renderer: AbstractRenderer) -> None: + def __init__( + self, prompt: str, line_target: int, renderer: AbstractRenderer + ) -> None: pass @abstractmethod - def process( - self, user_input, previous_response_text: Optional[str] = None - ) -> Optional[str]: + def process(self, user_input: Optional[str]) -> Optional[str]: pass @property diff --git a/ask/gemini.py b/ask/gemini.py index ff0b90b..02c4a23 100644 --- a/ask/gemini.py +++ b/ask/gemini.py @@ -4,11 +4,9 @@ from google.generativeai.types import content_types from typing import Optional -from ask.renderer import AbstractRenderer -from ask.service import BotService +from .renderer import AbstractRenderer +from .bot_service import BotService -from .save import save -from .copy import copy_code API_KEY_NAME = "GEMINI_API_KEY" @@ -52,7 +50,7 @@ def __init__(self, prompt, renderer: AbstractRenderer, line_target=0) -> None: def available(self) -> bool: return self._available - def process_user_input(self, user_input: str) -> Optional[str]: + def process(self, user_input: Optional[str]) -> Optional[str]: try: response = self.chat.send_message(user_input) return response.text @@ -60,26 +58,3 @@ def process_user_input(self, user_input: str) -> Optional[str]: print(f"\nCannot process prompt \n{user_input}\n", e) return None - - def process( - self, user_input, previous_response_text: Optional[str] = None - ) -> Optional[str]: - user_input_lower = user_input.lower() - if user_input_lower == "save": - save(previous_response_text) - return previous_response_text - - if previous_response_text and ( - ("copy code" in user_input_lower and len(user_input) < 12) - or ("copy" in user_input_lower and len(user_input) < 7) - ): - copy_code(self.renderer, previous_response_text) - return previous_response_text - - if user_input_lower.endswith("ignore"): - return None - - if len(user_input) > 0: - return self.process_user_input(user_input) - - return None diff --git a/ask/handler.py b/ask/handler.py new file mode 100644 index 0000000..ca9e8a8 --- /dev/null +++ b/ask/handler.py @@ -0,0 +1,47 @@ +from .renderer import AbstractRenderer +from typing import Optional + +from .save import save +from .copy import copy_code + + +class InputHandlerResponse: + def __init__(self, ignore=False, process=True) -> None: + self._ignore = ignore + self._process = process + + @property + def ignore(self): + return self._ignore + + @property + def process(self): + return self._process + + def __str__(self) -> str: + return f"ignore:{self._ignore}, process:{self._process}" + + +class InputHandler: + def __init__(self, renderer: AbstractRenderer) -> None: + self.renderer = renderer + + def handle( + self, input: str, previous_response_text: Optional[str] + ) -> InputHandlerResponse: + input_lower = input.lower() + if input_lower == "save": + save(previous_response_text) + return InputHandlerResponse(process=False) + + if previous_response_text and ( + ("copy code" in input_lower and len(input) < 12) + or ("copy" in input_lower and len(input) < 7) + ): + copy_code(self.renderer, previous_response_text) + return InputHandlerResponse(process=False) + + if input_lower.endswith("ignore"): + return InputHandlerResponse(process=False, ignore=True) + + return InputHandlerResponse() diff --git a/ask/tests/test_ask.py b/ask/tests/test_ask.py index 30ec1c2..55f060f 100644 --- a/ask/tests/test_ask.py +++ b/ask/tests/test_ask.py @@ -2,7 +2,7 @@ from unittest.mock import patch from typing import Optional -from ask.service import BotService +from ask.bot_service import BotService from .e2e_utils import parse_args, create_inputter from ..ask import main diff --git a/ask/tests/test_ask_gemini.py b/ask/tests/test_ask_gemini.py index a6e6d5a..6b71687 100644 --- a/ask/tests/test_ask_gemini.py +++ b/ask/tests/test_ask_gemini.py @@ -48,17 +48,34 @@ def test_ask_gemini_copy_code(GenerativeModel, clipboard_copy): ``` """ - with patch("sys.stdout", new=StringIO()) as captured_output: - main( - inputter=create_inputter(inputs=["mock input 1", "copy code"]), - parse_args=parse_args, - ) - lines = [line for line in captured_output.getvalue().split("\n") if line] - assert lines[0] == " -) ... ..." - assert lines[1].split("\n")[0] == "mock-response" - assert lines[-1] == "Bye ..." - assert len(lines) == 12 + renderer = main( + inputter=create_inputter(inputs=["mock input 1", "copy code"]), + Renderer=CapturingRenderer, + parse_args=parse_args, + ) + lines = [line for line in renderer.body.split("\n") if line] + assert lines[0] == "..." + assert lines[1].split("\n")[0] == "mock-response" + assert lines[-1] == "Bye ..." + assert len(lines) == 7 copies = clipboard_copy.call_args_list assert len(copies) == 1 assert "const a = 1" == copies[0][0][0] + + +@patch("google.generativeai.GenerativeModel") +@patch.dict(os.environ, {"GEMINI_API_KEY": "mock-api-key"}) +def test_ask_gemini_empty_inputs(GenerativeModel): + mock = GenerativeModel() + mock.start_chat().send_message().text = "mock-response" + renderer = main( + inputter=create_inputter(inputs=["mock input 1", "", "", ""]), + Renderer=CapturingRenderer, + parse_args=parse_args, + ) + lines = [line for line in renderer.body.split("\n") if line] + assert lines[0] == "..." + assert lines[1] == "mock-response" + assert lines[-1] == "Bye ..." + assert len(lines) == 3