Skip to content

Commit

Permalink
fix: fix models according to new schema
Browse files Browse the repository at this point in the history
  • Loading branch information
JayGhiya committed Oct 29, 2024
1 parent 4648a89 commit 5e69eb1
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from sentence_transformers import SentenceTransformer
from typing import List

#TODO: this code is duplicated across ingestion and query engine. We will not refactor
# as we move the embedding part to infrastructure such as vespa/marqo
class UnoplatEmbeddingGenerator:
def __init__(self, model_name: str):
self.model = SentenceTransformer(model_name, trust_remote_code=True)

self.dimensions = self.model.get_sentence_embedding_dimension()

def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
task = 'retrieval.query'
return self.model.encode(texts, task=task).tolist()

def generate_embeddings_for_single_text(self, text: str) -> List[float]:
task = 'retrieval.query'
return self.model.encode(text, task=task, convert_to_tensor=False)
return self.model.encode(text, task=task, convert_to_tensor=False)

def get_dimensions(self) -> int:
return self.dimensions
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,19 @@ def _create_relationships(self, tx: Transaction, relationships):

def create_vector_index(self, label: str, property: str, dimension: int = None, similarity_function: str = 'cosine') -> None:
with self.driver.session() as session:
query = f"CREATE VECTOR INDEX {label}_{property}_vector_index FOR (n:{label}) ON (n.{property})"
query = f"CREATE VECTOR INDEX {property}_vector_index FOR (n:{label}) ON (n.{property})"
if dimension is not None:
query += f" OPTIONS {{indexConfig: {{`vector.dimensions`: {dimension}, `vector.similarity_function`: '{similarity_function}'}}}}"
try:
session.run(query)
except Exception as e:
if "equivalent index already exists" in str(e):
print(f"Vector index for {label}.{property} already exists. Skipping creation.")
print(f"Vector index for {property} already exists. Skipping creation.")
else:
raise # Re-raise the exception if it's not about existing index

def create_text_index(self, label: str, property: str) -> None:
with self.driver.session() as session:
index_name = f"{label.lower()}_{property.lower()}_text_index"
index_name = f"{property.lower()}_text_index"
query = f"CREATE TEXT INDEX {index_name} FOR (n:{label}) ON (n.{property})"
session.run(query)
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,19 @@ async def load_codebase_graph(self, file_path: str) -> None:
async def _create_vector_index_on_all_nodes(self):
# Create vector indexes for all node types

node_types = ["Codebase","Package","Class","Method"]
embedding_types = ["objective_embedding", "implementation_embedding"]

for node_type in node_types:
for embedding_type in embedding_types:
await self._create_vector_index(node_type, embedding_type)
node_embedding_properties = {
"ConfluenceCodebase": ["codebase_objective_embedding", "codebase_implementation_summary_embedding"],
"ConfluencePackage": ["package_objective_embedding", "package_implementation_summary_embedding"],
"ConfluenceClass": ["class_objective_embedding", "class_implementation_summary_embedding"],
"ConfluenceMethod": ["function_objective_embedding", "function_implementation_summary_embedding"]
}

for node_type, embedding_properties in node_embedding_properties.items():
for embedding_property in embedding_properties:
await self._create_vector_index(node_label=node_type, embedding_property=embedding_property, dimensions=self.embedding_generator.get_dimensions())

async def _create_vector_index(self, node_label: str, embedding_property: str):
self.graph_helper.create_vector_index(node_label, embedding_property)
async def _create_vector_index(self, node_label: str, embedding_property: str, dimensions: int):
self.graph_helper.create_vector_index(node_label, embedding_property, dimensions)

async def load_existing_codebases(self):
return self.graph_helper.get_existing_codebases()
Expand Down

0 comments on commit 5e69eb1

Please sign in to comment.