Skip to content

Commit

Permalink
Adjusted anthropic import (#242)
Browse files Browse the repository at this point in the history
* Adjusted anthropic import

* Fixed sys prompts
  • Loading branch information
whitead committed Mar 5, 2024
1 parent ae633f4 commit 2c5da97
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
2 changes: 2 additions & 0 deletions paperqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
OpenAIEmbeddingModel,
LangchainLLMModel,
OpenAILLMModel,
AnthropicLLMModel,
LlamaEmbeddingModel,
NumpyVectorStore,
LangchainVectorStore,
Expand All @@ -26,6 +27,7 @@
"EmbeddingModel",
"OpenAIEmbeddingModel",
"OpenAILLMModel",
"AnthropicLLMModel",
"LangchainLLMModel",
"LlamaEmbeddingModel",
"SentenceTransformerEmbeddingModel",
Expand Down
37 changes: 31 additions & 6 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,18 +373,43 @@ def set_model_name(cls, data: Any) -> Any:

async def achat(self, client: Any, messages: list[dict[str, str]]) -> str:
aclient = self._check_client(client)
completion = await aclient.messages.create(
messages=messages, **process_llm_config(self.config, "max_tokens")
# filter out system
sys_message = next(
(m["content"] for m in messages if m["role"] == "system"), None
)
# BECAUISE THEY DO NOT USE NONE TO INDICATE SENTINEL
# LIKE ANY SANE PERSON
if sys_message:
completion = await aclient.messages.create(
system=sys_message,
messages=[m for m in messages if m["role"] != "system"],
**process_llm_config(self.config, "max_tokens"),
)
else:
completion = await aclient.messages.create(
messages=[m for m in messages if m["role"] != "system"],
**process_llm_config(self.config, "max_tokens"),
)
return completion.content or ""

async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any:
aclient = self._check_client(client)
completion = await aclient.messages.create(
messages=messages,
**process_llm_config(self.config, "max_tokens"),
stream=True,
sys_message = next(
(m["content"] for m in messages if m["role"] == "system"), None
)
if sys_message:
completion = await aclient.messages.create(
stream=True,
system=sys_message,
messages=[m for m in messages if m["role"] != "system"],
**process_llm_config(self.config, "max_tokens"),
)
else:
completion = await aclient.messages.create(
stream=True,
messages=[m for m in messages if m["role"] != "system"],
**process_llm_config(self.config, "max_tokens"),
)
async for event in completion:
if isinstance(event, ContentBlockDeltaEvent):
yield event.delta.text
Expand Down
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "4.0.0-pre.10"
__version__ = "4.0.0-pre.11"

0 comments on commit 2c5da97

Please sign in to comment.