diff --git a/ask/anthropic.py b/ask/anthropic.py new file mode 100644 index 0000000..68079c0 --- /dev/null +++ b/ask/anthropic.py @@ -0,0 +1,55 @@ +import os +from anthropic import Anthropic +from typing import Optional + +from anthropic.types import TextBlock +from .bot_service import BotService +from .renderer import AbstractRenderer + +ANTHROPIC_API_KEY_NAME = "ANTHROPIC_API_KEY" + + +class AnthropicService(BotService): + def __init__(self, prompt, renderer: AbstractRenderer, line_target=0) -> None: + self.renderer = renderer + if ANTHROPIC_API_KEY_NAME not in os.environ: + self.renderer.print( + f""" + + Please get a Anthropic API key from https://console.anthropic.com/ + and set in the environment variable {ANTHROPIC_API_KEY_NAME} + + """ + ) + self._available = False + return + self._available = True + + self.client = Anthropic( + api_key=os.environ.get(ANTHROPIC_API_KEY_NAME), + ) + self._available = True + + @property + def available(self) -> bool: + return self._available + + def process(self, user_input: str) -> Optional[str]: + try: + message = self.client.messages.create( + max_tokens=4906, + messages=[ + { + "role": "user", + "content": user_input, + } + ], + model="claude-3-5-sonnet-20240620", + ) + content = message.content[0] + if type(content) is TextBlock: + return content.text + except Exception as e: + print(f"\nCannot process prompt \n{user_input}\n", e) + + return None diff --git a/ask/ask.py b/ask/ask.py index f46054c..4f726b4 100644 --- a/ask/ask.py +++ b/ask/ask.py @@ -13,6 +13,7 @@ from .transcribe import register_transcribed_text, stop_transcribe from .config import load_config, default_parse_args from .gemini import Gemini +from .anthropic import AnthropicService from .ollama import Ollama from .renderer import RichRenderer, AbstractRenderer from .bot_service import BotService @@ -62,6 +63,8 @@ def run( match config.service.provider.lower(): case "ollama": Service = Ollama + case "anthropic": + Service = AnthropicService case _: Service = Gemini service = Service(renderer=renderer, prompt=prompt, line_target=config.line_target) diff --git a/ask/bot_service.py b/ask/bot_service.py index 1a7a7a6..f0cb12d 100644 --- a/ask/bot_service.py +++ b/ask/bot_service.py @@ -12,7 +12,7 @@ def __init__( pass @abstractmethod - def process(self, user_input: Optional[str]) -> Optional[str]: + def process(self, user_input: str) -> Optional[str]: pass @property diff --git a/ask/ollama.py b/ask/ollama.py index 3f6cf6e..863e8c0 100644 --- a/ask/ollama.py +++ b/ask/ollama.py @@ -8,7 +8,7 @@ class Ollama(BotService): def available(self) -> bool: return True - def process(self, user_input: Optional[str]) -> Optional[str]: + def process(self, user_input: str) -> Optional[str]: response = requests.post( "http://localhost:11434/api/generate", json={"model": "qwen2.5:1.5b", "prompt": user_input, "stream": False}, diff --git a/ask/tests/conftest.py b/ask/tests/conftest.py index 66a9db3..234d176 100644 --- a/ask/tests/conftest.py +++ b/ask/tests/conftest.py @@ -1,9 +1,13 @@ import os from ..gemini import API_KEY_NAME +from ..anthropic import ANTHROPIC_API_KEY_NAME # Safety check to ensure that API_KEY is not passed into unit tests since this # could have unintended side effect from environment leaking in if API_KEY_NAME in os.environ: os.environ.pop(API_KEY_NAME) +if ANTHROPIC_API_KEY_NAME in os.environ: + os.environ.pop(ANTHROPIC_API_KEY_NAME) assert API_KEY_NAME not in os.environ +assert ANTHROPIC_API_KEY_NAME not in os.environ diff --git a/requirements.txt b/requirements.txt index e8b8fe0..32a4225 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +anthropic configparser google-generativeai prompt_toolkit