diff --git a/FlagEmbedding/abc/evaluation/searcher.py b/FlagEmbedding/abc/evaluation/searcher.py index f75b9182..931b54be 100644 --- a/FlagEmbedding/abc/evaluation/searcher.py +++ b/FlagEmbedding/abc/evaluation/searcher.py @@ -138,6 +138,9 @@ def __call__( (not os.path.exists(os.path.join(corpus_embd_save_dir, "doc.npy")) or self.overwrite): os.makedirs(corpus_embd_save_dir, exist_ok=True) np.save(os.path.join(corpus_embd_save_dir, "doc.npy"), corpus_emb) + + gc.collect() + torch.cuda.empty_cache() faiss_index = index(corpus_embeddings=corpus_emb) all_scores, all_indices = search(query_embeddings=queries_emb, faiss_index=faiss_index, k=self.search_top_k) diff --git a/FlagEmbedding/abc/inference/AbsEmbedder.py b/FlagEmbedding/abc/inference/AbsEmbedder.py index fdf3247d..8d7a8651 100644 --- a/FlagEmbedding/abc/inference/AbsEmbedder.py +++ b/FlagEmbedding/abc/inference/AbsEmbedder.py @@ -264,8 +264,7 @@ def encode( return embeddings def __del__(self): - if self.pool is not None: - self.stop_multi_process_pool(self.pool) + self.stop_self_pool() @abstractmethod def encode_single_device( diff --git a/FlagEmbedding/abc/inference/AbsReranker.py b/FlagEmbedding/abc/inference/AbsReranker.py index d3bb7391..d0c775ec 100644 --- a/FlagEmbedding/abc/inference/AbsReranker.py +++ b/FlagEmbedding/abc/inference/AbsReranker.py @@ -210,8 +210,7 @@ def compute_score( return scores def __del__(self): - if self.pool is not None: - self.stop_multi_process_pool(self.pool) + self.stop_self_pool() @abstractmethod def compute_score_single_gpu( diff --git a/FlagEmbedding/evaluation/mteb/runner.py b/FlagEmbedding/evaluation/mteb/runner.py index cb1d23ae..34c050ea 100644 --- a/FlagEmbedding/evaluation/mteb/runner.py +++ b/FlagEmbedding/evaluation/mteb/runner.py @@ -65,14 +65,15 @@ def read_results(self, output_folder, tasks): print('ERROR') break - temp_data = data['scores'][split][0] - - if metric == 'ap': - tasks_results[t_type][task_name] = round(temp_data['cos_sim']['ap'] * 100, 2) - elif metric == 'cosine_spearman': - tasks_results[t_type][task_name] = round(temp_data['cos_sim']['spearman'] * 100, 2) - else: - tasks_results[t_type][task_name] = round(temp_data[metric] * 100, 2) + temp_datas = data['scores'][split][0] + temp_data = None + for td in temp_datas: + if td['hf_subset'] == 'default': + temp_data = td + if temp_data is None: + temp_data = temp_datas[0] + tasks_results[t_type][task_name] = round(temp_data['main_score'] * 100, 2) + print(f"tasks_results: {tasks_results}") return tasks_results @@ -119,16 +120,13 @@ def run(self): task_types=task_types ) output_folder = self.eval_args.output_dir - new_tasks = [] - for task in tasks: - if task.languages is not None: - if len(task.languages) == len([e for e in languages if e in task.languages]): - new_tasks.append(task) - for task in new_tasks: + for task in tasks: task_name = task.metadata.name task_type = task.metadata.type + self.retriever.stop_pool() + if self.eval_args.use_special_instructions: try: instruction = get_task_def_by_task_name_and_type(task_name, task_type) diff --git a/FlagEmbedding/evaluation/mteb/searcher.py b/FlagEmbedding/evaluation/mteb/searcher.py index 04233188..91323934 100644 --- a/FlagEmbedding/evaluation/mteb/searcher.py +++ b/FlagEmbedding/evaluation/mteb/searcher.py @@ -1,3 +1,5 @@ +import numpy as np + from typing import List, Dict, Optional from FlagEmbedding.abc.evaluation import EvalDenseRetriever, EvalReranker @@ -18,11 +20,18 @@ def get_instruction(self): def set_normalize_embeddings(self, normalize_embeddings: bool = True): self.embedder.normalize_embeddings = normalize_embeddings + def stop_pool(self): + self.embedder.stop_self_pool() + try: + self.embedder.stop_self_query_pool() + except: + pass + def encode_queries(self, queries: List[str], **kwargs): emb = self.embedder.encode_queries(queries) if isinstance(emb, dict): emb = emb["dense_vecs"] - return emb + return emb.astype(np.float32) def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs): if isinstance(corpus[0], dict): @@ -32,7 +41,7 @@ def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs): emb = self.embedder.encode_corpus(input_texts) if isinstance(emb, dict): emb = emb["dense_vecs"] - return emb + return emb.astype(np.float32) def encode(self, corpus: List[Dict[str, str]], **kwargs): if isinstance(corpus[0], dict): @@ -42,7 +51,7 @@ def encode(self, corpus: List[Dict[str, str]], **kwargs): emb = self.embedder.encode_queries(input_texts) if isinstance(emb, dict): emb = emb["dense_vecs"] - return emb + return emb.astype(np.float32) class MTEBEvalReranker(EvalReranker): def __init__(self, reranker, **kwargs): diff --git a/FlagEmbedding/inference/embedder/decoder_only/icl.py b/FlagEmbedding/inference/embedder/decoder_only/icl.py index 0923254a..57e6d2cf 100644 --- a/FlagEmbedding/inference/embedder/decoder_only/icl.py +++ b/FlagEmbedding/inference/embedder/decoder_only/icl.py @@ -4,6 +4,7 @@ import queue from multiprocessing import Queue +import gc import torch import numpy as np from transformers import AutoModel, AutoTokenizer @@ -121,10 +122,8 @@ def __init__( self.query_pool = None def __del__(self): - if self.pool is not None: - self.stop_multi_process_pool(self.pool) - if self.query_pool is not None: - self.stop_multi_process_pool(self.query_pool) + self.stop_self_pool() + self.stop_self_query_pool() def set_examples(self, examples_for_task: Optional[List[dict]] = None): """Set the prefix to the provided examples. @@ -175,6 +174,14 @@ def get_detailed_example(instruction_format: str, instruction: str, query: str, """ return instruction_format.format(instruction, query, response) + def stop_self_query_pool(self): + if self.query_pool is not None: + self.stop_multi_process_pool(self.query_pool) + self.query_pool = None + self.model.to('cpu') + gc.collect() + torch.cuda.empty_cache() + def encode_queries( self, queries: Union[List[str], str], @@ -209,9 +216,7 @@ def encode_queries( **kwargs ) - if self.pool is not None: - self.stop_multi_process_pool(self.pool) - self.pool = None + self.stop_self_pool() if self.query_pool is None: self.query_pool = self.start_multi_process_pool(ICLLLMEmbedder._encode_queries_multi_process_worker) embeddings = self.encode_multi_process( @@ -244,9 +249,7 @@ def encode_corpus( Returns: Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor. """ - if self.query_pool is not None: - self.stop_multi_process_pool(self.query_pool) - self.query_pool = None + self.stop_self_query_pool() return super().encode_corpus( corpus, batch_size=batch_size,