Skip to content

Commit

Permalink
Merge pull request #190 from unoplat/169-feat-implement-ogm-neomodel-…
Browse files Browse the repository at this point in the history
…based-on-unoplat-commons-lib-in-ingestion-utility

feat: moved from native cypher to neomodel for ingestion and it feels at home
  • Loading branch information
JayGhiya authored Oct 29, 2024
2 parents 8226336 + 5e69eb1 commit 12bb1f3
Show file tree
Hide file tree
Showing 13 changed files with 269 additions and 8,790 deletions.
2 changes: 1 addition & 1 deletion scripts/export_data_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def export_graph_to_json():

# Query for all relationships and their properties
relationships_query = """
MATCH (start)-[r:CONTAINS]->(end)
MATCH (start)-[r]->(end)
RETURN id(r) as id, type(r) as type, properties(r) as properties,
elementId(start) as startNode, elementId(end) as endNode
"""
Expand Down
3,795 changes: 0 additions & 3,795 deletions unoplat-code-confluence-query-engine/poetry.lock

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from sentence_transformers import SentenceTransformer
from typing import List

#TODO: this code is duplicated across ingestion and query engine. We will not refactor
# as we move the embedding part to infrastructure such as vespa/marqo
class UnoplatEmbeddingGenerator:
def __init__(self, model_name: str):
self.model = SentenceTransformer(model_name, trust_remote_code=True)

self.dimensions = self.model.get_sentence_embedding_dimension()

def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
task = 'retrieval.query'
return self.model.encode(texts, task=task).tolist()

def generate_embeddings_for_single_text(self, text: str) -> List[float]:
task = 'retrieval.query'
return self.model.encode(text, task=task, convert_to_tensor=False)
return self.model.encode(text, task=task, convert_to_tensor=False)

def get_dimensions(self) -> int:
return self.dimensions
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def search_similar_nodes(self, vector_index, query_embedding, top_k=5):

def get_existing_codebases(self):
query = """
MATCH (cb:Codebase)
MATCH (cb:ConfluenceCodebase)
RETURN cb.qualified_name AS codebase_name
"""
result = self.run_query(query)
Expand All @@ -52,7 +52,7 @@ def get_existing_codebases(self):

def get_package_details(self,package_name):
query = """
MATCH (p:Package {qualified_name: $package_name})
MATCH (p:ConfluencePackage {qualified_name: $package_name})
RETURN
p.qualified_name AS package_name,
p.objective AS package_objective,
Expand All @@ -62,7 +62,7 @@ def get_package_details(self,package_name):

def get_class_details(self,class_name):
query = """
MATCH (c:Class {qualified_name: $class_name})
MATCH (c:ConfluenceClass {qualified_name: $class_name})
RETURN
c.qualified_name AS class_name,
c.objective AS class_objective,
Expand All @@ -72,7 +72,7 @@ def get_class_details(self,class_name):

def get_codebase_details(self,codebase_name):
query = """
MATCH (cb:Codebase {qualified_name: $codebase_name})
MATCH (cb:ConfluenceCodebase {qualified_name: $codebase_name})
RETURN
cb.qualified_name AS codebase_name,
cb.objective AS codebase_objective,
Expand All @@ -82,8 +82,8 @@ def get_codebase_details(self,codebase_name):

def get_function_hierarchy_and_details(self, function_name):
query = """
MATCH (f:Method {qualified_name: $function_name})
OPTIONAL MATCH (f)<-[:CONTAINS]-(c:Class)<-[:CONTAINS]-(p:Package)<-[:CONTAINS]-(cb:Codebase)
MATCH (f:ConfluenceMethod {qualified_name: $function_name})
OPTIONAL MATCH (f)<-[:CONTAINS]-(c:ConfluenceClass)<-[:CONTAINS]-(p:ConfluencePackage)<-[:CONTAINS]-(cb:ConfluenceCodebase)
RETURN
cb.qualified_name AS codebase_name,
cb.objective AS codebase_objective,
Expand Down Expand Up @@ -143,19 +143,19 @@ def _create_relationships(self, tx: Transaction, relationships):

def create_vector_index(self, label: str, property: str, dimension: int = None, similarity_function: str = 'cosine') -> None:
with self.driver.session() as session:
query = f"CREATE VECTOR INDEX {label}_{property}_vector_index FOR (n:{label}) ON (n.{property})"
query = f"CREATE VECTOR INDEX {property}_vector_index FOR (n:{label}) ON (n.{property})"
if dimension is not None:
query += f" OPTIONS {{indexConfig: {{`vector.dimensions`: {dimension}, `vector.similarity_function`: '{similarity_function}'}}}}"
try:
session.run(query)
except Exception as e:
if "equivalent index already exists" in str(e):
print(f"Vector index for {label}.{property} already exists. Skipping creation.")
print(f"Vector index for {property} already exists. Skipping creation.")
else:
raise # Re-raise the exception if it's not about existing index

