From aefa67b3b1125d4246053b564fab9481d1562071 Mon Sep 17 00:00:00 2001 From: Umberto Lupo Date: Wed, 24 Jan 2024 15:26:30 +0100 Subject: [PATCH] Improve compute_num_correct_matchings --- diffpass/data_utils.py | 22 ++++++++++++++-------- diffpass/train.py | 26 +++++++++++++++----------- nbs/data_utils.ipynb | 22 ++++++++++++++-------- nbs/train.ipynb | 26 +++++++++++++++----------- 4 files changed, 58 insertions(+), 38 deletions(-) diff --git a/diffpass/data_utils.py b/diffpass/data_utils.py index 1dadbb8..fd97498 100644 --- a/diffpass/data_utils.py +++ b/diffpass/data_utils.py @@ -208,27 +208,33 @@ def get_single_and_paired_seqs( """Single and paired sequences from two MSAs. The paired sequences are returned as a list of dictionaries, where the keys are the concatenated sequences and the values are the number of times that pair appears in the concatenated MSA.""" - x_seqs = [] - y_seqs = [] + x_seqs_by_group = [] + y_seqs_by_group = [] idx = 0 - xy_seqs_to_counts = [] + xy_seqs_to_counts_by_group = [] for s in group_sizes: x_seqs_this_group = list(zip(*msa_x[idx : s + idx]))[1] - x_seqs.append(x_seqs_this_group) + x_seqs_by_group.append(x_seqs_this_group) y_seqs_this_group = list(zip(*msa_y[idx : s + idx]))[1] - y_seqs.append(y_seqs_this_group) + y_seqs_by_group.append(y_seqs_this_group) xy_seqs_this_group = [ f"{x_seq}:{y_seq}" for x_seq, y_seq in zip(x_seqs_this_group, y_seqs_this_group) ] - unique_xy, counts_xy = np.unique( + unique_xy_this_group, counts_xy_this_group = np.unique( np.array(xy_seqs_this_group), return_counts=True ) - xy_seqs_to_counts.append(dict(zip(unique_xy, counts_xy))) + xy_seqs_to_counts_by_group.append( + dict(zip(unique_xy_this_group, counts_xy_this_group)) + ) idx += s - return {"x_seqs": x_seqs, "y_seqs": y_seqs, "xy_seqs_to_counts": xy_seqs_to_counts} + return { + "x_seqs_by_group": x_seqs_by_group, + "y_seqs_by_group": y_seqs_by_group, + "xy_seqs_to_counts_by_group": xy_seqs_to_counts_by_group, + } def msa_tokenizer( diff --git a/diffpass/train.py b/diffpass/train.py index 0bda174..bc2205a 100644 --- a/diffpass/train.py +++ b/diffpass/train.py @@ -1110,7 +1110,7 @@ def fit( # %% ../nbs/train.ipynb 9 def compute_num_correct_matchings( - results: DiffPASSResults, + hard_perms: list[list[np.ndarray]], *, index_based: bool, single_and_paired_seqs: Optional[dict[str, list]] = None, @@ -1124,7 +1124,7 @@ def compute_num_correct_matchings( """ correct = [] if index_based: - for perms in results.hard_perms: + for perms in hard_perms: correct_this_step = 0 for perm_this_group in perms: n_seqs_this_group = perm_this_group.shape[-1] @@ -1134,25 +1134,29 @@ def compute_num_correct_matchings( correct_this_step += correct_this_group correct.append(correct_this_step) else: - x_seqs = single_and_paired_seqs["x_seqs"] - y_seqs = single_and_paired_seqs["y_seqs"] - xy_seqs_to_counts = single_and_paired_seqs["xy_seqs_to_counts"] - for perms in results.hard_perms: + x_seqs_by_group = single_and_paired_seqs["x_seqs_by_group"] + y_seqs_by_group = single_and_paired_seqs["y_seqs_by_group"] + xy_seqs_to_counts_by_group = single_and_paired_seqs[ + "xy_seqs_to_counts_by_group" + ] + for perms in hard_perms: correct_this_perm = 0 for ( perm_this_group, x_seqs_this_group, y_seqs_this_group, - xy_seqs_this_group, - ) in zip(perms, x_seqs, y_seqs, xy_seqs_to_counts): - _xy_seqs_this_group = xy_seqs_this_group.copy() + xy_seqs_to_counts_this_group, + ) in zip( + perms, x_seqs_by_group, y_seqs_by_group, xy_seqs_to_counts_by_group + ): + _xy_seqs_to_counts_this_group = xy_seqs_to_counts_this_group.copy() x_seqs_this_group_perm = [ x_seqs_this_group[idx] for idx in perm_this_group ] for x_seq, y_seq in zip(x_seqs_this_group_perm, y_seqs_this_group): xy_key = f"{x_seq}:{y_seq}" - if _xy_seqs_this_group.get(xy_key, 0) > 0: - _xy_seqs_this_group[xy_key] -= 1 + if _xy_seqs_to_counts_this_group.get(xy_key, 0) > 0: + _xy_seqs_to_counts_this_group[xy_key] -= 1 correct_this_perm += 1 correct.append(correct_this_perm) diff --git a/nbs/data_utils.ipynb b/nbs/data_utils.ipynb index ccf4d87..1285a09 100644 --- a/nbs/data_utils.ipynb +++ b/nbs/data_utils.ipynb @@ -232,27 +232,33 @@ " \"\"\"Single and paired sequences from two MSAs. The paired sequences are returned as a list of\n", " dictionaries, where the keys are the concatenated sequences and the values are the number of\n", " times that pair appears in the concatenated MSA.\"\"\"\n", - " x_seqs = []\n", - " y_seqs = []\n", + " x_seqs_by_group = []\n", + " y_seqs_by_group = []\n", "\n", " idx = 0\n", - " xy_seqs_to_counts = []\n", + " xy_seqs_to_counts_by_group = []\n", " for s in group_sizes:\n", " x_seqs_this_group = list(zip(*msa_x[idx : s + idx]))[1]\n", - " x_seqs.append(x_seqs_this_group)\n", + " x_seqs_by_group.append(x_seqs_this_group)\n", " y_seqs_this_group = list(zip(*msa_y[idx : s + idx]))[1]\n", - " y_seqs.append(y_seqs_this_group)\n", + " y_seqs_by_group.append(y_seqs_this_group)\n", " xy_seqs_this_group = [\n", " f\"{x_seq}:{y_seq}\"\n", " for x_seq, y_seq in zip(x_seqs_this_group, y_seqs_this_group)\n", " ]\n", - " unique_xy, counts_xy = np.unique(\n", + " unique_xy_this_group, counts_xy_this_group = np.unique(\n", " np.array(xy_seqs_this_group), return_counts=True\n", " )\n", - " xy_seqs_to_counts.append(dict(zip(unique_xy, counts_xy)))\n", + " xy_seqs_to_counts_by_group.append(\n", + " dict(zip(unique_xy_this_group, counts_xy_this_group))\n", + " )\n", " idx += s\n", "\n", - " return {\"x_seqs\": x_seqs, \"y_seqs\": y_seqs, \"xy_seqs_to_counts\": xy_seqs_to_counts}\n", + " return {\n", + " \"x_seqs_by_group\": x_seqs_by_group,\n", + " \"y_seqs_by_group\": y_seqs_by_group,\n", + " \"xy_seqs_to_counts_by_group\": xy_seqs_to_counts_by_group,\n", + " }\n", "\n", "\n", "def msa_tokenizer(\n", diff --git a/nbs/train.ipynb b/nbs/train.ipynb index 6c100bb..d5f8d6d 100644 --- a/nbs/train.ipynb +++ b/nbs/train.ipynb @@ -1206,7 +1206,7 @@ "#| export\n", "\n", "def compute_num_correct_matchings(\n", - " results: DiffPASSResults,\n", + " hard_perms: list[list[np.ndarray]],\n", " *,\n", " index_based: bool,\n", " single_and_paired_seqs: Optional[dict[str, list]] = None,\n", @@ -1220,7 +1220,7 @@ " \"\"\"\n", " correct = []\n", " if index_based:\n", - " for perms in results.hard_perms:\n", + " for perms in hard_perms:\n", " correct_this_step = 0\n", " for perm_this_group in perms:\n", " n_seqs_this_group = perm_this_group.shape[-1]\n", @@ -1230,25 +1230,29 @@ " correct_this_step += correct_this_group\n", " correct.append(correct_this_step)\n", " else:\n", - " x_seqs = single_and_paired_seqs[\"x_seqs\"]\n", - " y_seqs = single_and_paired_seqs[\"y_seqs\"]\n", - " xy_seqs_to_counts = single_and_paired_seqs[\"xy_seqs_to_counts\"]\n", - " for perms in results.hard_perms:\n", + " x_seqs_by_group = single_and_paired_seqs[\"x_seqs_by_group\"]\n", + " y_seqs_by_group = single_and_paired_seqs[\"y_seqs_by_group\"]\n", + " xy_seqs_to_counts_by_group = single_and_paired_seqs[\n", + " \"xy_seqs_to_counts_by_group\"\n", + " ]\n", + " for perms in hard_perms:\n", " correct_this_perm = 0\n", " for (\n", " perm_this_group,\n", " x_seqs_this_group,\n", " y_seqs_this_group,\n", - " xy_seqs_this_group,\n", - " ) in zip(perms, x_seqs, y_seqs, xy_seqs_to_counts):\n", - " _xy_seqs_this_group = xy_seqs_this_group.copy()\n", + " xy_seqs_to_counts_this_group,\n", + " ) in zip(\n", + " perms, x_seqs_by_group, y_seqs_by_group, xy_seqs_to_counts_by_group\n", + " ):\n", + " _xy_seqs_to_counts_this_group = xy_seqs_to_counts_this_group.copy()\n", " x_seqs_this_group_perm = [\n", " x_seqs_this_group[idx] for idx in perm_this_group\n", " ]\n", " for x_seq, y_seq in zip(x_seqs_this_group_perm, y_seqs_this_group):\n", " xy_key = f\"{x_seq}:{y_seq}\"\n", - " if _xy_seqs_this_group.get(xy_key, 0) > 0:\n", - " _xy_seqs_this_group[xy_key] -= 1\n", + " if _xy_seqs_to_counts_this_group.get(xy_key, 0) > 0:\n", + " _xy_seqs_to_counts_this_group[xy_key] -= 1\n", " correct_this_perm += 1\n", "\n", " correct.append(correct_this_perm)\n",