diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ff680cbc..096ca9ea 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -28,6 +28,7 @@ jobs: pip install . env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} - name: Build a binary wheel and a source tarball run: | python -m build --sdist --wheel --outdir dist/ . diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9034d40c..0dacf855 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,5 +33,6 @@ jobs: - name: Run Test env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} run: | pytest tests diff --git a/dev-requirements.txt b/dev-requirements.txt index 11052e63..dd2be414 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -11,3 +11,4 @@ langchain_openai langchain_community faiss-cpu sentence_transformers +anthropic diff --git a/paperqa/llms.py b/paperqa/llms.py index 1d70c4c5..db14da8b 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -69,19 +69,19 @@ def is_openai_model(model_name) -> bool: ) -def process_llm_config(llm_config: dict) -> dict: +def process_llm_config(llm_config: dict, max_token_name: str = "max_tokens") -> dict: """Remove model_type and try to set max_tokens""" result = {k: v for k, v in llm_config.items() if k != "model_type"} - if "max_tokens" not in result or result["max_tokens"] == -1: + if max_token_name not in result or result[max_token_name] == -1: model = llm_config["model"] # now we guess - we could use tiktoken to count, # but do have the initative right now if model.startswith("gpt-4") or ( model.startswith("gpt-3.5") and "1106" in model ): - result["max_tokens"] = 3000 + result[max_token_name] = 3000 else: - result["max_tokens"] = 1500 + result[max_token_name] = 1500 return result @@ -336,6 +336,64 @@ async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: yield chunk.choices[0].delta.content +try: + from anthropic import AsyncAnthropic + from anthropic.types import ContentBlockDeltaEvent + + class AnthropicLLMModel(LLMModel): + config: dict = Field( + default=dict(model="claude-3-sonnet-20240229", temperature=0.1) + ) + name: str = "claude-3-sonnet-20240229" + + def _check_client(self, client: Any) -> AsyncAnthropic: + if client is None: + raise ValueError( + "Your client is None - did you forget to set it after pickling?" + ) + if not isinstance(client, AsyncAnthropic): + raise ValueError( + f"Your client is not a required AsyncAnthropic client. It is a {type(client)}" + ) + return client + + @model_validator(mode="after") + @classmethod + def set_llm_type(cls, data: Any) -> Any: + m = cast(AnthropicLLMModel, data) + m.llm_type = "chat" + return m + + @model_validator(mode="after") + @classmethod + def set_model_name(cls, data: Any) -> Any: + m = cast(AnthropicLLMModel, data) + m.name = m.config["model"] + return m + + async def achat(self, client: Any, messages: list[dict[str, str]]) -> str: + aclient = self._check_client(client) + completion = await aclient.messages.create( + messages=messages, **process_llm_config(self.config, "max_tokens") + ) + return completion.content or "" + + async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: + aclient = self._check_client(client) + completion = await aclient.messages.create( + messages=messages, + **process_llm_config(self.config, "max_tokens"), + stream=True, + ) + async for event in completion: + if isinstance(event, ContentBlockDeltaEvent): + yield event.delta.text + # yield event.message.content + +except ImportError: + pass + + class LlamaEmbeddingModel(EmbeddingModel): embedding_model: str = Field(default="llama") diff --git a/paperqa/version.py b/paperqa/version.py index 10e958d2..216eb967 100644 --- a/paperqa/version.py +++ b/paperqa/version.py @@ -1 +1 @@ -__version__ = "4.0.0-pre.9" +__version__ = "4.0.0-pre.10" diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index cf2c2a7d..47186fb9 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -17,6 +17,7 @@ print_callback, ) from paperqa.llms import ( + AnthropicLLMModel, EmbeddingModel, LangchainEmbeddingModel, LangchainLLMModel, @@ -454,6 +455,31 @@ async def ac(x): completion = await call(dict(animal="duck"), callbacks=[accum, ac]) # type: ignore[call-arg] + async def test_anthropic_chain(self): + from anthropic import AsyncAnthropic + + client = AsyncAnthropic() + llm = AnthropicLLMModel() + call = llm.make_chain( + client, + "The {animal} says", + skip_system=True, + ) + + def accum(x): + outputs.append(x) + + outputs: list[str] = [] + completion = await call(dict(animal="duck"), callbacks=[accum]) # type: ignore[call-arg] + assert completion.seconds_to_first_token > 0 + assert completion.prompt_count > 0 + assert completion.completion_count > 0 + assert str(completion) == "".join(outputs) + + completion = await call(dict(animal="duck")) # type: ignore[call-arg] + assert completion.seconds_to_first_token == 0 + assert completion.seconds_to_last_token > 0 + def test_docs(): docs = Docs(llm="babbage-002")