Skip to content

Commit

Permalink
Merge pull request #249 from assafelovic/feature/context_compressor
Browse files Browse the repository at this point in the history
Feature/context compressor
  • Loading branch information
assafelovic authored Nov 17, 2023
2 parents f3654df + c1edcb5 commit d3c62cf
Show file tree
Hide file tree
Showing 12 changed files with 116 additions and 20 deletions.
3 changes: 2 additions & 1 deletion gpt_researcher/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ def __init__(self, config_file: str = None):
self.smart_token_limit = 4000
self.browse_chunk_max_length = 8192
self.summary_token_limit = 700
self.temperature = 0.6
self.temperature = 0.55
self.user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" \
" Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
self.max_search_results_per_query = 5
self.memory_backend = "local"
self.total_words = 1000
self.report_format = "apa"
Expand Down
4 changes: 4 additions & 0 deletions gpt_researcher/context/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .compression import ContextCompressor
from .retriever import SearchAPIRetriever

__all__ = ['ContextCompressor', 'SearchAPIRetriever']
42 changes: 42 additions & 0 deletions gpt_researcher/context/compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from .retriever import SearchAPIRetriever
from langchain.retrievers import (
ContextualCompressionRetriever,
)
from langchain.retrievers.document_compressors import (
DocumentCompressorPipeline,
EmbeddingsFilter,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter


class ContextCompressor:
def __init__(self, documents, embeddings, max_results=5, **kwargs):
self.max_results = max_results
self.documents = documents
self.kwargs = kwargs
self.embeddings = embeddings

def _get_contextual_retriever(self):
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
relevance_filter = EmbeddingsFilter(embeddings=self.embeddings, similarity_threshold=0.78)
pipeline_compressor = DocumentCompressorPipeline(
transformers=[splitter, relevance_filter]
)
base_retriever = SearchAPIRetriever(
pages=self.documents
)
contextual_retriever = ContextualCompressionRetriever(
base_compressor=pipeline_compressor, base_retriever=base_retriever
)
return contextual_retriever

def _pretty_print_docs(self, docs, top_n):
return f"\n".join(f"Source: {d.metadata.get('source')}\n"
f"Title: {d.metadata.get('title')}\n"
f"Content: {d.page_content}\n"
for i, d in enumerate(docs) if i < top_n)

def get_context(self, query, max_results=5):
compressed_docs = self._get_contextual_retriever()
relevant_docs = compressed_docs.get_relevant_documents(query)
return self._pretty_print_docs(relevant_docs, max_results)
29 changes: 29 additions & 0 deletions gpt_researcher/context/retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
from enum import Enum
from typing import Any, Dict, List, Optional

from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import Document
from langchain.schema.retriever import BaseRetriever


class SearchAPIRetriever(BaseRetriever):
"""Search API retriever."""
pages: List[Dict] = []

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:

docs = [
Document(
page_content=page.get("raw_content", ""),
metadata={
"title": page.get("title", ""),
"source": page.get("url", ""),
},
)
for page in self.pages
]

return docs
26 changes: 17 additions & 9 deletions gpt_researcher/master/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import time
from gpt_researcher.config import Config
from gpt_researcher.master.functions import *
from gpt_researcher.context.compression import ContextCompressor
from gpt_researcher.memory import Memory


class GPTResearcher:
Expand All @@ -24,6 +26,7 @@ def __init__(self, query, report_type, config_path=None, websocket=None):
self.cfg = Config(config_path)
self.retriever = get_retriever(self.cfg.retriever)
self.context = []
self.memory = Memory()
self.visited_urls = set()

async def run(self):
Expand All @@ -40,14 +43,16 @@ async def run(self):
# Generate Sub-Queries including original query
sub_queries = await get_sub_queries(self.query, self.role, self.cfg) + [self.query]
await stream_output("logs",
f"🧠 I will conduct my research based on the following queries: {sub_queries}...", self.websocket)
f"🧠 I will conduct my research based on the following queries: {sub_queries}...",
self.websocket)

# Run Sub-Queries
for sub_query in sub_queries:
await stream_output("logs", f"\n🔎 Running research for '{sub_query}'...", self.websocket)
context = await self.run_sub_query(sub_query)
scraped_sites = await self.scrape_sites_by_query(sub_query)
context = await self.get_similar_content_by_query(sub_query, scraped_sites)
await stream_output("logs", f"📃 {context}", self.websocket)
self.context.append(context)

