Skip to content

Commit

Permalink
Improve compute_num_correct_matchings
Browse files Browse the repository at this point in the history
  • Loading branch information
ulupo committed Jan 24, 2024
1 parent 37fc352 commit aefa67b
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 38 deletions.
22 changes: 14 additions & 8 deletions diffpass/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,27 +208,33 @@ def get_single_and_paired_seqs(
"""Single and paired sequences from two MSAs. The paired sequences are returned as a list of
dictionaries, where the keys are the concatenated sequences and the values are the number of
times that pair appears in the concatenated MSA."""
x_seqs = []
y_seqs = []
x_seqs_by_group = []
y_seqs_by_group = []

idx = 0
xy_seqs_to_counts = []
xy_seqs_to_counts_by_group = []
for s in group_sizes:
x_seqs_this_group = list(zip(*msa_x[idx : s + idx]))[1]
x_seqs.append(x_seqs_this_group)
x_seqs_by_group.append(x_seqs_this_group)
y_seqs_this_group = list(zip(*msa_y[idx : s + idx]))[1]
y_seqs.append(y_seqs_this_group)
y_seqs_by_group.append(y_seqs_this_group)
xy_seqs_this_group = [
f"{x_seq}:{y_seq}"
for x_seq, y_seq in zip(x_seqs_this_group, y_seqs_this_group)
]
unique_xy, counts_xy = np.unique(
unique_xy_this_group, counts_xy_this_group = np.unique(
np.array(xy_seqs_this_group), return_counts=True
)
xy_seqs_to_counts.append(dict(zip(unique_xy, counts_xy)))
xy_seqs_to_counts_by_group.append(
dict(zip(unique_xy_this_group, counts_xy_this_group))
)
idx += s

return {"x_seqs": x_seqs, "y_seqs": y_seqs, "xy_seqs_to_counts": xy_seqs_to_counts}
return {
"x_seqs_by_group": x_seqs_by_group,
"y_seqs_by_group": y_seqs_by_group,
"xy_seqs_to_counts_by_group": xy_seqs_to_counts_by_group,
}


def msa_tokenizer(
Expand Down
26 changes: 15 additions & 11 deletions diffpass/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ def fit(

# %% ../nbs/train.ipynb 9
def compute_num_correct_matchings(
results: DiffPASSResults,
hard_perms: list[list[np.ndarray]],
*,
index_based: bool,
single_and_paired_seqs: Optional[dict[str, list]] = None,
Expand All @@ -1124,7 +1124,7 @@ def compute_num_correct_matchings(
"""
correct = []
if index_based:
for perms in results.hard_perms:
for perms in hard_perms:
correct_this_step = 0
for perm_this_group in perms:
n_seqs_this_group = perm_this_group.shape[-1]
Expand All @@ -1134,25 +1134,29 @@ def compute_num_correct_matchings(
correct_this_step += correct_this_group
correct.append(correct_this_step)
else:
x_seqs = single_and_paired_seqs["x_seqs"]
y_seqs = single_and_paired_seqs["y_seqs"]
xy_seqs_to_counts = single_and_paired_seqs["xy_seqs_to_counts"]
for perms in results.hard_perms:
x_seqs_by_group = single_and_paired_seqs["x_seqs_by_group"]
y_seqs_by_group = single_and_paired_seqs["y_seqs_by_group"]
xy_seqs_to_counts_by_group = single_and_paired_seqs[
"xy_seqs_to_counts_by_group"
]
for perms in hard_perms:
correct_this_perm = 0
for (
perm_this_group,
x_seqs_this_group,
y_seqs_this_group,
xy_seqs_this_group,
) in zip(perms, x_seqs, y_seqs, xy_seqs_to_counts):
_xy_seqs_this_group = xy_seqs_this_group.copy()
xy_seqs_to_counts_this_group,
) in zip(
perms, x_seqs_by_group, y_seqs_by_group, xy_seqs_to_counts_by_group
):
_xy_seqs_to_counts_this_group = xy_seqs_to_counts_this_group.copy()
x_seqs_this_group_perm = [
x_seqs_this_group[idx] for idx in perm_this_group
]
for x_seq, y_seq in zip(x_seqs_this_group_perm, y_seqs_this_group):
xy_key = f"{x_seq}:{y_seq}"
if _xy_seqs_this_group.get(xy_key, 0) > 0:
_xy_seqs_this_group[xy_key] -= 1
if _xy_seqs_to_counts_this_group.get(xy_key, 0) > 0:
_xy_seqs_to_counts_this_group[xy_key] -= 1
correct_this_perm += 1

correct.append(correct_this_perm)
Expand Down
22 changes: 14 additions & 8 deletions nbs/data_utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -232,27 +232,33 @@
" \"\"\"Single and paired sequences from two MSAs. The paired sequences are returned as a list of\n",
" dictionaries, where the keys are the concatenated sequences and the values are the number of\n",
" times that pair appears in the concatenated MSA.\"\"\"\n",
" x_seqs = []\n",
" y_seqs = []\n",
" x_seqs_by_group = []\n",
" y_seqs_by_group = []\n",
"\n",
" idx = 0\n",
" xy_seqs_to_counts = []\n",
" xy_seqs_to_counts_by_group = []\n",
" for s in group_sizes:\n",
" x_seqs_this_group = list(zip(*msa_x[idx : s + idx]))[1]\n",
" x_seqs.append(x_seqs_this_group)\n",
" x_seqs_by_group.append(x_seqs_this_group)\n",
" y_seqs_this_group = list(zip(*msa_y[idx : s + idx]))[1]\n",
" y_seqs.append(y_seqs_this_group)\n",
" y_seqs_by_group.append(y_seqs_this_group)\n",
" xy_seqs_this_group = [\n",
" f\"{x_seq}:{y_seq}\"\n",
" for x_seq, y_seq in zip(x_seqs_this_group, y_seqs_this_group)\n",
" ]\n",
" unique_xy, counts_xy = np.unique(\n",
" unique_xy_this_group, counts_xy_this_group = np.unique(\n",
" np.array(xy_seqs_this_group), return_counts=True\n",
" )\n",
" xy_seqs_to_counts.append(dict(zip(unique_xy, counts_xy)))\n",
" xy_seqs_to_counts_by_group.append(\n",
" dict(zip(unique_xy_this_group, counts_xy_this_group))\n",
" )\n",
" idx += s\n",
"\n",
" return {\"x_seqs\": x_seqs, \"y_seqs\": y_seqs, \"xy_seqs_to_counts\": xy_seqs_to_counts}\n",
" return {\n",
" \"x_seqs_by_group\": x_seqs_by_group,\n",
" \"y_seqs_by_group\": y_seqs_by_group,\n",
" \"xy_seqs_to_counts_by_group\": xy_seqs_to_counts_by_group,\n",
" }\n",
"\n",
"\n",
"def msa_tokenizer(\n",
Expand Down
26 changes: 15 additions & 11 deletions nbs/train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,7 @@
"#| export\n",
"\n",
"def compute_num_correct_matchings(\n",
" results: DiffPASSResults,\n",
" hard_perms: list[list[np.ndarray]],\n",
" *,\n",
" index_based: bool,\n",
" single_and_paired_seqs: Optional[dict[str, list]] = None,\n",
Expand All @@ -1220,7 +1220,7 @@
" \"\"\"\n",
" correct = []\n",
" if index_based:\n",
" for perms in results.hard_perms:\n",
" for perms in hard_perms:\n",
" correct_this_step = 0\n",
" for perm_this_group in perms:\n",
" n_seqs_this_group = perm_this_group.shape[-1]\n",
Expand All @@ -1230,25 +1230,29 @@
" correct_this_step += correct_this_group\n",
" correct.append(correct_this_step)\n",
" else:\n",
" x_seqs = single_and_paired_seqs[\"x_seqs\"]\n",
" y_seqs = single_and_paired_seqs[\"y_seqs\"]\n",
" xy_seqs_to_counts = single_and_paired_seqs[\"xy_seqs_to_counts\"]\n",
" for perms in results.hard_perms:\n",
" x_seqs_by_group = single_and_paired_seqs[\"x_seqs_by_group\"]\n",
" y_seqs_by_group = single_and_paired_seqs[\"y_seqs_by_group\"]\n",
" xy_seqs_to_counts_by_group = single_and_paired_seqs[\n",
" \"xy_seqs_to_counts_by_group\"\n",
" ]\n",
" for perms in hard_perms:\n",
" correct_this_perm = 0\n",
" for (\n",
" perm_this_group,\n",
" x_seqs_this_group,\n",
" y_seqs_this_group,\n",
" xy_seqs_this_group,\n",
" ) in zip(perms, x_seqs, y_seqs, xy_seqs_to_counts):\n",
" _xy_seqs_this_group = xy_seqs_this_group.copy()\n",
" xy_seqs_to_counts_this_group,\n",
" ) in zip(\n",
" perms, x_seqs_by_group, y_seqs_by_group, xy_seqs_to_counts_by_group\n",
" ):\n",
" _xy_seqs_to_counts_this_group = xy_seqs_to_counts_this_group.copy()\n",
" x_seqs_this_group_perm = [\n",
" x_seqs_this_group[idx] for idx in perm_this_group\n",
" ]\n",
" for x_seq, y_seq in zip(x_seqs_this_group_perm, y_seqs_this_group):\n",
" xy_key = f\"{x_seq}:{y_seq}\"\n",
" if _xy_seqs_this_group.get(xy_key, 0) > 0:\n",
" _xy_seqs_this_group[xy_key] -= 1\n",
" if _xy_seqs_to_counts_this_group.get(xy_key, 0) > 0:\n",
" _xy_seqs_to_counts_this_group[xy_key] -= 1\n",
" correct_this_perm += 1\n",
"\n",
" correct.append(correct_this_perm)\n",
Expand Down

0 comments on commit aefa67b

Please sign in to comment.