Skip to content

Commit

Permalink
Merge pull request #713 from DandinPower/feature/no_redundant_subtopic
Browse files Browse the repository at this point in the history
Feature: Reducing redundancy problems in Detailed Reports
  • Loading branch information
assafelovic authored Aug 4, 2024
2 parents 24a4e8f + a8026d4 commit 56e4d5b
Show file tree
Hide file tree
Showing 8 changed files with 383 additions and 67 deletions.
16 changes: 15 additions & 1 deletion backend/report_type/detailed_report/detailed_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from gpt_researcher.master.actions import (
add_source_urls,
extract_headers,
extract_sections,
table_of_contents,
)
from gpt_researcher.master.agent import GPTResearcher
Expand Down Expand Up @@ -49,6 +50,9 @@ def __init__(
# This is a global variable to store the entire context accumulated at any point through searching and scraping
self.global_context = []

# This is a global variable to store all written sections. It will be used to retrieve relevant written content before any subtopic report to prevent redundant content writing.
self.global_written_sections = []

# This is a global variable to store the entire url list accumulated at any point through searching and scraping
if self.source_urls:
self.global_urls = set(self.source_urls)
Expand Down Expand Up @@ -134,10 +138,20 @@ async def _get_subtopic_report(self, subtopic: dict) -> str:
# Conduct research on the subtopic
await subtopic_assistant.conduct_research()

# Use research results to generate draft section titles
draft_section_titles = await subtopic_assistant.get_draft_section_titles()
parse_draft_section_titles = extract_headers(draft_section_titles)
parse_draft_section_titles_text = [header.get("text", "") for header in parse_draft_section_titles]

# Use the draft section titles to get previous relevant written contents
relevant_contents = await subtopic_assistant.get_similar_written_contents_by_draft_section_titles(current_subtopic_task, parse_draft_section_titles_text, self.global_written_sections)

# Here the headers gathered from previous subtopic reports are passed to the write report function
# The LLM is later instructed to avoid generating any information relating to these headers as they have already been generated
subtopic_report = await subtopic_assistant.write_report(self.existing_headers)
subtopic_report = await subtopic_assistant.write_report(self.existing_headers, relevant_contents)

# Update the global written sections list
self.global_written_sections.extend(extract_sections(subtopic_report))
# Update context of the global context variable
self.global_context = list(set(subtopic_assistant.context))
# Update url list of the global list variable
Expand Down
1 change: 1 addition & 0 deletions frontend/styles.css
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ footer {
}

#reportContainer {
font-family: 'Georgia', 'Times New Roman', Times, "Courier New", serif;
background-color: rgba(255, 255, 255, 0.1);
font-family: 'Times New Roman', Times, "Courier New", serif;
border: none;
Expand Down
42 changes: 41 additions & 1 deletion gpt_researcher/context/compression.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import asyncio
from .retriever import SearchAPIRetriever
from .retriever import SearchAPIRetriever, SectionRetriever
from langchain.retrievers import (
ContextualCompressionRetriever,
)
Expand Down Expand Up @@ -55,3 +55,43 @@ async def async_get_context(self, query, max_results=5, cost_callback=None):
cost_callback(estimate_embedding_cost(model=OPENAI_EMBEDDING_MODEL, docs=self.documents))
relevant_docs = await asyncio.to_thread(compressed_docs.invoke, query)
return self.__pretty_print_docs(relevant_docs, max_results)


class WrittenContentCompressor:
def __init__(self, documents, embeddings, similarity_threshold, **kwargs):
self.documents = documents
self.kwargs = kwargs
self.embeddings = embeddings
self.similarity_threshold = similarity_threshold

def __get_contextual_retriever(self):
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
relevance_filter = EmbeddingsFilter(embeddings=self.embeddings,
similarity_threshold=self.similarity_threshold)
pipeline_compressor = DocumentCompressorPipeline(
transformers=[splitter, relevance_filter]
)
base_retriever = SectionRetriever(
sections=self.documents
)
contextual_retriever = ContextualCompressionRetriever(
base_compressor=pipeline_compressor, base_retriever=base_retriever
)
return contextual_retriever

def __pretty_docs_list(self, docs, top_n):
return [f"Title: {d.metadata.get('section_title')}\nContent: {d.page_content}\n" for i, d in enumerate(docs) if i < top_n]

def get_context(self, query, max_results=5, cost_callback=None):
compressed_docs = self.__get_contextual_retriever()
if cost_callback:
cost_callback(estimate_embedding_cost(model=OPENAI_EMBEDDING_MODEL, docs=self.documents))
relevant_docs = compressed_docs.invoke(query)
return self.__pretty_print_docs(relevant_docs, max_results)

async def async_get_context(self, query, max_results=5, cost_callback=None):
compressed_docs = self.__get_contextual_retriever()
if cost_callback:
cost_callback(estimate_embedding_cost(model=OPENAI_EMBEDDING_MODEL, docs=self.documents))
relevant_docs = await asyncio.to_thread(compressed_docs.invoke, query)
return self.__pretty_docs_list(relevant_docs, max_results)
33 changes: 33 additions & 0 deletions gpt_researcher/context/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,36 @@ def _get_relevant_documents(
]

return docs

class SectionRetriever(BaseRetriever):
"""
SectionRetriever:
This class is used to retrieve sections while avoiding redundant subtopics.
"""
sections: List[Dict] = []
"""
sections example:
[
{
"section_title": "Example Title",
"written_content": "Example content"
},
...
]
"""

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:

docs = [
Document(
page_content=page.get("written_content", ""),
metadata={
"section_title": page.get("section_title", ""),
},
)
for page in self.sections # Changed 'self.pages' to 'self.sections'
]

return docs
66 changes: 65 additions & 1 deletion gpt_researcher/master/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import json_repair
import markdown

from typing import List, Dict

from gpt_researcher.master.prompts import *
from gpt_researcher.scraper.scraper import Scraper
from gpt_researcher.utils.enum import Tone
Expand Down Expand Up @@ -325,6 +327,35 @@ async def summarize_url(
print(f"{Fore.RED}Error in summarize: {e}{Style.RESET_ALL}")
return summary

async def generate_draft_section_titles(
query: str,
context,
agent_role_prompt: str,
report_type: str,
websocket,
cfg,
main_topic: str = "",
cost_callback: callable = None,
headers=None
) -> str:
assert report_type == "subtopic_report", "This function is only for subtopic reports"
content = f"{generate_draft_titles_prompt(query, main_topic, context)}"
try:
draft_section_titles = await create_chat_completion(
model=cfg.fast_llm_model,
messages=[
{"role": "system", "content": f"{agent_role_prompt}"},
{"role": "user", "content": content},
],
temperature=0,
llm_provider=cfg.llm_provider,
llm_kwargs=cfg.llm_kwargs,
cost_callback=cost_callback,
)
except Exception as e:
print(f"{Fore.RED}Error in generate_draft_section_titles: {e}{Style.RESET_ALL}")

return draft_section_titles

async def generate_report(
query: str,
Expand All @@ -337,6 +368,7 @@ async def generate_report(
cfg,
main_topic: str = "",
existing_headers: list = [],
relevant_written_contents: list = [],
cost_callback: callable = None,
headers=None,
):
Expand All @@ -352,6 +384,7 @@ async def generate_report(
cfg:
main_topic:
existing_headers:
relevant_written_contents:
cost_callback:
Returns:
Expand All @@ -362,7 +395,7 @@ async def generate_report(
report = ""

if report_type == "subtopic_report":
content = f"{generate_prompt(query, existing_headers, main_topic, context, report_format=cfg.report_format, total_words=cfg.total_words)}"
content = f"{generate_prompt(query, existing_headers, relevant_written_contents, main_topic, context, report_format=cfg.report_format, total_words=cfg.total_words)}"
if tone:
content += f", tone={tone}"
summary = await create_chat_completion(
Expand Down Expand Up @@ -502,6 +535,37 @@ def extract_headers(markdown_text: str):

return headers # Return the list of headers

def extract_sections(markdown_text: str) -> List[Dict[str, str]]:
"""
Extract all written sections from subtopic report
Args:
markdown_text: subtopic report text
Returns:
List of sections, each section is dictionary and contain following information
[
{
"section_title": "Pruning",
"written_content": "Pruning involves removing redundant or less ..."
},
]
"""
sections = []
parsed_md = markdown.markdown(markdown_text)

# Use regex to find all headers and their content
pattern = r'<h\d>(.*?)</h\d>(.*?)(?=<h\d>|$)'
matches = re.findall(pattern, parsed_md, re.DOTALL)

for title, content in matches:
# Clean up the content
clean_content = re.sub(r'<.*?>', '', content).strip()
if clean_content:
sections.append({
"section_title": title.strip(),
"written_content": clean_content
})

return sections

def table_of_contents(markdown_text: str):
try:
Expand Down
113 changes: 111 additions & 2 deletions gpt_researcher/master/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import time

from typing import Set

from gpt_researcher.config import Config
from gpt_researcher.context.compression import ContextCompressor
from gpt_researcher.context.compression import ContextCompressor, WrittenContentCompressor
from gpt_researcher.document import DocumentLoader, LangChainDocumentLoader
from gpt_researcher.master.actions import *
from gpt_researcher.memory import Memory
Expand Down Expand Up @@ -149,7 +151,7 @@ async def conduct_research(self):

return self.context

async def write_report(self, existing_headers: list = []):
async def write_report(self, existing_headers: list = [], relevant_written_contents: list = []):
"""
Writes the report based on research conducted
Expand Down Expand Up @@ -191,6 +193,7 @@ async def write_report(self, existing_headers: list = []):
cfg=self.cfg,
main_topic=self.parent_query,
existing_headers=existing_headers,
relevant_written_contents=relevant_written_contents,
cost_callback=self.add_costs,
headers=self.headers,
)
Expand Down Expand Up @@ -444,3 +447,109 @@ async def get_subtopics(self):
)

return subtopics

async def get_draft_section_titles(self):
"""
Writes the draft section titles based on research conducted. The draft section titles are used to retrieve the previous relevant written contents.
Returns:
str: The headers markdown text
"""
if self.verbose:
await stream_output(
"logs",
"task_summary_coming_up",
f"✍️ Writing draft section titles for research task: {self.query}...",
self.websocket,
)

draft_section_titles = await generate_draft_section_titles(
query=self.query,
context=self.context,
agent_role_prompt=self.role,
report_type=self.report_type,
websocket=self.websocket,
cfg=self.cfg,
main_topic=self.parent_query,
cost_callback=self.add_costs,
headers=self.headers,
)

return draft_section_titles

async def __get_similar_written_contents_by_query(self,
query: str,
written_contents: List[Dict],
similarity_threshold: float = 0.5,
max_results: int = 10
) -> List[str]:
"""
Asynchronously retrieves similar written contents based on a given query.
Args:
query (str): The query to search for similar written contents.
written_contents (List[Dict]): List of written contents to search through.
similarity_threshold (float, optional): The minimum similarity score for content to be considered relevant.
Defaults to 0.5.
max_results (int, optional): The maximum number of similar contents to return. Defaults to 10.
Returns:
List[str]: A list of similar written contents, limited by max_results.
"""
if self.verbose:
await stream_output(
"logs",
"fetching_relevant_written_content",
f"🔎 Getting relevant written content based on query: {query}...",
self.websocket,
)

# Retrieve similar written contents based on the query
# Use a higher similarity threshold to ensure more relevant results and reduce irrelevant matches
written_content_compressor = WrittenContentCompressor(
documents=written_contents, embeddings=self.memory.get_embeddings(), similarity_threshold=similarity_threshold
)
return await written_content_compressor.async_get_context(
query=query, max_results=max_results, cost_callback=self.add_costs
)

async def get_similar_written_contents_by_draft_section_titles(
self,
current_subtopic: str,
draft_section_titles: List[str],
written_contents: List[Dict],
max_results: int = 10
) -> List[str]:
"""
Retrieve similar written contents based on current subtopic and draft section titles.
Args:
current_subtopic (str): The current subtopic.
draft_section_titles (List[str]): List of draft section titles.
written_contents (List[Dict]): List of written contents to search through.
max_results (int): Maximum number of results to return. Defaults to 10.
Returns:
List[str]: List of relevant written contents.
"""
all_queries = [current_subtopic] + draft_section_titles

async def process_query(query: str) -> Set[str]:
return set(await self.__get_similar_written_contents_by_query(query, written_contents))

# Run all queries in parallel
results = await asyncio.gather(*[process_query(query) for query in all_queries])

# Combine all results
relevant_contents = set().union(*results)

# Limit the number of results
relevant_contents = list(relevant_contents)[:max_results]

if relevant_contents and self.verbose:
prettier_contents = "\n".join(relevant_contents)
await stream_output(
"logs", "relevant_contents_context", f"📃 {prettier_contents}", self.websocket
)

return relevant_contents
Loading

0 comments on commit 56e4d5b

Please sign in to comment.