Skip to content

Commit

Permalink
♻️ Use capturing rendererer in one e2e test
Browse files Browse the repository at this point in the history
To trial as alternative to sys.stdout, since that can have side effect
of capturing other output
  • Loading branch information
ianhomer committed Sep 24, 2024
1 parent 0f84457 commit 0266ada
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 27 deletions.
10 changes: 6 additions & 4 deletions ask/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def main(
Service: type[BotService] = Gemini,
Renderer: type[AbstractRenderer] = RichRenderer,
parse_args=parse_args,
) -> None:
) -> AbstractRenderer:
global transcribe_thread
args = parse_args()

Expand All @@ -75,9 +75,9 @@ def main(
prompt, file_input = get_prompt(args.inputs, args.template)

if args.dry:
print("Prompt : ")
print(prompt)
return
renderer.print("Prompt : ")
renderer.print(prompt)
return renderer

service = Service(renderer=renderer, prompt=prompt, line_target=args.line_target)

Expand Down Expand Up @@ -106,6 +106,8 @@ def process(user_input, response_text: Optional[str] = None) -> Optional[str]:

response_text = process(user_input, response_text)

return renderer


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions ask/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

class Gemini(BotService):
def __init__(self, prompt, renderer: AbstractRenderer, line_target=0) -> None:
self.renderer = renderer
if API_KEY_NAME not in os.environ:
print(
self.renderer.print(
f"""
Please get a Gemini API key from https://aistudio.google.com/
Expand All @@ -28,7 +29,6 @@ def __init__(self, prompt, renderer: AbstractRenderer, line_target=0) -> None:
return
self._available = True
api_key = os.environ[API_KEY_NAME]
self.renderer = renderer

genai.configure(api_key=api_key)
model = genai.GenerativeModel("gemini-1.5-flash")
Expand Down
39 changes: 27 additions & 12 deletions ask/renderer.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,58 @@
from abc import abstractmethod
from rich import print
from typing import Optional
from typing import Optional, List
from rich.markdown import Markdown


class AbstractRenderer:
def __init__(self, pretty_markdown=True) -> None:
self.pretty_markdown = pretty_markdown
self._messages: List[str] = []

def record(self, message):
self._messages.append(message)

@abstractmethod
def __init__(self, pretty_markdown: bool) -> None:
def print(self, message):
pass

@abstractmethod
def print_processing(self):
pass
self.print("...\n")

@abstractmethod
def print_response(self, response_text: Optional[str]):
pass
if response_text:
self.print(response_text)

@abstractmethod
def print_message(self, message: str):
pass
self.print(message)

@property
def messages(self) -> List[str]:
return self._messages

@property
def body(self) -> str:
return "\n".join(self._messages)


class RichRenderer(AbstractRenderer):
def __init__(self, pretty_markdown=True) -> None:
self.pretty_markdown = pretty_markdown
def print(self, message):
print(message)

def print_processing(self):
print(
self.print(
"[bold bright_yellow] -) ... ...[/bold bright_yellow]\n"
)

def print_message(self, message: str):
print(f"[bold bright_yellow] -) {message} [/bold bright_yellow]\n")
self.print(f"[bold bright_yellow] -) {message} [/bold bright_yellow]\n")

def print_response(self, response_text: Optional[str]):
if response_text:
if self.pretty_markdown:
markdown = Markdown(response_text)
print(markdown)
self.print(markdown)
else:
print(response_text)
self.print(response_text)
9 changes: 9 additions & 0 deletions ask/tests/e2e_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable
from collections import deque
from ..input import InputInterrupt
from ..renderer import AbstractRenderer


def create_inputter(inputs=["mock input 1"]) -> Callable[[], str]:
Expand All @@ -25,3 +26,11 @@ def parse_args():
no_transcribe=True,
template=None,
)


class CapturingRenderer(AbstractRenderer):
def __init__(self, pretty_markdown: bool) -> None:
AbstractRenderer.__init__(self, pretty_markdown=pretty_markdown)

def print(self, message):
self.record(message)
16 changes: 7 additions & 9 deletions ask/tests/test_ask_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,20 @@


from ..ask import main
from .e2e_utils import parse_args, create_inputter
from .e2e_utils import parse_args, create_inputter, CapturingRenderer


@patch("google.generativeai.GenerativeModel")
def test_ask_gemini_key_required(GenerativeModel):
mock = GenerativeModel()
mock.start_chat().send_message().text = "mock-response"

with patch("sys.stdout", new=StringIO()) as captured_output:
main(inputter=create_inputter(), parse_args=parse_args)
assert (
"set in the environment variable GEMINI_API_KEY"
in captured_output.getvalue()
)
lines = [line for line in captured_output.getvalue().split("\n") if line]
assert len(lines) == 3
renderer = main(
inputter=create_inputter(), Renderer=CapturingRenderer, parse_args=parse_args
)
assert "set in the environment variable GEMINI_API_KEY" in renderer.messages[0]
lines = [line for line in renderer.body.split("\n") if line]
assert len(lines) == 3


@patch("google.generativeai.GenerativeModel")
Expand Down

0 comments on commit 0266ada

Please sign in to comment.