Skip to content

Commit

Permalink
Refactor PR help message tool to use full documentation content for a…
Browse files Browse the repository at this point in the history
…nswering questions and update relevant section handling in prompts
  • Loading branch information
mrT23 committed Oct 24, 2024
1 parent 4f14742 commit 9786499
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 121 deletions.
Binary file removed docs/chroma_db.zip
Binary file not shown.
6 changes: 3 additions & 3 deletions pr_agent/settings/pr_help_prompts.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ You will recieve a question, and the full documentation website content.
Your goal is to provide the best answer to the question using the documentation provided.
Additional instructions:
- Try to be short and concise in your answers. Give examples if needed.
- Try to be short and concise in your answers. Try to give examples if needed.
- The main tools of PR-Agent are 'describe', 'review', 'improve'. If there is ambiguity to which tool the user is referring to, prioritize snippets of these tools over others.
Expand All @@ -17,7 +17,7 @@ class relevant_section(BaseModel):
class DocHelper(BaseModel):
user_question: str = Field(description="The user's question")
response: str = Field(description="The response to the user's question")
relevant_sections: List[relevant_section] = Field(description="A list of the relevant markdown sections in the documentation that answer the user's question, ordered by importance")
relevant_sections: List[relevant_section] = Field(description="A list of the relevant markdown sections in the documentation that answer the user's question, ordered by importance (most relevant first)")
=====
Expand All @@ -41,7 +41,7 @@ User's Question:
=====
Relevant doc snippets retrieved:
Documentation website content:
=====
{{ snippets|trim }}
=====
Expand Down
136 changes: 24 additions & 112 deletions pr_agent/tools/pr_help_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@

from jinja2 import Environment, StrictUndefined

from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import ModelType, load_yaml
from pr_agent.algo.utils import ModelType, load_yaml, clip_tokens
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider, GithubProvider, BitbucketServerProvider, \
get_git_provider_with_context
Expand Down Expand Up @@ -68,83 +69,6 @@ def parse_args(self, args):
question_str = ""
return question_str

def get_sim_results_from_s3_db(self, embeddings):
get_logger().info("Loading the S3 index...")
sim_results = []
try:
from langchain_chroma import Chroma
from urllib import request
with tempfile.TemporaryDirectory() as temp_dir:
# Define the local file path within the temporary directory
local_file_path = os.path.join(temp_dir, 'chroma_db.zip')

bucket = 'pr-agent'
file_name = 'chroma_db.zip'
s3_url = f'https://{bucket}.s3.amazonaws.com/{file_name}'
request.urlretrieve(s3_url, local_file_path)

# # Download the file from S3 to the temporary directory
# s3 = boto3.client('s3')
# s3.download_file(bucket, file_name, local_file_path)

# Extract the contents of the zip file
with zipfile.ZipFile(local_file_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)

vectorstore = Chroma(persist_directory=temp_dir + "/chroma_db",
embedding_function=embeddings)
sim_results = vectorstore.similarity_search_with_score(self.question_str, k=self.num_retrieved_snippets)
except Exception as e:
get_logger().error(f"Error while getting sim from S3: {e}",
artifact={"traceback": traceback.format_exc()})
return sim_results

def get_sim_results_from_local_db(self, embeddings):
get_logger().info("Loading the local index...")
sim_results = []
try:
from langchain_chroma import Chroma
get_logger().info("Loading the Chroma index...")
db_path = "./docs/chroma_db.zip"
if not os.path.exists(db_path):
db_path= "/app/docs/chroma_db.zip"
if not os.path.exists(db_path):
get_logger().error("Local db not found")
return sim_results
with tempfile.TemporaryDirectory() as temp_dir:

# Extract the ZIP file
with zipfile.ZipFile(db_path, 'r') as zip_ref:
zip_ref.extractall(temp_dir)

vectorstore = Chroma(persist_directory=temp_dir + "/chroma_db",
embedding_function=embeddings)

# Do similarity search
sim_results = vectorstore.similarity_search_with_score(self.question_str, k=self.num_retrieved_snippets)
except Exception as e:
get_logger().error(f"Error while getting sim from local db: {e}",
artifact={"traceback": traceback.format_exc()})
return sim_results

def get_sim_results_from_pinecone_db(self, embeddings):
get_logger().info("Loading the Pinecone index...")
sim_results = []
try:
from langchain_pinecone import PineconeVectorStore
INDEX_NAME = "pr-agent-docs"
vectorstore = PineconeVectorStore(
index_name=INDEX_NAME, embedding=embeddings,
pinecone_api_key=get_settings().pinecone.api_key
)

