Skip to content

Commit

Permalink
Update Literature Search tool (#159)
Browse files Browse the repository at this point in the history
* update literature search

* fixed typo

* also accept absolute paths

* removed paperscraper

* fixed rgy bug with saving path in path registry

* updated uniprot unit test
  • Loading branch information
qcampbel authored Oct 17, 2024
1 parent f4cecf4 commit 2b1926d
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 84 deletions.
3 changes: 2 additions & 1 deletion mdagent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@ def __init__(
uploaded_files=[], # user input files to add to path registry
run_id="",
use_memory=False,
paper_dir=None, # papers for pqa, relative path within repo
):
self.llm = _make_llm(model, temp, streaming)
if tools_model is None:
tools_model = model
self.tools_llm = _make_llm(tools_model, temp, streaming)

self.use_memory = use_memory
self.path_registry = PathRegistry.get_instance(ckpt_dir=ckpt_dir)
self.path_registry = PathRegistry.get_instance(ckpt_dir, paper_dir)
self.ckpt_dir = self.path_registry.ckpt_dir
self.memory = MemoryManager(self.path_registry, self.tools_llm, run_id=run_id)
self.run_id = self.memory.run_id
Expand Down
2 changes: 1 addition & 1 deletion mdagent/agent/prompt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from langchain.prompts import PromptTemplate

structured_prompt = PromptTemplate(
input_variables=["input, context"],
input_variables=["input", "context"],
template="""
You are an expert molecular dynamics scientist, and
your task is to respond to the question or
Expand Down
3 changes: 2 additions & 1 deletion mdagent/tools/base_tools/analysis_tools/rgy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def plot_rgy(self) -> str:
if plot_name.endswith(".png"):
plot_name = plot_name.split(".png")[0]
plot_path = f"{self.path_registry.ckpt_figures}/{plot_name}"
print("plot_path", plot_path)
plt.plot(rg_per_frame)
plt.xlabel("Frame")
plt.ylabel("Radius of Gyration (nm)")
Expand All @@ -77,7 +78,7 @@ def plot_rgy(self) -> str:
plt.savefig(f"{plot_path}")
self.path_registry.map_path(
plot_id,
plot_path,
plot_path + ".png",
description=f"Plot of radii of gyration over time for {self.traj_file}",
)
plt.close()
Expand Down
105 changes: 33 additions & 72 deletions mdagent/tools/base_tools/util_tools/search_tools.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,47 @@
import logging
import os
import re
from typing import Optional

import langchain
import nest_asyncio
import paperqa
import paperscraper
from langchain.base_language import BaseLanguageModel
from langchain.tools import BaseTool
from langchain_core.output_parsers import StrOutputParser
from pypdf.errors import PdfReadError

from mdagent.utils import PathRegistry


def configure_logging(path):
# to log all runtime errors from paperscraper, which can be VERY noisy
log_file = os.path.join(path, "scraping_errors.log")
logging.basicConfig(
filename=log_file,
level=logging.ERROR,
format="%(asctime)s:%(levelname)s:%(message)s",
)


def paper_scraper(search: str, pdir: str = "query") -> dict:
try:
return paperscraper.search_papers(search, pdir=pdir)
except KeyError:
return {}


def paper_search(llm, query, path_registry):
prompt = langchain.prompts.PromptTemplate(
input_variables=["question"],
template="""
I would like to find scholarly papers to answer
this question: {question}. Your response must be at
most 10 words long.
'A search query that would bring up papers that can answer
this question would be: '""",
)

path = f"{path_registry.ckpt_files}/query"
query_chain = prompt | llm | StrOutputParser()
if not os.path.isdir(path):
os.mkdir(path)
configure_logging(path)
search = query_chain.invoke(query)
print("\nSearch:", search)
papers = paper_scraper(search, pdir=f"{path}/{re.sub(' ', '', search)}")
return papers


def scholar2result_llm(llm, query, path_registry, k=5, max_sources=2):
"""Useful to answer questions that require
technical knowledge. Ask a specific question."""
if llm.model_name.startswith("gpt"):
docs = paperqa.Docs(llm=llm.model_name)
def scholar2result_llm(llm, query, path_registry):
paper_directory = path_registry.ckpt_papers
if paper_directory is None:
raise ValueError(
"'paper_dir' is None. To use this tool, the user "
"must provide a directory with PDFs at the start."
)
print("Paper Directory", paper_directory)
llm_name = llm.model_name
if llm_name.startswith("gpt") or llm_name.startswith("claude"):
settings = paperqa.Settings(
llm=llm_name,
summary_llm=llm_name,
temperature=llm.temperature,
paper_directory=paper_directory,
)
else:
docs = paperqa.Docs() # uses default gpt model in paperqa

papers = paper_search(llm, query, path_registry)
if len(papers) == 0:
return "Failed. Not enough papers found"
not_loaded = 0
for path, data in papers.items():
try:
docs.add(path, data["citation"])
except (ValueError, FileNotFoundError, PdfReadError):
not_loaded += 1

print(
f"\nFound {len(papers)} papers"
+ (f" but couldn't load {not_loaded}" if not_loaded > 0 else "")
)
answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer
return "Succeeded. " + answer
settings = paperqa.Settings(
temperature=llm.temperature, # uses default gpt model in paperqa
paper_directory=paper_directory,
)
response = paperqa.ask(query, settings=settings)
answer = response.answer.formatted_answer
if "I cannot answer." in answer:
answer += f" Check to ensure there's papers in {paper_directory}"
print(answer)
return answer


class Scholar2ResultLLM(BaseTool):
name = "LiteratureSearch"
description = (
"Useful to answer questions that require technical "
"knowledge. Ask a specific question."
"Useful to answer questions that may be found in literature. "
"Ask a specific question as the input."
)
llm: BaseLanguageModel = None
path_registry: Optional[PathRegistry]
Expand All @@ -96,7 +53,11 @@ def __init__(self, llm, path_registry):

def _run(self, query) -> str:
nest_asyncio.apply()
return scholar2result_llm(self.llm, query, self.path_registry)
try:
return scholar2result_llm(self.llm, query, self.path_registry)
except Exception as e:
print(e)
return f"Failed. {type(e).__name__}: {e}"

async def _arun(self, query) -> str:
"""Use the tool asynchronously."""
Expand Down
20 changes: 17 additions & 3 deletions mdagent/utils/path_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from datetime import datetime
from enum import Enum
from typing import Optional

from mdagent.utils.set_ckpt import SetCheckpoint

Expand All @@ -22,20 +23,33 @@ class PathRegistry:

@classmethod
# set ckpt_dir to None by default
def get_instance(cls, ckpt_dir=None):
def get_instance(cls, ckpt_dir=None, paper_dir=None):
# todo: use same ckpt if run_id is given
if not cls.instance or ckpt_dir is not None:
cls.instance = cls(ckpt_dir)
cls.instance = cls(ckpt_dir, paper_dir)
return cls.instance

def __init__(self, ckpt_dir: str = "ckpt"):
def __init__(self, ckpt_dir: str = "ckpt", paper_dir=None):
self._set_ckpt(ckpt_dir)
self._set_paper_dir(paper_dir)
self._make_all_dirs()
self._init_path_registry()

def _set_ckpt(self, ckpt: str):
self.ckpt_dir = self.set_ckpt.set_ckpt_subdir(ckpt_dir=ckpt)

def _set_paper_dir(self, paper_dir: Optional[str]):
if paper_dir is None:
self.ckpt_papers = None
return
absolute_path = os.path.abspath(paper_dir)
if not os.path.exists(absolute_path) or not os.path.isdir(absolute_path):
raise ValueError(
f"Invalid paper directory: '{absolute_path}' either doesn't exist "
"or isn't a directory."
)
self.ckpt_papers = absolute_path

def _make_all_dirs(self):
self.json_file_path = os.path.join(self.ckpt_dir, "paths_registry.json")
self.ckpt_files = os.path.join(self.ckpt_dir, "files")
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
"matplotlib",
"nbformat",
"openai",
"paper-qa==4.0.0rc8 ",
"paper-scraper @ git+https://github.com/blackadad/paper-scraper.git",
"paper-qa==5.0.6",
"pandas",
"pydantic>=2.6",
"python-dotenv",
Expand Down
4 changes: 0 additions & 4 deletions tests/test_preprocess/test_uniprot.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,14 +487,10 @@ def test_get_ids(query_uniprot):
"P68871",
"P02089",
"P02070",
"O13163",
"P02008",
"B3EWR7",
"P04244",
"P02094",
"P83479",
"P01966",
"O93349",
"P68872",
"P69905",
"P02088",
Expand Down

0 comments on commit 2b1926d

Please sign in to comment.