-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_predict_all_molecules.py
49 lines (43 loc) · 1.45 KB
/
run_predict_all_molecules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from ensmallen import HyperSketchingPy
from src.models import XGBoost
import pandas as pd
from grape import Graph
from tqdm import tqdm
from src.predict import predict_all_molecules_for_one_species
model = XGBoost.load_model("xgboost_model.pkl")
graph = Graph.from_csv(
name="LOTUS_with_NCBITaxon",
node_path="./data/lotus_with_ncbi_clean_nodes.csv",
node_list_separator="\t",
node_list_header=True,
nodes_column_number=0,
node_list_node_types_column_number=1,
edge_path="./data/lotus_with_ncbi_clean_edges.csv",
edge_list_separator="\t",
edge_list_header=True,
sources_column_number=0,
destinations_column_number=1,
edge_list_edge_types_column_number=2,
# directed=True,
directed=False,
load_edge_list_in_parallel=False,
load_node_list_in_parallel=False,
)
print("The graph hash is : ", graph.hash())
lotus = pd.read_csv("data/molecules/230106_frozen_metadata.csv.gz", low_memory=False)
lotus["wd_species"] = "wd:" + lotus.organism_wikidata.str.extract(r"(Q\d+)")
lotus["wd_molecule"] = "wd:" + lotus.structure_wikidata.str.extract(r"(Q\d+)")
sketching_features = HyperSketchingPy(
hops=2,
normalize=False,
graph=graph,
)
sketching_features.fit()
for species in tqdm(["wd:Q311176", "wd:Q15550965", "wd:Q21319402"]):
predict_all_molecules_for_one_species(
species=species,
graph=graph,
lotus=lotus,
model=model,
sketching_features=sketching_features,
)