def create_text_index(self, label: str, property: str) -> None:
with self.driver.session() as session:
index_name = f"{label.lower()}_{property.lower()}_text_index"
index_name = f"{property.lower()}_text_index"
query = f"CREATE TEXT INDEX {index_name} FOR (n:{label}) ON (n.{property})"
session.run(query)
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ async def process_query(self, user_query: str) -> str:

if int(ConfluenceUserIntent.FUNCTIONAL_IMPLEMENTATION.value) in user_intent_list:
# Search similar functions
results = self.graph_helper.search_similar_nodes(vector_index="Method_implementation_embedding_vector_index", query_embedding=user_query_embedding, top_k=5)
results = self.graph_helper.search_similar_nodes(vector_index="function_implementation_summary_embedding_vector_index", query_embedding=user_query_embedding, top_k=5)
context = {result["name"]: result["summary"] for result in results}
log.debug(f"context for function: {context}")

if len(context) > 1:
rerank_results = self.rerank_module(user_query=user_query, possible_answers=context).answer.relevant_answers
Expand All @@ -69,9 +70,9 @@ async def process_query(self, user_query: str) -> str:

elif int(ConfluenceUserIntent.CODE_SUMMARIZATION.value) in user_intent_list:

results = self.graph_helper.search_similar_nodes(vector_index="Codebase_implementation_embedding_vector_index", query_embedding=user_query_embedding,top_k=5)
results = self.graph_helper.search_similar_nodes(vector_index="codebase_implementation_summary_embedding_vector_index", query_embedding=user_query_embedding,top_k=5)
context = {result["name"]: result["summary"] for result in results}

log.debug(f"context for codebase: {context}")
if len(context) > 1:
rerank_results = self.rerank_module(user_query=user_query, possible_answers=context).answer.relevant_answers
filtered_rerank_results = {k: v for k, v in rerank_results.items() if v > 7}
Expand All @@ -90,9 +91,9 @@ async def process_query(self, user_query: str) -> str:
final_response = final_response + self.user_query_response_module(user_query=user_query, code_metadata=context).answer

elif int(ConfluenceUserIntent.PACKAGE_OVERVIEW.value) in user_intent_list:
results = self.graph_helper.search_similar_nodes(vector_index="Package_implementation_embedding_vector_index", query_embedding=user_query_embedding,top_k=5)
results = self.graph_helper.search_similar_nodes(vector_index="package_implementation_summary_embedding_vector_index", query_embedding=user_query_embedding,top_k=5)
context = {result["name"]: result["summary"] for result in results}

log.debug(f"context for package: {context}")
if len(context) > 1:
rerank_results = self.rerank_module(user_query=user_query, possible_answers=context).answer.relevant_answers
filtered_rerank_results = {k: v for k, v in rerank_results.items() if v > 7}
Expand All @@ -111,9 +112,9 @@ async def process_query(self, user_query: str) -> str:
final_response = final_response + self.user_query_response_module(user_query=user_query, code_metadata=context).answer

elif int(ConfluenceUserIntent.CLASS_DETAILS.value) in user_intent_list:
results = self.graph_helper.search_similar_nodes(vector_index="Class_implementation_embedding_vector_index", query_embedding=user_query_embedding, top_k=5)
results = self.graph_helper.search_similar_nodes(vector_index="class_implementation_summary_embedding_vector_index", query_embedding=user_query_embedding, top_k=5)
context = {result["name"]: result["summary"] for result in results}

log.debug(f"context for class: {context}")
if len(context) > 1:
rerank_results = self.rerank_module(user_query=user_query, possible_answers=context).answer.relevant_answers
filtered_rerank_results = {k: v for k, v in rerank_results.items() if v > 7}
Expand Down Expand Up @@ -147,15 +148,19 @@ async def load_codebase_graph(self, file_path: str) -> None:
async def _create_vector_index_on_all_nodes(self):
# Create vector indexes for all node types

node_types = ["Codebase","Package","Class","Method"]
embedding_types = ["objective_embedding", "implementation_embedding"]

for node_type in node_types:
for embedding_type in embedding_types:
await self._create_vector_index(node_type, embedding_type)
node_embedding_properties = {
"ConfluenceCodebase": ["codebase_objective_embedding", "codebase_implementation_summary_embedding"],
"ConfluencePackage": ["package_objective_embedding", "package_implementation_summary_embedding"],
"ConfluenceClass": ["class_objective_embedding", "class_implementation_summary_embedding"],
"ConfluenceMethod": ["function_objective_embedding", "function_implementation_summary_embedding"]
}

for node_type, embedding_properties in node_embedding_properties.items():
for embedding_property in embedding_properties:
await self._create_vector_index(node_label=node_type, embedding_property=embedding_property, dimensions=self.embedding_generator.get_dimensions())

async def _create_vector_index(self, node_label: str, embedding_property: str):
self.graph_helper.create_vector_index(node_label, embedding_property)
async def _create_vector_index(self, node_label: str, embedding_property: str, dimensions: int):
self.graph_helper.create_vector_index(node_label, embedding_property, dimensions)

async def load_existing_codebases(self):
return self.graph_helper.get_existing_codebases()
Expand Down
Loading

0 comments on commit 12bb1f3

Please sign in to comment.