From 2c5da971449074009dc07fb81bf4df8136834a95 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Mon, 4 Mar 2024 17:46:14 -0800 Subject: [PATCH] Adjusted anthropic import (#242) * Adjusted anthropic import * Fixed sys prompts --- paperqa/__init__.py | 2 ++ paperqa/llms.py | 37 +++++++++++++++++++++++++++++++------ paperqa/version.py | 2 +- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/paperqa/__init__.py b/paperqa/__init__.py index f6a4a268..44a458ad 100644 --- a/paperqa/__init__.py +++ b/paperqa/__init__.py @@ -7,6 +7,7 @@ OpenAIEmbeddingModel, LangchainLLMModel, OpenAILLMModel, + AnthropicLLMModel, LlamaEmbeddingModel, NumpyVectorStore, LangchainVectorStore, @@ -26,6 +27,7 @@ "EmbeddingModel", "OpenAIEmbeddingModel", "OpenAILLMModel", + "AnthropicLLMModel", "LangchainLLMModel", "LlamaEmbeddingModel", "SentenceTransformerEmbeddingModel", diff --git a/paperqa/llms.py b/paperqa/llms.py index db14da8b..ea31091d 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -373,18 +373,43 @@ def set_model_name(cls, data: Any) -> Any: async def achat(self, client: Any, messages: list[dict[str, str]]) -> str: aclient = self._check_client(client) - completion = await aclient.messages.create( - messages=messages, **process_llm_config(self.config, "max_tokens") + # filter out system + sys_message = next( + (m["content"] for m in messages if m["role"] == "system"), None ) + # BECAUISE THEY DO NOT USE NONE TO INDICATE SENTINEL + # LIKE ANY SANE PERSON + if sys_message: + completion = await aclient.messages.create( + system=sys_message, + messages=[m for m in messages if m["role"] != "system"], + **process_llm_config(self.config, "max_tokens"), + ) + else: + completion = await aclient.messages.create( + messages=[m for m in messages if m["role"] != "system"], + **process_llm_config(self.config, "max_tokens"), + ) return completion.content or "" async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any: aclient = self._check_client(client) - completion = await aclient.messages.create( - messages=messages, - **process_llm_config(self.config, "max_tokens"), - stream=True, + sys_message = next( + (m["content"] for m in messages if m["role"] == "system"), None ) + if sys_message: + completion = await aclient.messages.create( + stream=True, + system=sys_message, + messages=[m for m in messages if m["role"] != "system"], + **process_llm_config(self.config, "max_tokens"), + ) + else: + completion = await aclient.messages.create( + stream=True, + messages=[m for m in messages if m["role"] != "system"], + **process_llm_config(self.config, "max_tokens"), + ) async for event in completion: if isinstance(event, ContentBlockDeltaEvent): yield event.delta.text diff --git a/paperqa/version.py b/paperqa/version.py index 216eb967..f5ee359b 100644 --- a/paperqa/version.py +++ b/paperqa/version.py @@ -1 +1 @@ -__version__ = "4.0.0-pre.10" +__version__ = "4.0.0-pre.11"