Skip to content

Commit

Permalink
Fixed typing and formatting problems
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed Sep 1, 2023
1 parent e5c1721 commit 4310374
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 5 deletions.
4 changes: 2 additions & 2 deletions paperqa/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.memory.chat_memory import BaseChatMemory
from langchain.prompts import BasePromptTemplate, PromptTemplate, StringPromptTemplate
from langchain.prompts import PromptTemplate, StringPromptTemplate
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import LLMResult, SystemMessage

Expand Down Expand Up @@ -45,7 +45,7 @@ async def agenerate(


class ExtendedHumanMessagePromptTemplate(HumanMessagePromptTemplate):
prompt: BasePromptTemplate
prompt: StringPromptTemplate


def make_chain(
Expand Down
5 changes: 3 additions & 2 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .types import Answer, CallbackFactory, Context, Doc, DocKey, PromptCollection, Text
from .utils import (
gather_with_concurrency,
get_llm_name,
guess_is_4xx,
maybe_is_html,
maybe_is_pdf,
Expand Down Expand Up @@ -316,7 +317,7 @@ async def adoc_match(
try:
if (
rerank is None
and cast(BaseLanguageModel, self.llm).model_name.startswith("gpt-4")
and get_llm_name(cast(BaseLanguageModel, self.llm)).startswith("gpt-4")
or rerank is True
):
chain = make_chain(
Expand Down Expand Up @@ -538,7 +539,7 @@ async def process(match):
context_str = "\n\n".join(
[
f"{c.text.name}: {c.context}"
+ (f" Based on {c.text.doc.citation}" if detailed_citations else "")
+ (f"\n\n Based on {c.text.doc.citation}" if detailed_citations else "")
for c in answer.contexts
]
)
Expand Down
8 changes: 8 additions & 0 deletions paperqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import BinaryIO, List

import pypdf
from langchain.base_language import BaseLanguageModel

from .types import StrPath

Expand Down Expand Up @@ -89,3 +90,10 @@ def guess_is_4xx(msg: str) -> bool:
if re.search(r"4\d\d", msg):
return True
return False


def get_llm_name(llm: BaseLanguageModel) -> str:
try:
return llm.model_name # type: ignore
except AttributeError:
return llm.model # type: ignore
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.9.0"
__version__ = "3.9.1"

0 comments on commit 4310374

Please sign in to comment.