# Do similarity search
sim_results = vectorstore.similarity_search_with_score(self.question_str, k=self.num_retrieved_snippets)
except Exception as e:
get_logger().error(f"Error while getting sim from Pinecone db: {e}",
artifact={"traceback": traceback.format_exc()})
return sim_results

async def run(self):
try:
if self.question_str:
Expand All @@ -159,48 +83,36 @@ async def run(self):
return

# current path
docs_path= Path(__file__).parent.parent.parent/'docs'/'docs'
docs_path= Path(__file__).parent.parent.parent / 'docs' / 'docs'
# get all the 'md' files inside docs_path and its subdirectories
md_files = list(docs_path.glob('**/*.md'))
folders_to_exclude =['/finetuning_benchmark/']
files_to_exclude = ['EXAMPLE_BEST_PRACTICE.md','compression_strategy.md']
folders_to_exclude = ['/finetuning_benchmark/']
files_to_exclude = ['EXAMPLE_BEST_PRACTICE.md', 'compression_strategy.md', '/docs/overview/index.md']
md_files = [file for file in md_files if not any(folder in str(file) for folder in folders_to_exclude) and not any(file.name == file_to_exclude for file_to_exclude in files_to_exclude)]
# # calculate the token count of all the md files
# token_count = 0
# for file in md_files:
# with open(file, 'r') as f:
# token_count += self.token_handler.count_tokens(f.read())

docs_prompt =""
# sort the 'md_files' so that 'priority_files' will be at the top
priority_files_strings = ['/docs/index.md', '/usage-guide', 'tools/describe.md', 'tools/review.md',
'tools/improve.md', '/faq']
md_files_priority = [file for file in md_files if
any(priority_string in str(file) for priority_string in priority_files_strings)]
md_files_not_priority = [file for file in md_files if file not in md_files_priority]
md_files = md_files_priority + md_files_not_priority

docs_prompt = ""
for file in md_files:
with open(file, 'r') as f:
file_path = str(file).replace(str(docs_path), '')
docs_prompt += f"==file name:==\n\n{file_path}\n\n==file content:==\n\n{f.read()}\n=========\n\n"

docs_prompt += f"==file name:==\n\n{file_path}\n\n==file content:==\n\n{f.read().strip()}\n=========\n\n"
token_count = self.token_handler.count_tokens(docs_prompt)
get_logger().debug(f"Token count of full documentation website: {token_count}")

model = get_settings().config.model
max_tokens_full = MAX_TOKENS[model] # note - here we take the actual max tokens, without any reductions. we do aim to get the full documentation website in the prompt
delta_output = 2000
if token_count > max_tokens_full - delta_output:
get_logger().info(f"Token count {token_count} exceeds the limit {max_tokens_full - delta_output}. Skipping the PR Help message.")
docs_prompt = clip_tokens(docs_prompt, max_tokens_full - delta_output)
self.vars['snippets'] = docs_prompt.strip()
# # Initialize embeddings
# from langchain_openai import OpenAIEmbeddings
# embeddings = OpenAIEmbeddings(model="text-embedding-3-small",
# api_key=get_settings().openai.key)
#
# # Get similar snippets via similarity search
# if get_settings().pr_help.force_local_db:
# sim_results = self.get_sim_results_from_local_db(embeddings)
# elif get_settings().get('pinecone.api_key'):
# sim_results = self.get_sim_results_from_pinecone_db(embeddings)
# else:
# sim_results = self.get_sim_results_from_s3_db(embeddings)
# if not sim_results:
# get_logger().info("Failed to load the S3 index. Loading the local index...")
# sim_results = self.get_sim_results_from_local_db(embeddings)
# if not sim_results:
# get_logger().error("Failed to retrieve similar snippets. Exiting...")
# return

# # Prepare relevant snippets
# relevant_pages_full, relevant_snippets_full_header, relevant_snippets_str =\
# await self.prepare_relevant_snippets(sim_results)
# self.vars['snippets'] = relevant_snippets_str.strip()

# run the AI model
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
Expand Down
6 changes: 0 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@ gunicorn==22.0.0
pytest-cov==5.0.0
pydantic==2.8.2
html2text==2024.2.26
# help bot
langchain==0.3.0
langchain-openai==0.2.0
langchain-pinecone==0.2.0
langchain-chroma==0.1.4
chromadb==0.5.7
# Uncomment the following lines to enable the 'similar issue' tool
# pinecone-client
# pinecone-datasets @ git+https://github.com/mrT23/pinecone-datasets.git@main
Expand Down

0 comments on commit 9786499

Please sign in to comment.