Skip to content

Commit

Permalink
Added anthropic LLMs (#241)
Browse files Browse the repository at this point in the history
* Added anthropic LLM back

* Added back anthropic

* Added anthropic to dev requirements for testing

* Checked anthropic API key
  • Loading branch information
whitead committed Mar 5, 2024
1 parent 72aa100 commit ae633f4
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 5 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
pip install .
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
- name: Build a binary wheel and a source tarball
run: |
python -m build --sdist --wheel --outdir dist/ .
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ jobs:
- name: Run Test
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
run: |
pytest tests
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ langchain_openai
langchain_community
faiss-cpu
sentence_transformers
anthropic
66 changes: 62 additions & 4 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,19 @@ def is_openai_model(model_name) -> bool:
)


def process_llm_config(llm_config: dict) -> dict:
def process_llm_config(llm_config: dict, max_token_name: str = "max_tokens") -> dict:
"""Remove model_type and try to set max_tokens"""
result = {k: v for k, v in llm_config.items() if k != "model_type"}
if "max_tokens" not in result or result["max_tokens"] == -1:
if max_token_name not in result or result[max_token_name] == -1:
model = llm_config["model"]
# now we guess - we could use tiktoken to count,
# but do have the initative right now
if model.startswith("gpt-4") or (
model.startswith("gpt-3.5") and "1106" in model
):
result["max_tokens"] = 3000
result[max_token_name] = 3000
else:
result["max_tokens"] = 1500
result[max_token_name] = 1500
return result


Expand Down Expand Up @@ -336,6 +336,64 @@ async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any:
yield chunk.choices[0].delta.content


try:
from anthropic import AsyncAnthropic
from anthropic.types import ContentBlockDeltaEvent

class AnthropicLLMModel(LLMModel):
config: dict = Field(
default=dict(model="claude-3-sonnet-20240229", temperature=0.1)
)
name: str = "claude-3-sonnet-20240229"

def _check_client(self, client: Any) -> AsyncAnthropic:
if client is None:
raise ValueError(
"Your client is None - did you forget to set it after pickling?"
)
if not isinstance(client, AsyncAnthropic):
raise ValueError(
f"Your client is not a required AsyncAnthropic client. It is a {type(client)}"
)
return client

@model_validator(mode="after")
@classmethod
def set_llm_type(cls, data: Any) -> Any:
m = cast(AnthropicLLMModel, data)
m.llm_type = "chat"
return m

@model_validator(mode="after")
@classmethod
def set_model_name(cls, data: Any) -> Any:
m = cast(AnthropicLLMModel, data)
m.name = m.config["model"]
return m

async def achat(self, client: Any, messages: list[dict[str, str]]) -> str:
aclient = self._check_client(client)
completion = await aclient.messages.create(
messages=messages, **process_llm_config(self.config, "max_tokens")
)
return completion.content or ""

async def achat_iter(self, client: Any, messages: list[dict[str, str]]) -> Any:
aclient = self._check_client(client)
completion = await aclient.messages.create(
messages=messages,
**process_llm_config(self.config, "max_tokens"),
stream=True,
)
async for event in completion:
if isinstance(event, ContentBlockDeltaEvent):
yield event.delta.text
# yield event.message.content

except ImportError:
pass


class LlamaEmbeddingModel(EmbeddingModel):
embedding_model: str = Field(default="llama")

Expand Down
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "4.0.0-pre.9"
__version__ = "4.0.0-pre.10"
26 changes: 26 additions & 0 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
print_callback,
)
from paperqa.llms import (
AnthropicLLMModel,
EmbeddingModel,
LangchainEmbeddingModel,
LangchainLLMModel,
Expand Down Expand Up @@ -454,6 +455,31 @@ async def ac(x):

completion = await call(dict(animal="duck"), callbacks=[accum, ac]) # type: ignore[call-arg]

async def test_anthropic_chain(self):
from anthropic import AsyncAnthropic

client = AsyncAnthropic()
llm = AnthropicLLMModel()
call = llm.make_chain(
client,
"The {animal} says",
skip_system=True,
)

def accum(x):
outputs.append(x)

outputs: list[str] = []
completion = await call(dict(animal="duck"), callbacks=[accum]) # type: ignore[call-arg]
assert completion.seconds_to_first_token > 0
assert completion.prompt_count > 0
assert completion.completion_count > 0
assert str(completion) == "".join(outputs)

completion = await call(dict(animal="duck")) # type: ignore[call-arg]
assert completion.seconds_to_first_token == 0
assert completion.seconds_to_last_token > 0


def test_docs():
docs = Docs(llm="babbage-002")
Expand Down

0 comments on commit ae633f4

Please sign in to comment.