Skip to content

Commit

Permalink
♻️ move input handler logic into separate class
Browse files Browse the repository at this point in the history
  • Loading branch information
ianhomer committed Sep 24, 2024
1 parent bbb2657 commit d1fc427
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 51 deletions.
20 changes: 12 additions & 8 deletions ask/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from .config import load_config
from .gemini import Gemini
from .renderer import RichRenderer, AbstractRenderer
from .service import BotService
from .bot_service import BotService
from .handler import InputHandler

transcribe_thread: Optional[threading.Thread] = None

Expand All @@ -24,8 +25,8 @@ def signal_handler(sig: int, frame: Optional[object]) -> None:
quit()


def quit() -> None:
print("\nBye ...")
def quit(renderer: AbstractRenderer) -> None:
renderer.print("\nBye ...")
if transcribe_thread:
stop_transcribe()

Expand Down Expand Up @@ -69,6 +70,7 @@ def main(
args = parse_args()

renderer = Renderer(pretty_markdown=not args.no_markdown)
input_handler = InputHandler(renderer=renderer)

file_input = False

Expand All @@ -81,9 +83,9 @@ def main(

service = Service(renderer=renderer, prompt=prompt, line_target=args.line_target)

def process(user_input, response_text: Optional[str] = None) -> Optional[str]:
def process(user_input) -> Optional[str]:
renderer.print_processing()
response_text = service.process(user_input, response_text)
response_text = service.process(user_input)
renderer.print_response(response_text)
return response_text

Expand All @@ -101,10 +103,12 @@ def process(user_input, response_text: Optional[str] = None) -> Optional[str]:
try:
user_input = inputter()
except InputInterrupt:
quit()
quit(renderer)
break

response_text = process(user_input, response_text)
if user_input and len(user_input) > 0:
input_handler_response = input_handler.handle(user_input, response_text)
if input_handler_response.process:
response_text = process(user_input)

return renderer

Expand Down
8 changes: 4 additions & 4 deletions ask/service.py → ask/bot_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

class BotService:
@abstractmethod
def __init__(self, prompt: str, line_target: int, renderer: AbstractRenderer) -> None:
def __init__(
self, prompt: str, line_target: int, renderer: AbstractRenderer
) -> None:
pass

@abstractmethod
def process(
self, user_input, previous_response_text: Optional[str] = None
) -> Optional[str]:
def process(self, user_input: Optional[str]) -> Optional[str]:
pass

@property
Expand Down
31 changes: 3 additions & 28 deletions ask/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
from google.generativeai.types import content_types
from typing import Optional

from ask.renderer import AbstractRenderer
from ask.service import BotService
from .renderer import AbstractRenderer
from .bot_service import BotService

from .save import save
from .copy import copy_code

API_KEY_NAME = "GEMINI_API_KEY"

Expand Down Expand Up @@ -52,34 +50,11 @@ def __init__(self, prompt, renderer: AbstractRenderer, line_target=0) -> None:
def available(self) -> bool:
return self._available

def process_user_input(self, user_input: str) -> Optional[str]:
def process(self, user_input: Optional[str]) -> Optional[str]:
try:
response = self.chat.send_message(user_input)
return response.text
except Exception as e:
print(f"\nCannot process prompt \n{user_input}\n", e)

return None

def process(
self, user_input, previous_response_text: Optional[str] = None
) -> Optional[str]:
user_input_lower = user_input.lower()
if user_input_lower == "save":
save(previous_response_text)
return previous_response_text

if previous_response_text and (
("copy code" in user_input_lower and len(user_input) < 12)
or ("copy" in user_input_lower and len(user_input) < 7)
):
copy_code(self.renderer, previous_response_text)
return previous_response_text

if user_input_lower.endswith("ignore"):
return None

if len(user_input) > 0:
return self.process_user_input(user_input)

return None
47 changes: 47 additions & 0 deletions ask/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from .renderer import AbstractRenderer
from typing import Optional

from .save import save
from .copy import copy_code


class InputHandlerResponse:
def __init__(self, ignore=False, process=True) -> None:
self._ignore = ignore
self._process = process

@property
def ignore(self):
return self._ignore

@property
def process(self):
return self._process

def __str__(self) -> str:
return f"ignore:{self._ignore}, process:{self._process}"


class InputHandler:
def __init__(self, renderer: AbstractRenderer) -> None:
self.renderer = renderer

def handle(
self, input: str, previous_response_text: Optional[str]
) -> InputHandlerResponse:
input_lower = input.lower()
if input_lower == "save":
save(previous_response_text)
return InputHandlerResponse(process=False)

if previous_response_text and (
("copy code" in input_lower and len(input) < 12)
or ("copy" in input_lower and len(input) < 7)
):
copy_code(self.renderer, previous_response_text)
return InputHandlerResponse(process=False)

if input_lower.endswith("ignore"):
return InputHandlerResponse(process=False, ignore=True)

return InputHandlerResponse()
2 changes: 1 addition & 1 deletion ask/tests/test_ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unittest.mock import patch
from typing import Optional

from ask.service import BotService
from ask.bot_service import BotService

from .e2e_utils import parse_args, create_inputter
from ..ask import main
Expand Down
37 changes: 27 additions & 10 deletions ask/tests/test_ask_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,34 @@ def test_ask_gemini_copy_code(GenerativeModel, clipboard_copy):
```
"""

with patch("sys.stdout", new=StringIO()) as captured_output:
main(
inputter=create_inputter(inputs=["mock input 1", "copy code"]),
parse_args=parse_args,
)
lines = [line for line in captured_output.getvalue().split("\n") if line]
assert lines[0] == " -) ... ..."
assert lines[1].split("\n")[0] == "mock-response"
assert lines[-1] == "Bye ..."
assert len(lines) == 12
renderer = main(
inputter=create_inputter(inputs=["mock input 1", "copy code"]),
Renderer=CapturingRenderer,
parse_args=parse_args,
)
lines = [line for line in renderer.body.split("\n") if line]
assert lines[0] == "..."
assert lines[1].split("\n")[0] == "mock-response"
assert lines[-1] == "Bye ..."
assert len(lines) == 7

copies = clipboard_copy.call_args_list
assert len(copies) == 1
assert "const a = 1" == copies[0][0][0]


@patch("google.generativeai.GenerativeModel")
@patch.dict(os.environ, {"GEMINI_API_KEY": "mock-api-key"})
def test_ask_gemini_empty_inputs(GenerativeModel):
mock = GenerativeModel()
mock.start_chat().send_message().text = "mock-response"
renderer = main(
inputter=create_inputter(inputs=["mock input 1", "", "", ""]),
Renderer=CapturingRenderer,
parse_args=parse_args,
)
lines = [line for line in renderer.body.split("\n") if line]
assert lines[0] == "..."
assert lines[1] == "mock-response"
assert lines[-1] == "Bye ..."
assert len(lines) == 3

0 comments on commit d1fc427

Please sign in to comment.