Skip to content
This repository has been archived by the owner on Jun 15, 2024. It is now read-only.

Commit

Permalink
refactor: adjust abstraction of EmbeddingIndexer
Browse files Browse the repository at this point in the history
  • Loading branch information
doomspec committed Sep 1, 2023
1 parent 5b0d3c4 commit 0ef2f71
Showing 1 changed file with 68 additions and 36 deletions.
104 changes: 68 additions & 36 deletions evonote/indexing/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ def get_top_k_notes(self, query: List[str], weights: List[float] = None, k: int
"show_src_similarity_gui": False,
}


class AbsEmbeddingIndexer(Indexer):
@classmethod
def prepare_src_weight_list(cls, new_notes: List[Note], indexing: Indexing,
use_cache: bool) -> (
List[List[str]], List[List[float]], List[Note]):
List[List[str]], List[List[float]], List[Note]):
raise NotImplementedError

@classmethod
Expand Down Expand Up @@ -177,6 +178,17 @@ def get_similarities(cls, query: List[str], indexing: Indexing,

return similarity, indexing.data["index_to_note"]

@classmethod
def process_note_with_content(cls, notes: List[Note], indexing: Indexing,
use_cache: bool):
raise NotImplementedError

@classmethod
def process_note_without_content(cls, notes: List[Note], indexing: Indexing,
use_cache: bool):
raise NotImplementedError


def show_src_similarity_gui(similarity, data, query, weights, top_k=10):
from evonote.gui.similarity_search import draw_similarity_gui
top_note_index = np.argsort(similarity)[::-1][:top_k]
Expand All @@ -187,49 +199,21 @@ def show_src_similarity_gui(similarity, data, query, weights, top_k=10):
src_list = [src_list[i] for i in top_note_index]
draw_similarity_gui(src_list, weights, query, contents)


class FragmentedEmbeddingIndexer(AbsEmbeddingIndexer):
@classmethod
def prepare_src_weight_list(cls, new_notes: List[Note], indexing: Indexing,
use_cache: bool):

notebook = indexing.notebook

index_to_note = []

notes_with_content = []
notes_content = []
notes_without_content = []
for note in new_notes:
if len(note.content) == 0:
keywords_on_path = note.get_note_path(notebook)
if len(keywords_on_path) != 0:
notes_without_content.append(note)
continue
notes_with_content.append(note)
notes_content.append(note.content)

new_src_list_1 = []
new_weights_1 = []

for note in notes_without_content:
keywords_on_path = note.get_note_path(notebook)
# keep last 1/3 of the keywords
n_keywords = min(max(math.ceil(len(keywords_on_path) / 3), 3),
len(keywords_on_path))
new_src = keywords_on_path[-n_keywords:]
new_src_list_1.append(new_src)
weight = np.array([i + 1 for i in range(len(new_src))])
weight = weight / np.sum(weight)
new_weights_1.append(weight)

def process_note_with_content(cls, notes: List[Note], indexing: Indexing,
use_cache: bool):
break_sent_use_cache = lambda sent: process_sent_into_frags(sent, use_cache,
get_main_path())
notes_content = [note.content for note in notes]
notebook = indexing.notebook

new_src_list_2 = []
new_weights_2 = []
n_finished = 0
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
for note, frags in zip(notes_with_content,
for note, frags in zip(notes,
executor.map(break_sent_use_cache, notes_content)):
new_src = []
new_src.extend(frags)
Expand All @@ -239,14 +223,62 @@ def prepare_src_weight_list(cls, new_notes: List[Note], indexing: Indexing,
new_src.append(note.content)

new_src_list_2.append(new_src)
weight = np.ones(len(new_src)) / (len(new_src) ** 0.9)
# TODO Handle when there are too many fragments. Maybe we should group
# them by clustering
weight = np.ones(len(new_src)) / (len(new_src) ** 0.95)
new_weights_2.append(weight)

n_finished += 1
if n_finished % 20 == 19:
save_cache()

save_cache()
return new_src_list_2, new_weights_2

@classmethod
def process_note_without_content(cls, notes: List[Note], indexing: Indexing,
use_cache: bool):
new_src_list_1 = []
new_weights_1 = []

for note in notes:
keywords_on_path = note.get_note_path(indexing.notebook)
# keep last 1/3 of the keywords
n_keywords = min(max(math.ceil(len(keywords_on_path) / 3), 3),
len(keywords_on_path))
new_src = keywords_on_path[-n_keywords:]
new_src_list_1.append(new_src)
weight = np.array([i + 1 for i in range(len(new_src))])
weight = weight / np.sum(weight)
new_weights_1.append(weight)

return new_src_list_1, new_weights_1

@classmethod
def prepare_src_weight_list(cls, new_notes: List[Note], indexing: Indexing,
use_cache: bool):

notebook = indexing.notebook

index_to_note = []

notes_with_content = []
notes_content = []
notes_without_content = []
for note in new_notes:
if len(note.content) == 0:
keywords_on_path = note.get_note_path(notebook)
if len(keywords_on_path) != 0:
notes_without_content.append(note)
continue
notes_with_content.append(note)
notes_content.append(note.content)

new_src_list_1, new_weights_1 = cls.process_note_without_content(
notes_without_content, indexing, use_cache)

new_src_list_2, new_weights_2 = cls.process_note_with_content(
notes_with_content, indexing, use_cache)

new_src_list = new_src_list_1 + new_src_list_2
new_weights = new_weights_1 + new_weights_2
Expand Down

0 comments on commit 0ef2f71

Please sign in to comment.