# Conduct Research
await stream_output("logs", f"✍️ Writing {self.report_type} for research task: {self.query}...", self.websocket)
report = await generate_report(query=self.query, context=self.context,
Expand All @@ -72,7 +77,7 @@ async def get_new_urls(self, url_set_input):

return new_urls

async def run_sub_query(self, sub_query):
async def scrape_sites_by_query(self, sub_query):
"""
Runs a sub-query
Args:
Expand All @@ -83,16 +88,19 @@ async def run_sub_query(self, sub_query):
"""
# Get Urls
retriever = self.retriever(sub_query)
search_results = retriever.search()
search_results = retriever.search(max_results=self.cfg.max_search_results_per_query)
new_search_urls = await self.get_new_urls([url.get("href") for url in search_results])

# Scrape Urls
# await stream_output("logs", f"📝Scraping urls {new_search_urls}...\n", self.websocket)
content = scrape_urls(new_search_urls, self.cfg)
await stream_output("logs", f"🤔Researching for relevant information...\n", self.websocket)
# Summarize Raw Data
summary = await summarize(query=sub_query, content=content, agent_role_prompt=self.role, cfg=self.cfg, websocket=self.websocket)
scraped_content_results = scrape_urls(new_search_urls, self.cfg)
return scraped_content_results

async def get_similar_content_by_query(self, query, pages):
await stream_output("logs", f"🌐 Summarizing url: {query}", self.websocket)
# Summarize Raw Data
context_compressor = ContextCompressor(documents=pages, embeddings=self.memory.get_embeddings())
# Run Tasks
return summary
return context_compressor.get_context(query, max_results=8)

6 changes: 3 additions & 3 deletions gpt_researcher/master/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ def get_retriever(retriever):
"""
match retriever:
case "duckduckgo":
from gpt_researcher.retrievers import Duckduckgo
retriever = Duckduckgo
case "tavily":
from gpt_researcher.retrievers import TavilySearch
retriever = TavilySearch
Expand All @@ -31,6 +28,9 @@ def get_retriever(retriever):
case "serp":
from gpt_researcher.retrievers import SerpSearch
retriever = SerpSearch
case "duckduckgo":
from gpt_researcher.retrievers import Duckduckgo
retriever = Duckduckgo

case _:
raise Exception("Retriever not found.")
Expand Down
1 change: 1 addition & 0 deletions gpt_researcher/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .embeddings import Memory
11 changes: 11 additions & 0 deletions gpt_researcher/memory/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings


class Memory:
def __init__(self, **kwargs):
self._embeddings = OpenAIEmbeddings()

def get_embeddings(self):
return self._embeddings

4 changes: 2 additions & 2 deletions gpt_researcher/retrievers/google/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tavily import TavilyClient


class GoogleSearch():
class GoogleSearch:
"""
Tavily API Retriever
"""
Expand Down Expand Up @@ -50,7 +50,7 @@ def get_cx_key(self):
"You can get a key at https://developers.google.com/custom-search/v1/overview")
return api_key

def search(self):
def search(self, max_results=7):
"""
Searches the query
Returns:
Expand Down
4 changes: 2 additions & 2 deletions gpt_researcher/retrievers/searx/searx.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ def get_api_key(self):
"You can get your key from https://searx.space/")
return api_key

def search(self):
def search(self, max_results=7):
"""
Searches the query
Returns:
"""
searx = SearxSearchWrapper(searx_host=os.environ["SEARX_URL"])
results = searx.results(self.query, 5)
results = searx.results(self.query, max_results)
# Normalizing results to match the format of the other search APIs
search_response = [{"href": obj["link"], "body": obj["snippet"]} for obj in results]
return search_response
2 changes: 1 addition & 1 deletion gpt_researcher/retrievers/serper/serper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_api_key(self):
"You can get a key at https://serper.dev/")
return api_key

def search(self):
def search(self, max_results=7):
"""
Searches the query
Returns:
Expand Down
4 changes: 2 additions & 2 deletions gpt_researcher/retrievers/tavily_search/tavily_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ def get_api_key(self):
"You can get a key at https://app.tavily.com")
return api_key

def search(self):
def search(self, max_results=7):
"""
Searches the query
Returns:
"""
# Search the query
results = self.client.search(self.query, search_depth="basic", max_results=5)
results = self.client.search(self.query, search_depth="basic", max_results=max_results)
# Return the results
search_response = [{"href": obj["url"], "body": obj["content"]} for obj in results.get("results", [])]
return search_response

0 comments on commit d3c62cf

Please sign in to comment.