From a125b0ce804c5b3a869fbf487930ec3e97d39c05 Mon Sep 17 00:00:00 2001 From: Roque Lopez Date: Wed, 20 Nov 2024 15:57:17 -0500 Subject: [PATCH] feat: Make caching function extensible for any standard --- bdikit/models/contrastive_learning/cl_api.py | 5 +- bdikit/utils.py | 58 +++++++------------- 2 files changed, 23 insertions(+), 40 deletions(-) diff --git a/bdikit/models/contrastive_learning/cl_api.py b/bdikit/models/contrastive_learning/cl_api.py index d37ee23..0f053f7 100644 --- a/bdikit/models/contrastive_learning/cl_api.py +++ b/bdikit/models/contrastive_learning/cl_api.py @@ -1,4 +1,3 @@ -import os from typing import List, Dict, Tuple, Optional from bdikit.config import get_device import numpy as np @@ -13,7 +12,7 @@ from sklearn.metrics.pairwise import cosine_similarity from tqdm.auto import tqdm from bdikit.download import get_cached_model_or_download -from bdikit.utils import check_gdc_cache, write_embeddings_to_cache +from bdikit.utils import check_embedding_cache, write_embeddings_to_cache from bdikit.models import ColumnEmbedder @@ -108,7 +107,7 @@ def _sample_to_15_rows(self, table: pd.DataFrame): def _load_table_tokens(self, table: pd.DataFrame) -> List[np.ndarray]: - embedding_file, embeddings = check_gdc_cache(table, self.model_path) + embedding_file, embeddings = check_embedding_cache(table, self.model_path) if embeddings != None: print(f"Table features loaded for {len(table.columns)} columns") diff --git a/bdikit/utils.py b/bdikit/utils.py index 91b644d..bcdfd77 100644 --- a/bdikit/utils.py +++ b/bdikit/utils.py @@ -1,18 +1,11 @@ import os import hashlib import pandas as pd -from os.path import join, dirname +from os.path import join, dirname, isfile from bdikit.download import BDIKIT_EMBEDDINGS_CACHE_DIR -from bdikit.standards.standard_factory import Standards - -GDC_TABLE_PATH = join(dirname(__file__), "./resource/gdc_table.csv") - -__gdc_df = None -__gdc_hash = None def hash_dataframe(df: pd.DataFrame) -> str: - hash_object = hashlib.sha256() columns_string = ",".join(df.columns) + "\n" @@ -27,50 +20,41 @@ def hash_dataframe(df: pd.DataFrame) -> str: def write_embeddings_to_cache(embedding_file: str, embeddings: list): - os.makedirs(os.path.dirname(embedding_file), exist_ok=True) + os.makedirs(dirname(embedding_file), exist_ok=True) with open(embedding_file, "w") as file: for vec in embeddings: file.write(",".join([str(val) for val in vec]) + "\n") -def load_gdc_data(): - global __gdc_df, __gdc_hash - if __gdc_df is None or __gdc_hash is None: - standard = Standards.get_standard("gdc") - __gdc_df = standard.get_dataframe_rep() - __gdc_hash = hash_dataframe(__gdc_df) - - -def check_gdc_cache(table: pd.DataFrame, model_path: str): - global __gdc_df, __gdc_hash - load_gdc_data() - +def check_embedding_cache(table: pd.DataFrame, model_path: str): + embedding_file = None + embeddings = None table_hash = hash_dataframe(table) + model_name = model_path.split("/")[-1] + cache_model_path = join(BDIKIT_EMBEDDINGS_CACHE_DIR, model_name) + os.makedirs(cache_model_path, exist_ok=True) - df_hash_file = None - features = None + hash_list = { + f for f in os.listdir(cache_model_path) if isfile(join(cache_model_path, f)) + } - # check if table for computing embedding is the same as the GDC table we have in resources - if table_hash == __gdc_hash: - model_name = model_path.split("/")[-1] - cache_model_path = join(BDIKIT_EMBEDDINGS_CACHE_DIR, model_name) - df_hash_file = join(cache_model_path, __gdc_hash) + embedding_file = join(cache_model_path, table_hash) - # Found file in cache - if os.path.isfile(df_hash_file): + # Check if table for computing embedding is the same as the tables we have in resources + if table_hash in hash_list: + if isfile(embedding_file): try: # Load embeddings from disk - with open(df_hash_file, "r") as file: - features = [ + with open(embedding_file, "r") as file: + embeddings = [ [float(val) for val in vec.split(",")] for vec in file.read().split("\n") if vec.strip() ] - if len(features) != len(__gdc_df.columns): - features = None - raise ValueError("Mismatch in the number of features") + except Exception as e: print(f"Error loading features from cache: {e}") - features = None - return df_hash_file, features + embeddings = None + + return embedding_file, embeddings