diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..efb1e94 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,12 @@ +# top-most EditorConfig file +root = true + +# Unix-style newlines with a newline ending every file +[*] +end_of_line = lf +insert_final_newline = true + +[*.py] +charset = utf-8 +indent_style = space +indent_size = 4 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..bb313d0 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,34 @@ +name: CI for GEB + +on: + push: + branches: ["**"] + pull_request: + branches: ["**"] + +permissions: + id-token: write + contents: read + actions: write + pull-requests: read + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: "3.11" + - uses: yezz123/setup-uv@v4 + with: + uv-venv: ".geb_venv" + - run: uv pip install ruff + - run: ruff format . + - run: ruff check . + # TODO: pytest + # TODO: pyright diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..f349ef9 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,51 @@ +# This workflow will +# - Find the latest version tag based on the commit history +# - Create a git tag for the new version +# - Update the version number in pyproject.toml based on the commit history +# - Upload the package to PyPI +# - Create a release on GitHub + +# This workflow required the following secrets to be set: +# - a GitHub personal access token with the `repo` scope called `RELEASE` +# - and that you setup trusted publishing using PyPI as described here: https://blog.pypi.org/posts/2023-04-20-introducing-trusted-publishers/ + +name: Release +on: + push: + branches: + - main + +jobs: + release: + runs-on: ubuntu-latest + concurrency: release + permissions: + id-token: write # IMPORTANT: this permission is mandatory for trusted publishing using PyPI + contents: write + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Python Semantic Release + id: release + uses: python-semantic-release/python-semantic-release@v9.8.3 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + + - name: Publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@v1.9.0 + if: steps.release.outputs.released == 'true' + with: + repository-url: https://test.pypi.org/legacy/ + # This action supports PyPI's trusted publishing implementation, which allows authentication to PyPI without a manually + # configured API token or username/password combination. To perform trusted publishing with this action, your project's + # publisher must already be configured on PyPI. + + - name: Publish package distributions to GitHub Releases + uses: python-semantic-release/upload-to-gh-release@v9.8.3 + if: steps.release.outputs.released == 'true' + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + tag: ${{ steps.release.outputs.tag }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bb077e1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +.venv/ +__pycache__/ +.vscode/ +build/ +dist/ +*egg-info/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f49a4e1 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..a77f005 --- /dev/null +++ b/README.md @@ -0,0 +1,120 @@ +

Genomic Embedding Benchmark

+ +

+ + GitHub release + + + arXiv URL + + + License + + + Downloads + +

+ +

+

+ Installation | + Usage | + Leaderboard | + Documentation | + Citing +

+

+ +

+ +

+ +## Installation + +TODO(joshua): +```bash +pip install geb +``` + +## Usage + +- Using the python script (see [run_geb.py](https://github.com/tattabio/geb/blob/main/run_geb.py)): + +```bash +python run_geb.py --model facebook/esm2_t6_8M_UR50D +``` + + +- Using the python API: + +```py +import geb + +model = geb.get_model("facebook/esm2_t6_8M_UR50D") +tasks = geb.get_tasks_by_modality(geb.Modality.PROTEIN) +evaluation = geb.GEB(tasks=tasks) +evaluation.run(model, output_folder="results") +``` + + +### Using a custom model + +Custom models should be wrapped with the `geb.models.BioSeqTransformer` abstract class, and specify the modality, number of layers, and embedding dimension. See see [models.py](https://github.com/tattabio/geb/blob/main/geb/models.py) for additional examples on custom model loading and inference. + + +```python +import geb +from geb.models import BioSeqTransformer +from geb.modality import Modality + +class MyModel(BioSeqTransformer): + + @property + def modality(self) -> Modality: + return Modality.PROTEIN + + @property + def num_layers(self) -> int: + return self.config.num_hidden_layers + + @property + def embed_dim(self) -> int: + return self.config.hidden_size + + +model = MyModel() +tasks = geb.get_tasks_by_modality(model.modality) +evaluation = MTEB(tasks=tasks) +evaluation.run(model) +``` + +### Evaluating on a custom dataset + +TODO(andre): Update this section + +To evaluate on a custom task, you can run the following code on your custom task. + +```python +import geb +from geb.tasks import AbsTask + +class MyCustomTask(AbsTask): + def run( + self, model: BioSeqTransformer, layers: Optional[List[int]] = None + ) -> TaskResult: + pass + +model = geb.models.ESM("facebook/esm2_t6_8M_UR50D") +evaluation = geb.GEB(tasks=[MyCustomTask()]) +evaluation.run(model) +``` + + + +## Citing + +GEB was introduced in "[GEB: Genomic Embedding Benchmark]()", feel free to cite: + +TODO(andre): bibtex + +For works that have used GEB for benchmarking, you can find them on the [leaderboard](https://huggingface.co/spaces/tattabio/GEB/leaderboard). diff --git a/benchmarks.png b/benchmarks.png new file mode 100644 index 0000000..e7c6c76 Binary files /dev/null and b/benchmarks.png differ diff --git a/docs/images/tatta_logo.png b/docs/images/tatta_logo.png new file mode 100644 index 0000000..76220bd Binary files /dev/null and b/docs/images/tatta_logo.png differ diff --git a/geb/__init__.py b/geb/__init__.py new file mode 100644 index 0000000..a3d4a78 --- /dev/null +++ b/geb/__init__.py @@ -0,0 +1,28 @@ +from geb.geb import ( + GEB, + get_all_tasks, + get_output_folder, + get_all_task_names, + get_tasks_by_name, + get_tasks_by_modality, + get_all_model_names, + get_model, +) +from geb.tasks.tasks import TaskResult +from geb.modality import Modality + +# importing without setting `__all__` produces a Ruff error: +# "imported but unused; consider removing, adding to __all__, or using a redundant alias RuffF401" +# See https://docs.astral.sh/ruff/rules/unused-import/#why-is-this-bad +__all__ = [ + "GEB", + "get_all_tasks", + "get_all_task_names", + "get_tasks_by_name", + "get_tasks_by_modality", + "get_all_model_names", + "get_model", + "get_output_folder", + "TaskResult", + "Modality", +] diff --git a/geb/eval_utils.py b/geb/eval_utils.py new file mode 100644 index 0000000..7b5f630 --- /dev/null +++ b/geb/eval_utils.py @@ -0,0 +1,394 @@ +"""Utility functions for evaluation.""" + +from typing import Any, Dict, List, Tuple +import json +import torch +import random +import numpy as np +from sklearn.metrics import auc + + +class ForwardHook: + """Pytorch forward hook class to store outputs of intermediate layers.""" + + def __init__(self, module: torch.nn.Module): + self.hook = module.register_forward_hook(self.hook_fn) + self.output = None + + def hook_fn(self, module, input, output): + self.output = output + + def close(self): + self.hook.remove() + + +def pool( + last_hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool_type: str +) -> torch.Tensor: + """Pool embeddings across the sequence length dimension.""" + assert ( + last_hidden_states.ndim == 3 + ), f"Expected hidden_states to have shape [batch, seq_len, D], got shape: {last_hidden_states.shape}" + assert ( + attention_mask.ndim == 2 + ), f"Expected attention_mask to have shape [batch, seq_len], got shape: {attention_mask.shape}" + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + if pool_type == "mean": + emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + elif pool_type == "max": + emb = last_hidden.max(dim=1)[0] + elif pool_type == "cls": + emb = last_hidden[:, 0] + elif pool_type == "last": + emb = last_hidden[torch.arange(last_hidden.size(0)), attention_mask.sum(1) - 1] + else: + raise ValueError(f"pool_type {pool_type} not supported") + return emb + + +def set_all_seeds(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def write_results_to_json(results: Dict[str, Any], results_path: str): + """Write results dict to a json file.""" + with open(results_path, "w") as f: + json.dump(results, f, indent=4) + + +def merge_split_elem_embeds(ids, embeds, preserve_order: bool = False): + """Merge embeddings with the same id by mean-pooling and optionally preserve order in which they appear. + + Args: + ids: Array of string ids, [batch]. + embeds: Array of embeddings, [batch, ...]. + + Returns: + ids: Unique ids, [unique_batch]. + embeds: Array of embeddings, [unique_batch, ...]. + """ + unique_ids, indices = np.unique(ids, return_inverse=True) + shape_no_batch = embeds.shape[1:] + sums = np.zeros([unique_ids.size, *shape_no_batch], dtype=embeds.dtype) + counts = np.bincount(indices, minlength=unique_ids.size) + np.add.at(sums, indices, embeds) + # Add trailing dimensions to counts. + counts = counts[(...,) + (None,) * len(shape_no_batch)] + mean_pooled = sums / counts + # Preserve the order of the input ids. + if preserve_order: + order = [] + for id in unique_ids: + idx = np.where(ids == id)[0][0] + order.append(idx) + re_order = np.argsort(order) + unique_ids = unique_ids[re_order] + mean_pooled = mean_pooled[re_order] + return unique_ids, mean_pooled + + +def paired_dataset(labels, embeds): + """Creates a paired dataset for consecutive operonic gene pairs.""" + embeds1 = embeds[:-1] + embeds2 = embeds[1:] + labels = labels[:-1] + return embeds1, embeds2, labels + + +def cos_sim(a, b): + """Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. + + Return: + Matrix with res[i][j] = cos_sim(a[i], b[j]) + """ # noqa: D402 + if not isinstance(a, torch.Tensor): + a = torch.tensor(a) + + if not isinstance(b, torch.Tensor): + b = torch.tensor(b) + + if len(a.shape) == 1: + a = a.unsqueeze(0) + + if len(b.shape) == 1: + b = b.unsqueeze(0) + + a_norm = torch.nn.functional.normalize(a, p=2, dim=1) + b_norm = torch.nn.functional.normalize(b, p=2, dim=1) + return torch.mm(a_norm, b_norm.transpose(0, 1)) + + +def dot_score(a: torch.Tensor, b: torch.Tensor): + """Computes the dot-product dot_prod(a[i], b[j]) for all i and j. + :return: Matrix with res[i][j] = dot_prod(a[i], b[j]) + """ + if not isinstance(a, torch.Tensor): + a = torch.tensor(a) + + if not isinstance(b, torch.Tensor): + b = torch.tensor(b) + + if len(a.shape) == 1: + a = a.unsqueeze(0) + + if len(b.shape) == 1: + b = b.unsqueeze(0) + + return torch.mm(a, b.transpose(0, 1)) + + +# From https://github.com/beir-cellar/beir/blob/f062f038c4bfd19a8ca942a9910b1e0d218759d4/beir/retrieval/custom_metrics.py#L4 +def mrr( + qrels: dict[str, dict[str, int]], + results: dict[str, dict[str, float]], + k_values: List[int], + output_type: str = "mean", +) -> Tuple[Dict[str, float]]: + MRR = {} + + for k in k_values: + MRR[f"MRR@{k}"] = [] + + k_max, top_hits = max(k_values), {} + + for query_id, doc_scores in results.items(): + top_hits[query_id] = sorted( + doc_scores.items(), key=lambda item: item[1], reverse=True + )[0:k_max] + + for query_id in top_hits: + query_relevant_docs = set( + [doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0] + ) + for k in k_values: + rr = 0 + for rank, hit in enumerate(top_hits[query_id][0:k]): + if hit[0] in query_relevant_docs: + rr = 1.0 / (rank + 1) + break + MRR[f"MRR@{k}"].append(rr) + + if output_type == "mean": + for k in k_values: + MRR[f"MRR@{k}"] = round(sum(MRR[f"MRR@{k}"]) / len(qrels), 5) + + elif output_type == "all": + pass + + return MRR + + +# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py +def recall_cap( + qrels: dict[str, dict[str, int]], + results: dict[str, dict[str, float]], + k_values: List[int], + output_type: str = "mean", +) -> Tuple[Dict[str, float]]: + capped_recall = {} + + for k in k_values: + capped_recall[f"R_cap@{k}"] = [] + + k_max = max(k_values) + + for query_id, doc_scores in results.items(): + top_hits = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[ + 0:k_max + ] + query_relevant_docs = [ + doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0 + ] + for k in k_values: + retrieved_docs = [ + row[0] for row in top_hits[0:k] if qrels[query_id].get(row[0], 0) > 0 + ] + denominator = min(len(query_relevant_docs), k) + capped_recall[f"R_cap@{k}"].append(len(retrieved_docs) / denominator) + + if output_type == "mean": + for k in k_values: + capped_recall[f"R_cap@{k}"] = round( + sum(capped_recall[f"R_cap@{k}"]) / len(qrels), 5 + ) + + elif output_type == "all": + pass + + return capped_recall + + +# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py +def hole( + qrels: dict[str, dict[str, int]], + results: dict[str, dict[str, float]], + k_values: List[int], + output_type: str = "mean", +) -> Tuple[Dict[str, float]]: + Hole = {} + + for k in k_values: + Hole[f"Hole@{k}"] = [] + + annotated_corpus = set() + for _, docs in qrels.items(): + for doc_id, score in docs.items(): + annotated_corpus.add(doc_id) + + k_max = max(k_values) + + for _, scores in results.items(): + top_hits = sorted(scores.items(), key=lambda item: item[1], reverse=True)[ + 0:k_max + ] + for k in k_values: + hole_docs = [ + row[0] for row in top_hits[0:k] if row[0] not in annotated_corpus + ] + Hole[f"Hole@{k}"].append(len(hole_docs) / k) + + if output_type == "mean": + for k in k_values: + Hole[f"Hole@{k}"] = round(Hole[f"Hole@{k}"] / len(qrels), 5) + + elif output_type == "all": + pass + + return Hole + + +# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py +def top_k_accuracy( + qrels: dict[str, dict[str, int]], + results: dict[str, dict[str, float]], + k_values: List[int], + output_type: str = "mean", +) -> Tuple[Dict[str, float]]: + top_k_acc = {} + + for k in k_values: + top_k_acc[f"Accuracy@{k}"] = [] + + k_max, top_hits = max(k_values), {} + + for query_id, doc_scores in results.items(): + top_hits[query_id] = [ + item[0] + for item in sorted( + doc_scores.items(), key=lambda item: item[1], reverse=True + )[0:k_max] + ] + + for query_id in top_hits: + query_relevant_docs = set( + [doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0] + ) + for k in k_values: + for relevant_doc_id in query_relevant_docs: + if relevant_doc_id in top_hits[query_id][0:k]: + top_k_acc[f"Accuracy@{k}"].append(1.0) + break + + if output_type == "mean": + for k in k_values: + top_k_acc[f"Accuracy@{k}"] = round( + top_k_acc[f"Accuracy@{k}"] / len(qrels), 5 + ) + + elif output_type == "all": + pass + + return top_k_acc + + +# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py +def confidence_scores(sim_scores: List[float]) -> Dict[str, float]: + """Computes confidence scores for a single instance = (query, positives, negatives) + + Args: + sim_scores: Query-documents similarity scores with length `num_pos+num_neg` + + Returns: + conf_scores: + - `max`: Maximum similarity score + - `std`: Standard deviation of similarity scores + - `diff1`: Difference between highest and second highest similarity scores + """ + sim_scores_sorted = sorted(sim_scores)[::-1] + + cs_max = sim_scores_sorted[0] + cs_std = np.std(sim_scores) + if len(sim_scores) > 1: + cs_diff1 = sim_scores_sorted[0] - sim_scores_sorted[1] + elif len(sim_scores) == 1: + cs_diff1 = 0.0 + + conf_scores = {"max": cs_max, "std": cs_std, "diff1": cs_diff1} + + return conf_scores + + +# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py +def nAUC( + conf_scores: np.ndarray, + metrics: np.ndarray, + abstention_rates: np.ndarray = np.linspace(0, 1, 11)[:-1], +) -> float: + """Computes normalized Area Under the Curve on a set of evaluated instances as presented in the paper https://arxiv.org/abs/2402.12997 + 1/ Computes the raw abstention curve, i.e., the average evaluation metric at different abstention rates determined by the confidence scores + 2/ Computes the oracle abstention curve, i.e., the best theoretical abstention curve (e.g.: at a 10% abstention rate, the oracle abstains on the bottom-10% instances with regard to the evaluation metric) + 3/ Computes the flat abstention curve, i.e., the one remains flat for all abstention rates (ineffective abstention) + 4/ Computes the area under the three curves + 5/ Finally scales the raw AUC between the oracle and the flat AUCs to get normalized AUC + + Args: + conf_scores: Instance confidence scores used for abstention thresholding, with shape `(num_test_instances,)` + metrics: Metric evaluations at instance-level (e.g.: average precision, NDCG...), with shape `(num_test_instances,)` + abstention_rates: Target rates for the computation of the abstention curve + + Returns: + abst_nauc: Normalized area under the abstention curve (upper-bounded by 1) + """ + + def abstention_curve( + conf_scores: np.ndarray, + metrics: np.ndarray, + abstention_rates: np.ndarray = np.linspace(0, 1, 11)[:-1], + ) -> np.ndarray: + """Computes the raw abstention curve for a given set of evaluated instances and corresponding confidence scores + + Args: + conf_scores: Instance confidence scores used for abstention thresholding, with shape `(num_test_instances,)` + metrics: Metric evaluations at instance-level (e.g.: average precision, NDCG...), with shape `(num_test_instances,)` + abstention_rates: Target rates for the computation of the abstention curve + + Returns: + abst_curve: Abstention curve of length `len(abstention_rates)` + """ + conf_scores_argsort = np.argsort(conf_scores) + abst_curve = np.zeros(len(abstention_rates)) + + for i, rate in enumerate(abstention_rates): + num_instances_abst = min( + round(rate * len(conf_scores_argsort)), len(conf_scores) - 1 + ) + abst_curve[i] = metrics[conf_scores_argsort[num_instances_abst:]].mean() + + return abst_curve + + abst_curve = abstention_curve(conf_scores, metrics, abstention_rates) + or_curve = abstention_curve(metrics, metrics, abstention_rates) + abst_auc = auc(abstention_rates, abst_curve) + or_auc = auc(abstention_rates, or_curve) + flat_auc = or_curve[0] * (abstention_rates[-1] - abstention_rates[0]) + + if or_auc == flat_auc: + abst_nauc = np.nan + else: + abst_nauc = (abst_auc - flat_auc) / (or_auc - flat_auc) + + return abst_nauc diff --git a/geb/evaluators.py b/geb/evaluators.py new file mode 100644 index 0000000..5098970 --- /dev/null +++ b/geb/evaluators.py @@ -0,0 +1,839 @@ +""" +Evaluator objects for different evaluation types. +""" + +import logging +import random +from abc import ABC, abstractmethod +import heapq +from collections import defaultdict +import pytrec_eval +import numpy as np +import sklearn.cluster +import torch +from scipy.stats import pearsonr +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import ( + accuracy_score, + average_precision_score, + classification_report, + f1_score, + precision_score, + recall_score, + label_ranking_average_precision_score, +) +from sklearn.metrics.cluster import v_measure_score +from sklearn.metrics.pairwise import ( + paired_cosine_distances, + paired_euclidean_distances, + paired_manhattan_distances, +) +from sklearn.multioutput import MultiOutputRegressor +from sklearn.preprocessing import MultiLabelBinarizer +from typing import Dict, List, Tuple + +from .eval_utils import ( + cos_sim, + dot_score, + mrr, + recall_cap, + hole, + confidence_scores, + nAUC, + top_k_accuracy, +) + + +class Evaluator(ABC): + """Base class for all evaluators + Extend this class and implement __call__ for custom evaluators. + """ + + def __init__(self, seed=42, **kwargs): + self.seed = seed + random.seed(self.seed) + np.random.seed(self.seed) + torch.manual_seed(self.seed) + torch.cuda.manual_seed_all(self.seed) + + @abstractmethod + def __call__(self, model): + """This is called during training to evaluate the model. + It returns scores. + + Parameters + ---------- + model: + the model to evaluate + """ + pass + + +logger = logging.getLogger(__name__) + + +class logRegClassificationEvaluator(Evaluator): + def __init__( + self, + embeds_train, + y_train, + embeds_test, + y_test, + max_iter=1000, + **kwargs, + ): + super().__init__(**kwargs) + self.embeds_train = embeds_train + self.y_train = y_train + self.embeds_test = embeds_test + self.y_test = y_test + + self.max_iter = max_iter + + def __call__(self): + scores = {} + clf = LogisticRegression( + random_state=self.seed, + n_jobs=-1, + max_iter=self.max_iter, + verbose=1 if logger.isEnabledFor(logging.DEBUG) else 0, + ) + logger.info(f"Encoding {len(self.embeds_train)} training embeds...") + X_train = np.asarray(self.embeds_train) + + logger.info(f"Encoding {len(self.embeds_test)} test embeds...") + X_test = np.asarray(self.embeds_test) + logger.info("Fitting logistic regression classifier...") + clf.fit(X_train, self.y_train) + logger.info("Evaluating...") + y_pred = clf.predict(X_test) + accuracy = accuracy_score(self.y_test, y_pred) + f1 = f1_score(self.y_test, y_pred, average="macro") + scores["accuracy"] = accuracy + scores["f1"] = f1 + + # if binary classification + if len(np.unique(self.y_train)) == 2: + ap = average_precision_score(self.y_test, y_pred) + scores["ap"] = ap + + return scores + + +class ClusteringEvaluator(Evaluator): + def __init__( + self, + embeds, + labels, + clustering_batch_size=500, + **kwargs, + ): + super().__init__(**kwargs) + self.embeds = embeds + self.labels = labels + self.clustering_batch_size = clustering_batch_size + + def __call__(self): + logger.info(f"Encoding {len(self.embeds)} embeds...") + corpus_embeddings = np.asarray(self.embeds) + + logger.info("Fitting Mini-Batch K-Means model...") + clustering_model = sklearn.cluster.MiniBatchKMeans( + n_clusters=len(set(self.labels)), + batch_size=self.clustering_batch_size, + n_init="auto", + ) + clustering_model.fit(corpus_embeddings) + cluster_assignment = clustering_model.labels_ + + logger.info("Evaluating...") + v_measure = v_measure_score(self.labels, cluster_assignment) + + return {"v_measure": v_measure} + + +class PairClassificationEvaluator(Evaluator): + """Evaluate a model based on the similarity of the embeddings by calculating the accuracy of identifying similar and + dissimilar embeds. + The metrics are the cosine similarity as well as euclidean and Manhattan distance + The returned score is the accuracy with a specified metric. + The results are written in a CSV. If a CSV already exists, then values are appended. + The labels need to be 0 for dissimilar pairs and 1 for similar pairs. + :param embeds1: The first column of embeds + :param embeds2: The second column of embeds + :param labels: labels[i] is the label for the pair (embeds1[i], embeds2[i]). Must be 0 or 1 + :param name: Name for the output + :param write_csv: Write results to a CSV file + """ + + def __init__(self, embeds1, embeds2, labels, **kwargs): + super().__init__(**kwargs) + self.embeds1 = embeds1 + self.embeds2 = embeds2 + self.labels = labels + + assert len(self.embeds1) == len(self.embeds2) + assert len(self.embeds1) == len(self.labels) + for label in labels: + assert label == 0 or label == 1 + + def __call__(self): + scores = self.compute_metrics() + # Compute the max of Average Precision (AP) over all distance metrics. + top_ap_score = max(score for k, score in scores.items() if k.endswith("_ap")) + scores["top_ap"] = top_ap_score + return scores + + def compute_metrics(self): + embeddings1 = np.array(self.embeds1) + embeddings2 = np.array(self.embeds2) + + logger.info("Computing similarity distances...") + cosine_scores = 1 - paired_cosine_distances(embeddings1, embeddings2) + manhattan_distances = paired_manhattan_distances(embeddings1, embeddings2) + euclidean_distances = paired_euclidean_distances(embeddings1, embeddings2) + + embeddings1_np = np.asarray(embeddings1) + embeddings2_np = np.asarray(embeddings2) + dot_scores = [ + np.dot(embeddings1_np[i], embeddings2_np[i]) + for i in range(len(embeddings1_np)) + ] + + logger.info("Computing metrics...") + labels = np.asarray(self.labels) + output_scores = {} + for short_name, name, scores, reverse in [ + ["cos_sim", "Cosine-Similarity", cosine_scores, True], + ["manhattan", "Manhattan-Distance", manhattan_distances, False], + ["euclidean", "Euclidean-Distance", euclidean_distances, False], + ["dot", "Dot-Product", dot_scores, True], + ]: + metrics = self._compute_metrics(scores, labels, reverse) + metrics = {short_name + "_" + k: v for k, v in metrics.items()} + output_scores.update(metrics) + + return output_scores + + @staticmethod + def _compute_metrics(scores, labels, high_score_more_similar): + """Compute the metrics for the given scores and labels. + + Args: + scores (`np.ndarray` of shape (n_pairs, )): The similarity/dissimilarity scores for the pairs. + labels (`np.ndarray` of shape (n_pairs, )): The labels for the pairs. + high_score_more_similar (`bool`): If true, then the higher the score, the more similar the pairs are. + + Returns: + `dict`: The metrics for the given scores and labels. + """ + acc, acc_threshold = PairClassificationEvaluator.find_best_acc_and_threshold( + scores, labels, high_score_more_similar + ) + f1, precision, recall, f1_threshold = ( + PairClassificationEvaluator.find_best_f1_and_threshold( + scores, labels, high_score_more_similar + ) + ) + ap = PairClassificationEvaluator.ap_score( + scores, labels, high_score_more_similar + ) + + return { + "accuracy": acc, + "accuracy_threshold": acc_threshold, + "f1": f1, + "f1_threshold": f1_threshold, + "precision": precision, + "recall": recall, + "ap": ap, + } + + @staticmethod + def find_best_acc_and_threshold(scores, labels, high_score_more_similar: bool): + assert len(scores) == len(labels) + rows = list(zip(scores, labels)) + + rows = sorted(rows, key=lambda x: x[0], reverse=high_score_more_similar) + + max_acc = 0 + best_threshold = -1 + + positive_so_far = 0 + remaining_negatives = sum(np.array(labels) == 0) + + for i in range(len(rows) - 1): + score, label = rows[i] + if label == 1: + positive_so_far += 1 + else: + remaining_negatives -= 1 + + acc = (positive_so_far + remaining_negatives) / len(labels) + if acc > max_acc: + max_acc = acc + best_threshold = (rows[i][0] + rows[i + 1][0]) / 2 + + return max_acc, best_threshold + + @staticmethod + def find_best_f1_and_threshold(scores, labels, high_score_more_similar: bool): + assert len(scores) == len(labels) + + scores = np.asarray(scores) + labels = np.asarray(labels) + + rows = list(zip(scores, labels)) + + rows = sorted(rows, key=lambda x: x[0], reverse=high_score_more_similar) + + best_f1 = best_precision = best_recall = 0 + threshold = 0 + nextract = 0 + ncorrect = 0 + total_num_duplicates = sum(labels) + + for i in range(len(rows) - 1): + score, label = rows[i] + nextract += 1 + + if label == 1: + ncorrect += 1 + + if ncorrect > 0: + precision = ncorrect / nextract + recall = ncorrect / total_num_duplicates + f1 = 2 * precision * recall / (precision + recall) + if f1 > best_f1: + best_f1 = f1 + best_precision = precision + best_recall = recall + threshold = (rows[i][0] + rows[i + 1][0]) / 2 + + return best_f1, best_precision, best_recall, threshold + + @staticmethod + def ap_score(scores, labels, high_score_more_similar: bool): + return average_precision_score( + labels, scores * (1 if high_score_more_similar else -1) + ) + + +class MultiClassMultiOutputLogRegClassificationEvaluator(Evaluator): + def __init__( + self, + embeds_train, + y_train, + embeds_test, + y_test, + max_iter=1000, + **kwargs, + ): + super().__init__(**kwargs) + self.embeds_train = embeds_train + self.y_train = y_train + self.embeds_test = embeds_test + self.y_test = y_test + self.max_iter = max_iter + + def __call__(self): + scores = {} + mlb = MultiLabelBinarizer() + # all classes in y_train and y_test + + class_labels = list(self.y_train) + list(self.y_test) + labels = [class_label.split(", ") for class_label in class_labels] + mlb.fit(labels) + train_labels = [class_label.split(", ") for class_label in self.y_train] + test_labels = [class_label.split(", ") for class_label in self.y_test] + + y_train = mlb.transform(train_labels) + y_test = mlb.transform(test_labels) + clf = MultiOutputRegressor( + LogisticRegression( + random_state=self.seed, solver="lbfgs", max_iter=self.max_iter + ) + ).fit(self.embeds_train, y_train) + y_pred = clf.predict(self.embeds_test) + + results_dict = classification_report(y_test, y_pred, output_dict=True) + assert isinstance( + results_dict, dict + ), "Should always be true since `output_dict=True` is passed to sklearn.metric.classification_report" + scores["precision"] = results_dict["macro avg"]["precision"] + scores["recall"] = results_dict["macro avg"]["recall"] + scores["f1"] = results_dict["macro avg"]["f1-score"] + scores["accuracy"] = accuracy_score(y_test, y_pred) + + return scores + + +class MultiClassMultiOutputKNNClassificationEvaluator(Evaluator): + def __init__( + self, + embeds_train, + y_train, + embeds_test, + y_test, + n_neighbors=5, + **kwargs, + ): + super().__init__(**kwargs) + self.embeds_train = embeds_train + self.y_train = y_train + self.embeds_test = embeds_test + self.y_test = y_test + self.n_neighbors = n_neighbors + + def __call__(self): + scores = {} + + mlb = MultiLabelBinarizer() + class_labels = list(self.y_train) + list(self.y_test) + labels = [class_label.split(", ") for class_label in class_labels] + mlb.fit(labels) + train_labels = [class_label.split(", ") for class_label in self.y_train] + test_labels = [class_label.split(", ") for class_label in self.y_test] + + y_train = mlb.transform(train_labels) + y_test = mlb.transform(test_labels) + clf = sklearn.neighbors.KNeighborsClassifier( + n_neighbors=self.n_neighbors, metric="cosine" + ) + logger.info("Fitting KNN classifier...") + clf.fit(self.embeds_train, y_train) + logger.info("Evaluating...") + y_pred = clf.predict(self.embeds_test) + accuracy = accuracy_score(y_test, y_pred) + f1 = f1_score(y_test, y_pred, average="macro") + precision = precision_score(y_test, y_pred, average="macro") + recall = recall_score(y_test, y_pred, average="macro") + lrap = label_ranking_average_precision_score(y_test, y_pred) + scores["f1"] = f1 + scores["accuracy"] = accuracy + scores["precision"] = precision + scores["recall"] = recall + scores["lrap"] = lrap + + return scores + + +class BiGeneMiningEvaluator(Evaluator): + """ + BiGene Mining Evaluator, analogous to Bitext Mining Evaluator https://github.com/embeddings-benchmark/mteb/blob/main/mteb/evaluation/evaluators/BitextMiningEvaluator.py. + + If top_k > 1, then recall@k is also computed. + """ + + def __init__(self, embeds1, embeds2, top_k=1, **kwargs): + super().__init__(**kwargs) + self.n = len(embeds1) + self.embeds1 = np.array(embeds1) + self.embeds2 = np.array(embeds2) + self.gold = list(zip(range(self.n), range(self.n))) + self.top_k = top_k + + def __call__(self): + scores = self.compute_metrics() + return scores + + def compute_metrics(self): + logger.info(f"Finding nearest neighbors... with top_k={self.top_k}") + nearest_neighbors = self._similarity_search( + self.embeds1, self.embeds2, top_k=self.top_k + ) + + # Compute errors + logger.info("Computing metrics...") + labels = [] + predictions = [] + + # Get predictions and labels for top_k=1. + for i, x in enumerate(nearest_neighbors): + j = x[0]["corpus_id"] + predictions.append(j) + labels.append(self.gold[i][1]) + + scores = { + "precision": precision_score( + labels, predictions, zero_division=0, average="weighted" + ), + "recall": recall_score( + labels, predictions, zero_division=0, average="weighted" + ), + "f1": f1_score(labels, predictions, zero_division=0, average="weighted"), + "accuracy": accuracy_score(labels, predictions), + } + + if self.top_k > 1: + # Compute recall@k. + top_k_preds = [] + for i, x in enumerate(nearest_neighbors): + top_k_preds.append([pred["corpus_id"] for pred in x]) + top_k_recall = [ + self.gold[i][1] in top_k_pred + for i, top_k_pred in enumerate(top_k_preds) + ] + scores[f"recall_at_{self.top_k}"] = sum(top_k_recall) / len(top_k_recall) + return scores + + def _similarity_search( + self, + query_embeddings, + corpus_embeddings, + query_chunk_size=100, + corpus_chunk_size=500000, + top_k=1, + score_function=cos_sim, + ): + """This function performs a cosine similarity search between a list of query embeddings and a list of corpus embeddings. + It can be used for Information Retrieval / Semantic Search for corpora up to about 1 Million entries. + :param query_embeddings: A 2 dimensional tensor with the query embeddings. + :param corpus_embeddings: A 2 dimensional tensor with the corpus embeddings. + :param query_chunk_size: Process 100 queries simultaneously. Increasing that value increases the speed, but requires more memory. + :param corpus_chunk_size: Scans the corpus 50k entries at a time. Increasing that value increases the speed, but requires more memory. + :param top_k: Retrieve top k matching entries. + :param score_function: Function for computing scores. By default, cosine similarity. + :return: Returns a list with one entry for each query. Each entry is a list of dictionaries with the keys 'corpus_id' and 'score', sorted by decreasing cosine similarity scores. + """ + query_embeddings = torch.from_numpy(query_embeddings) + corpus_embeddings = torch.from_numpy(corpus_embeddings) + if len(query_embeddings.shape) == 1: + query_embeddings = query_embeddings.unsqueeze(0) + if len(corpus_embeddings.shape) == 1: + corpus_embeddings = corpus_embeddings.unsqueeze(0) + + # Check that corpus and queries are on the same device + if corpus_embeddings.device != query_embeddings.device: + query_embeddings = query_embeddings.to(corpus_embeddings.device) + + queries_result_list = [[] for _ in range(len(query_embeddings))] + + for query_start_idx in range(0, len(query_embeddings), query_chunk_size): + # Iterate over chunks of the corpus + for corpus_start_idx in range(0, len(corpus_embeddings), corpus_chunk_size): + # Compute cosine similarities + cos_scores = score_function( + query_embeddings[ + query_start_idx : query_start_idx + query_chunk_size + ], + corpus_embeddings[ + corpus_start_idx : corpus_start_idx + corpus_chunk_size + ], + ) + + # Get top-k scores + cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk( + cos_scores, + min(top_k, len(cos_scores[0])), + dim=1, + largest=True, + sorted=False, + ) + cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist() + cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist() + + for query_itr in range(len(cos_scores)): + for sub_corpus_id, score in zip( + cos_scores_top_k_idx[query_itr], + cos_scores_top_k_values[query_itr], + ): + corpus_id = corpus_start_idx + sub_corpus_id + query_id = query_start_idx + query_itr + queries_result_list[query_id].append( + {"corpus_id": corpus_id, "score": score} + ) + + # Sort and strip to top_k results + for idx in range(len(queries_result_list)): + queries_result_list[idx] = sorted( + queries_result_list[idx], key=lambda x: x["score"], reverse=True + ) + queries_result_list[idx] = queries_result_list[idx][0:top_k] + + return queries_result_list + + +class EDSEvaluator(Evaluator): + """ + Evolutionary Distance Similarity Evaluator, analogous to Semantic Textual Similarity Evaluator. + Adapted from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/evaluation/evaluators/STSEvaluator.py + """ + + def __init__(self, embeds1, embeds2, gold_scores, **kwargs): + super().__init__(**kwargs) + self.embeds1 = embeds1 + self.embeds2 = embeds2 + self.gold_scores = gold_scores + + def __call__(self): + embeddings1 = np.array(self.embeds1) + embeddings2 = np.array(self.embeds2) + logger.info("Evaluating...") + cosine_scores = paired_cosine_distances(embeddings1, embeddings2) + manhattan_distances = paired_manhattan_distances(embeddings1, embeddings2) + euclidean_distances = paired_euclidean_distances(embeddings1, embeddings2) + + cosine_pearson, _ = pearsonr(self.gold_scores, cosine_scores) + manhattan_pearson, _ = pearsonr(self.gold_scores, manhattan_distances) + euclidean_pearson, _ = pearsonr(self.gold_scores, euclidean_distances) + + top_corr = max( + cosine_pearson, + manhattan_pearson, + euclidean_pearson, + ) + return { + "cos_sim": cosine_pearson, + "manhattan": manhattan_pearson, + "euclidean": euclidean_pearson, + "top_corr": top_corr, + } + + +class RetrievalEvaluator(Evaluator): + """Adapted from + https://github.com/embeddings-benchmark/mteb/blob/main/mteb/evaluation/evaluators/RetrievalEvaluator.py + """ + + def __init__( + self, + corpus_embeds, + query_embeds, + corpus_ids, + query_ids, + qrels: Dict[str, Dict[str, int]], + k_values: List[int] = [5, 10, 50], + score_function: str = "cos_sim", + corpus_chunk_size: int = 50000, + **kwargs, + ): + super().__init__(**kwargs) + self.corpus_embeds = corpus_embeds + self.query_embeds = query_embeds + self.corpus_ids = corpus_ids + self.query_ids = query_ids + self.qrels = qrels + self.k_values = k_values + self.top_k = max(k_values) if "top_k" not in kwargs else kwargs["top_k"] + self.score_function = score_function + self.score_functions = { + "cos_sim": cos_sim, + "dot": dot_score, + } + self.corpus_chunk_size = corpus_chunk_size + + def __call__(self): + results = self.search( + self.corpus_embeds, + self.query_embeds, + self.corpus_ids, + self.query_ids, + self.top_k, + self.score_function, + ) + ndcg, _map, recall, precision, naucs = self.evaluate( + self.qrels, results, self.k_values + ) + mrr, naucs_mrr = self.evaluate_custom(self.qrels, results, self.k_values, "mrr") + scores = { + **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()}, + **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()}, + **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()}, + **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()}, + **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()}, + **{ + k.replace("@", "_at_").replace("_P", "_precision").lower(): v + for k, v in naucs.items() + }, + **{ + k.replace("@", "_at_").replace("_P", "_precision").lower(): v + for k, v in naucs_mrr.items() + }, + } + return scores + + def search( + self, + corpus_embeds, + query_embeds, + corpus_ids, + query_ids, + top_k: int, + score_function: str, + return_sorted: bool = False, + **kwargs, + ) -> dict[str, dict[str, float]]: + # Create embeddings for all queries using model.encode() + # Runs semantic search against the corpus embeddings + # Returns a ranked list with the corpus ids + if score_function not in self.score_functions: + raise ValueError( + f"score function: {score_function} must be either (cos_sim) for cosine similarity or (dot) for dot product" + ) + # make query embeds and corpus embeds torch tensors + query_embeds = torch.from_numpy(query_embeds) + corpus_embeds = torch.from_numpy(corpus_embeds) + itr = range(0, len(corpus_embeds), self.corpus_chunk_size) + results = defaultdict(dict) + # Keep only the top-k docs for each query + result_heaps = defaultdict(list) + for batch_num, corpus_start_idx in enumerate(itr): + logger.info("Searching Batch {}/{}...".format(batch_num + 1, len(itr))) + corpus_end_idx = min( + corpus_start_idx + self.corpus_chunk_size, len(corpus_ids) + ) + sub_corpus_embeds = corpus_embeds[corpus_start_idx:corpus_end_idx] + # Compute similarites using either cosine-similarity or dot product + cos_scores = self.score_functions[score_function]( + query_embeds, sub_corpus_embeds + ) + cos_scores[torch.isnan(cos_scores)] = -1 + + # Get top-k values + cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk( + cos_scores, + min( + top_k + 1, + len(cos_scores[1]) if len(cos_scores) > 1 else len(cos_scores[-1]), + ), + dim=1, + largest=True, + sorted=return_sorted, + ) + cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist() + cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist() + + for query_itr in range(len(query_embeds)): + query_id = query_ids[query_itr] + for sub_corpus_id, score in zip( + cos_scores_top_k_idx[query_itr], cos_scores_top_k_values[query_itr] + ): + corpus_id = corpus_ids[corpus_start_idx + sub_corpus_id] + if corpus_id != query_id: + if len(result_heaps[query_id]) < top_k: + # Push item on the heap + heapq.heappush(result_heaps[query_id], (score, corpus_id)) + else: + # If item is larger than the smallest in the heap, push it on the heap then pop the smallest element + heapq.heappushpop( + result_heaps[query_id], (score, corpus_id) + ) + + for qid in result_heaps: + for score, corpus_id in result_heaps[qid]: + results[qid][corpus_id] = score + + return results + + @staticmethod + def evaluate( + qrels: dict[str, dict[str, int]], + results: dict[str, dict[str, float]], + k_values: List[int], + ignore_identical_ids: bool = True, + ) -> Tuple[Dict[str, float], dict[str, float], dict[str, float], dict[str, float]]: + if ignore_identical_ids: + logger.info( + "For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this." + ) + popped = [] + for qid, rels in results.items(): + for pid in list(rels): + if qid == pid: + results[qid].pop(pid) + popped.append(pid) + + all_ndcgs, all_aps, all_recalls, all_precisions = {}, {}, {}, {} + + for k in k_values: + all_ndcgs[f"NDCG@{k}"] = [] + all_aps[f"MAP@{k}"] = [] + all_recalls[f"Recall@{k}"] = [] + all_precisions[f"P@{k}"] = [] + + map_string = "map_cut." + ",".join([str(k) for k in k_values]) + ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values]) + recall_string = "recall." + ",".join([str(k) for k in k_values]) + precision_string = "P." + ",".join([str(k) for k in k_values]) + evaluator = pytrec_eval.RelevanceEvaluator( + qrels, {map_string, ndcg_string, recall_string, precision_string} + ) + scores = evaluator.evaluate(results) + + for query_id in scores.keys(): + for k in k_values: + all_ndcgs[f"NDCG@{k}"].append(scores[query_id]["ndcg_cut_" + str(k)]) + all_aps[f"MAP@{k}"].append(scores[query_id]["map_cut_" + str(k)]) + all_recalls[f"Recall@{k}"].append(scores[query_id]["recall_" + str(k)]) + all_precisions[f"P@{k}"].append(scores[query_id]["P_" + str(k)]) + ndcg, _map, recall, precision = ( + all_ndcgs.copy(), + all_aps.copy(), + all_recalls.copy(), + all_precisions.copy(), + ) + + for k in k_values: + ndcg[f"NDCG@{k}"] = round(sum(ndcg[f"NDCG@{k}"]) / len(scores), 5) + _map[f"MAP@{k}"] = round(sum(_map[f"MAP@{k}"]) / len(scores), 5) + recall[f"Recall@{k}"] = round(sum(recall[f"Recall@{k}"]) / len(scores), 5) + precision[f"P@{k}"] = round(sum(precision[f"P@{k}"]) / len(scores), 5) + naucs = RetrievalEvaluator.evaluate_abstention( + results, {**all_ndcgs, **all_aps, **all_recalls, **all_precisions} + ) + return ndcg, _map, recall, precision, naucs + + @staticmethod + def evaluate_abstention( + results: dict[str, dict[str, float]], + metric_scores: dict[str, list[float]], + ) -> Dict[str, float]: + """Computes normalized Area Under the Curve on a set of evaluated instances as presented in the paper https://arxiv.org/abs/2402.12997""" + all_sim_scores = [list(results[qid].values()) for qid in list(results.keys())] + all_conf_scores = [ + confidence_scores(sim_scores) for sim_scores in all_sim_scores + ] + conf_fcts = list(all_conf_scores[0].keys()) + all_conf_scores = { + fct: np.array([x[fct] for x in all_conf_scores]) for fct in conf_fcts + } + metric_scores = {k: np.array(v) for k, v in metric_scores.items()} + naucs = {} + + for metric_name, scores in metric_scores.items(): + for fct, conf_scores in all_conf_scores.items(): + naucs[f"nAUC_{metric_name}_{fct}"] = nAUC(conf_scores, scores) + + return naucs + + @staticmethod + def evaluate_custom( + qrels: dict[str, dict[str, int]], + results: dict[str, dict[str, float]], + k_values: List[int], + metric: str, + output_type: str = "all", + ) -> Tuple[Dict[str, float]]: + if metric.lower() in ["mrr", "mrr@k", "mrr_cut"]: + metric_scores = mrr(qrels, results, k_values, output_type) + + elif metric.lower() in ["recall_cap", "r_cap", "r_cap@k"]: + metric_scores = recall_cap(qrels, results, k_values, output_type) + + elif metric.lower() in ["hole", "hole@k"]: + metric_scores = hole(qrels, results, k_values, output_type) + + elif metric.lower() in [ + "acc", + "top_k_acc", + "accuracy", + "accuracy@k", + "top_k_accuracy", + ]: + metric_scores = top_k_accuracy(qrels, results, k_values, output_type) + + naucs = RetrievalEvaluator.evaluate_abstention(results, metric_scores) + metric_scores_avg = {k: sum(v) / len(v) for k, v in metric_scores.items()} + + return metric_scores_avg, naucs diff --git a/geb/geb.py b/geb/geb.py new file mode 100644 index 0000000..e6bf3df --- /dev/null +++ b/geb/geb.py @@ -0,0 +1,129 @@ +from itertools import chain +import logging +import os +import traceback +from typing import Any, List + +from rich.console import Console + +from .eval_utils import set_all_seeds +from .modality import Modality +from .models import BioSeqTransformer +from .tasks.tasks import Task + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class GEB: + """GEB class to run the evaluation pipeline.""" + + def __init__(self, tasks: List[type[Task]], seed: int = 42): + self.tasks = tasks + set_all_seeds(seed) + + def print_selected_tasks(self): + """Print the selected tasks.""" + console = Console() + console.rule("[bold]Selected Tasks\n", style="grey15") + for task in self.tasks: + prefix = " - " + name = f"{task.metadata.display_name}" + category = f", [italic grey39]{task.metadata.type}[/]" + console.print(f"{prefix}{name}{category}") + console.print("\n") + + def run( + self, + model, # type encoder + output_folder: str = "results", + ): + """Run the evaluation pipeline on the selected tasks. + + Args: + model: Model to be used for evaluation + output_folder: Folder where the results will be saved. Default to 'results'. Where it will save the results in the format: + `{output_folder}/{model_name}/{model_revision}/{task_name}.json`. + + Returns: + A list of MTEBResults objects, one for each task evaluated. + """ + # Run selected tasks + self.print_selected_tasks() + results = [] + + for task in self.tasks: + logger.info( + f"\n\n********************** Evaluating {task.metadata.display_name} **********************" + ) + + try: + result = task().run(model) + except Exception as e: + logger.error(e) + logger.error(traceback.format_exc()) + logger.error(f"Error running task {task}") + continue + + results.append(result) + + save_path = get_output_folder(model.hf_name, task, output_folder) + with open(save_path, "w") as f_out: + f_out.write(result.model_dump_json(indent=2)) + return results + + +def get_model(model_name: str, **kwargs: Any) -> type[BioSeqTransformer]: + all_names = get_all_model_names() + for cls in BioSeqTransformer.__subclasses__(): + if model_name in cls.MODEL_NAMES: + return cls(model_name, **kwargs) + raise ValueError(f"Model {model_name} not found in {all_names}.") + + +def get_all_model_names() -> List[str]: + return list( + chain.from_iterable( + cls.MODEL_NAMES for cls in BioSeqTransformer.__subclasses__() + ) + ) + + +def get_all_task_names() -> List[str]: + return [task.metadata.id for task in get_all_tasks()] + + +def get_tasks_by_name(tasks: List[str]) -> List[type[Task]]: + return [_get_task(task) for task in tasks] + + +def get_tasks_by_modality(modality: Modality) -> List[type[Task]]: + return [task for task in get_all_tasks() if task.metadata.modality == modality] + + +def get_all_tasks() -> List[type[Task]]: + return Task.__subclasses__() + + +def _get_task(task_name: str) -> type[Task]: + logger.info(f"Getting task {task_name}") + for task in get_all_tasks(): + if task.metadata.id == task_name: + return task + + raise ValueError( + f"Task {task_name} not found, available tasks are: {[task.metadata.id for task in get_all_tasks()]}" + ) + + +def get_output_folder( + model_hf_name: str, task: type[Task], output_folder: str, create: bool = True +): + output_folder = os.path.join(output_folder, os.path.basename(model_hf_name)) + # create output folder if it does not exist + if create and not os.path.exists(output_folder): + os.makedirs(output_folder) + return os.path.join( + output_folder, + f"{task.metadata.id}.json", + ) diff --git a/geb/modality.py b/geb/modality.py new file mode 100644 index 0000000..146d1ad --- /dev/null +++ b/geb/modality.py @@ -0,0 +1,10 @@ +"""Defines the data modality enum.""" + +from enum import Enum + + +class Modality(Enum): + """Data modality, either DNA or protein sequence.""" + + PROTEIN = "protein" + DNA = "dna" diff --git a/geb/models.py b/geb/models.py new file mode 100644 index 0000000..1c46c81 --- /dev/null +++ b/geb/models.py @@ -0,0 +1,482 @@ +import logging +import re +from abc import ABC, abstractmethod +from functools import partial +from types import SimpleNamespace +from typing import Dict, List, Literal, Optional + +import numpy as np +import torch +import tqdm as tqdm +from datasets import Dataset +from torch import Tensor +from torch.nn import functional as F +from torch.utils.data import DataLoader +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoModelForMaskedLM, + AutoTokenizer, + BatchEncoding, + DefaultDataCollator, + T5EncoderModel, + T5Tokenizer, +) +from transformers.modeling_outputs import BaseModelOutput + +from .modality import Modality +from .eval_utils import ForwardHook, pool + +logger = logging.getLogger(__name__) + + +class BioSeqTransformer(ABC): + """ + Abstract class to wrap models which map biological sequences (DNA/Prot) to embeddings. + Modelled after SentenceTransformer (https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/SentenceTransformer.py) + + Args: + model_name: Name or path to the pretrained model. + layers: List of model layers to probe. Can be integers or "mid" or "last". + devices: List of device ids for inference. If cuda is not available, will use cpu. + num_processes: Number of processes to use for data loading. + max_seq_length: Maximum sequence length of the input sequences. + l2_norm: If true, embeddings are L2-normalized before they are returned. + batch_size: Batch size for encoding. + pool_type: Pooling strategy to use. One of "mean", "max", "cls", "last". + """ + + def __init__( + self, + model_name: str, + layers: Optional[List[int] | Literal["mid"] | Literal["last"]] = None, + devices: List[int] = [0], + num_processes: int = 16, + max_seq_length: int = 1024, + l2_norm: bool = False, + batch_size: int = 128, + pool_type: str = "mean", + ): + super().__init__() + + self.id = self.__class__.__name__ + self.hf_name = model_name + self.encoder = self._load_model(model_name) + if not hasattr(self.encoder, "config"): + raise ValueError( + 'The model from `self._load_model()` must have a "config" attribute.' + ) + self.config = self.encoder.config + self.tokenizer = self._get_tokenizer(model_name) + self.num_param = sum(p.numel() for p in self.encoder.parameters()) + self.data_collator = DefaultDataCollator() + self.gpu_count = len(devices) + self.l2_norm = l2_norm + + self.device = torch.device( + f"cuda:{devices[0]}" if torch.cuda.is_available() else "cpu" + ) + self.num_processes = num_processes + self.max_seq_length = max_seq_length + self.batch_size = batch_size + self.pool_type = pool_type + + if self.gpu_count > 1: + self.encoder = torch.nn.DataParallel(self.encoder, device_ids=devices) + self.encoder.to(self.device) + self.encoder.eval() + + mid_layer = self.num_layers // 2 + last_layer = self.num_layers - 1 + mid_layer_label = f"mid ({mid_layer})" + last_layer_label = f"last ({self.num_layers - 1})" + + if layers is None: + logger.debug(f"Using default layers: {mid_layer_label}, {last_layer_label}") + self.layers = [mid_layer, last_layer] + self.layer_labels = [mid_layer_label, last_layer_label] + elif layers == "mid": + self.layers = [mid_layer] + self.layer_labels = [mid_layer_label] + elif layers == "last": + self.layers = [last_layer] + self.layer_labels = [last_layer_label] + else: + self.layers = layers + self.layer_labels = [str(layer) for layer in layers] + + def _encode_single_batch(self, batch_dict: Dict[str, Tensor]): + """Returns the output embedding for the given batch with shape [batch, num_layers, D].""" + outputs = self.encoder(**batch_dict, output_hidden_states=True) + embeds = [outputs.hidden_states[layer] for layer in self.layers] + embeds = [ + pool(layer_embeds, batch_dict["attention_mask"], self.pool_type) + for layer_embeds in embeds + ] + # Stack with shape [B, num_layers, D]. + embeds = torch.stack(embeds, dim=1) + return embeds + + def _load_model(self, model_name): + return AutoModel.from_pretrained(model_name, trust_remote_code=True) + + def _get_tokenizer(self, model_name): + return AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + def _tokenize_func( + self, tokenizer, examples: Dict[str, List], max_seq_length: int + ) -> BatchEncoding: + batch_dict = tokenizer( + examples["input_seqs"], + max_length=max_seq_length, + padding=True, + truncation=True, + ) + return batch_dict + + @property + def metadata(self) -> Dict: + return { + "hf_name": self.hf_name, + "revision": "...", # TODO: Fix + "num_layers": self.num_layers, + "num_params": self.num_param, + "embed_dim": self.embed_dim, + } + + @property + @abstractmethod + def num_layers(self) -> int: + pass + + @property + @abstractmethod + def embed_dim(self) -> int: + pass + + @property + @abstractmethod + def modality(self) -> Modality: + pass + + @torch.no_grad() + def encode(self, sequences, **kwargs) -> np.ndarray: + """Returns a list of embeddings for the given sequences. + Args: + sequences (`List[str]`): List of sequences to encode + Returns: + `np.ndarray`: Embeddings for the given sequences of shape [num_sequences, num_layers, embedding_dim]. + """ + dataset = Dataset.from_dict({"input_seqs": sequences}) + dataset.set_transform( + partial( + self._tokenize_func, self.tokenizer, max_seq_length=self.max_seq_length + ) + ) + data_loader = DataLoader( + dataset, + batch_size=self.batch_size * self.gpu_count, + shuffle=False, + drop_last=False, + num_workers=self.num_processes, + collate_fn=self.data_collator, + pin_memory=True, + ) + + if max(self.layers) >= self.num_layers: + raise ValueError( + f"Layer {max(self.layers)} is not available in the model. Choose a layer between 0 and {self.num_layers - 1}" + ) + + encoded_embeds = [] + for batch_dict in tqdm.tqdm( + data_loader, desc="encoding", mininterval=10, disable=len(sequences) < 128 + ): + batch_dict = {k: v.to(self.device) for k, v in batch_dict.items()} + + embeds = self._encode_single_batch(batch_dict) + + if self.l2_norm: + embeds = F.normalize(embeds, p=2, dim=-1) + encoded_embeds.append(embeds.cpu().numpy()) + + return np.concatenate(encoded_embeds, axis=0) + + +class ESM(BioSeqTransformer): + """ESM model from https://huggingface.co/docs/transformers/en/model_doc/esm""" + + MODEL_NAMES = [ + "facebook/esm2_t6_8M_UR50D", + "facebook/esm2_t12_35M_UR50D", + "facebook/esm2_t30_150M_UR50D", + "facebook/esm2_t33_650M_UR50D", + "facebook/esm2_t36_3B_UR50D", + "facebook/esm2_t48_15B_UR50D", + ] + + @property + def modality(self) -> Modality: + return Modality.PROTEIN + + @property + def num_layers(self) -> int: + return self.config.num_hidden_layers + + @property + def embed_dim(self) -> int: + return self.config.hidden_size + + +class ESM3(BioSeqTransformer): + """ESM3 model from https://github.com/evolutionaryscale/esm""" + + MODEL_NAMES = ["esm3_sm_open_v1"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Register forward hooks to store embeddings per layer. + self.hooks = [ + ForwardHook(self.encoder.transformer.blocks[layer]) for layer in self.layers + ] + + @property + def modality(self) -> Modality: + return Modality.PROTEIN + + @property + def num_layers(self) -> int: + return self.config.num_hidden_layers + + @property + def embed_dim(self) -> int: + return self.config.hidden_size + + def _load_model(self, model_name): + try: + from esm.models.esm3 import ESM3 as ModelESM3 + except ImportError: + raise ImportError( + "ESM3 is not installed. Please install it with `pip install esm`." + ) + model = ModelESM3.from_pretrained("esm3_sm_open_v1") + model.config = SimpleNamespace( + num_hidden_layers=len(model.transformer.blocks), + hidden_size=model.transformer.blocks[0].ffn[-1].out_features, + ) + return model + + def _get_tokenizer(self, model_name): + try: + from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer + except ImportError: + raise ImportError( + "ESM3 is not installed. Please install it with `pip install esm`." + ) + return EsmSequenceTokenizer() + + def _encode_single_batch(self, batch_dict: Dict[str, Tensor]): + _ = self.encoder.forward(sequence_tokens=batch_dict["input_ids"]) + embeds = [hook.output for hook in self.hooks] + embeds = [ + pool(layer_embeds, batch_dict["attention_mask"], self.pool_type) + for layer_embeds in embeds + ] + # Stack with shape [B, num_layers, D]. + embeds = torch.stack(embeds, dim=1) + embeds = embeds.to(torch.float32) + return embeds + + +class ProtT5(BioSeqTransformer): + """ProtT5 model from https://github.com/agemagician/ProtTrans""" + + MODEL_NAMES = [ + "Rostlab/prot_t5_xl_uniref50", + "Rostlab/prot_t5_xl_bfd", + "Rostlab/prot_t5_xxl_uniref50", + "Rostlab/prot_t5_xxl_bfd", + ] + + @property + def modality(self) -> Modality: + return Modality.PROTEIN + + @property + def num_layers(self) -> int: + return self.config.num_layers + + @property + def embed_dim(self) -> int: + return self.config.d_model + + def _load_model(self, model_name): + return T5EncoderModel.from_pretrained(model_name) + + def _get_tokenizer(self, model_name): + return T5Tokenizer.from_pretrained(model_name, do_lower_case=False) + + def _tokenize_func( + self, tokenizer, examples: Dict[str, List], max_seq_length: int + ) -> BatchEncoding: + example_sequences = examples["input_seqs"] + # Add space between amino acids to make sure they are tokenized correctly. + example_sequences = [" ".join(sequence) for sequence in example_sequences] + example_sequences = [ + re.sub(r"[UZOB]", "X", sequence) for sequence in example_sequences + ] + batch_dict = tokenizer( + example_sequences, + max_length=max_seq_length, + padding=True, + truncation=True, + add_special_tokens=True, + ) + + return batch_dict + + +class ProGen(BioSeqTransformer): + """ProGen models from https://github.com/salesforce/progen.""" + + MODEL_NAMES = [ + "hugohrban/progen2-small", + "hugohrban/progen2-medium", + "hugohrban/progen2-base", + "hugohrban/progen2-large", + "hugohrban/progen2-xlarge", + ] + + @property + def modality(self) -> Modality: + return Modality.PROTEIN + + @property + def num_layers(self) -> int: + return self.config.n_layer + + @property + def embed_dim(self) -> int: + return self.config.embed_dim + + def _load_model(self, model_name): + return AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + + def _get_tokenizer(self, model_name_or_path): + tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, trust_remote_code=True + ) + tokenizer.pad_token = "<|pad|>" + return tokenizer + + def _encode_single_batch(self, batch_dict: Dict[str, Tensor]): + """Returns the output embedding for the given batch with shape [batch, num_layers, D].""" + outputs: BaseModelOutput = self.encoder( + input_ids=batch_dict["input_ids"], + output_hidden_states=True, + use_cache=False, + ) + embeds = [outputs.hidden_states[layer] for layer in self.layers] + embeds = [ + pool(layer_embeds, batch_dict["attention_mask"], self.pool_type) + for layer_embeds in embeds + ] + # Stack with shape [B, num_layers, D]. + embeds = torch.stack(embeds, dim=1) + return embeds + + +class EvoModel(BioSeqTransformer): + """https://github.com/evo-design/evo.""" + + MODEL_NAMES = [ + "togethercomputer/evo-1-8k-base", + "togethercomputer/evo-1-131k-base", + ] + + @property + def modality(self) -> Modality: + return Modality.DNA + + @property + def num_layers(self) -> int: + return self.config.num_layers + + @property + def embed_dim(self) -> int: + return self.config.hidden_size + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Register forward hooks to store embeddings per layer. + self.hooks = [] + for layer in self.layers: + # For the last layer, get the output of `backbone.norm`, which directly precedes `backbone.unembed`. + # This is equivalent to the approach in https://github.com/evo-design/evo/issues/32. + if layer == self.num_layers - 1 or layer == -1: + self.hooks.append(ForwardHook(self.encoder.backbone.norm)) + else: + self.hooks.append(ForwardHook(self.encoder.backbone.blocks[layer])) + + def _load_model(self, model_name): + config = AutoConfig.from_pretrained( + model_name, trust_remote_code=True, revision="1.1_fix" + ) + model = AutoModelForCausalLM.from_pretrained( + model_name, config=config, trust_remote_code=True, revision="1.1_fix" + ) + return model + + def _get_tokenizer(self, model_name): + tokenizer = AutoTokenizer.from_pretrained( + model_name, revision="1.1_fix", trust_remote_code=True + ) + # Evo tokenizer is missing pad_token by default. + tokenizer.add_special_tokens({"pad_token": "N"}) + return tokenizer + + def _encode_single_batch(self, batch_dict: Dict[str, Tensor]): + _ = self.encoder(batch_dict["input_ids"], use_cache=False) + embeds = [hook.output for hook in self.hooks] + # The hook output for Evo middle layers is a tuple (embedding, inference_params=None). + embeds = [x[0] if isinstance(x, tuple) else x for x in embeds] + embeds = [ + pool(layer_embeds, batch_dict["attention_mask"], self.pool_type) + for layer_embeds in embeds + ] + # Stack with shape [B, num_layers, D]. + embeds = torch.stack(embeds, dim=1) + embeds = embeds.to(torch.float32) + return embeds + + +class NTModel(BioSeqTransformer): + """Nucleotide Transformer https://github.com/instadeepai/nucleotide-transformer""" + + MODEL_NAMES = [ + "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species", + "InstaDeepAI/nucleotide-transformer-v2-100m-multi-species", + "InstaDeepAI/nucleotide-transformer-v2-250m-multi-species", + "InstaDeepAI/nucleotide-transformer-v2-500m-multi-species", + "InstaDeepAI/nucleotide-transformer-2.5b-multi-species", + ] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_seq_length = self.tokenizer.model_max_length + + @property + def modality(self) -> Modality: + return Modality.DNA + + @property + def num_layers(self) -> int: + return self.config.num_hidden_layers + + @property + def embed_dim(self) -> int: + return self.config.hidden_size + + def _load_model(self, model_name): + return AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True) diff --git a/geb/results.py b/geb/results.py new file mode 100644 index 0000000..afa378f --- /dev/null +++ b/geb/results.py @@ -0,0 +1,50 @@ +from typing import List +from pydantic import BaseModel + + +class Metric(BaseModel): + metric_id: str + display_name: str + display_value: float + description: str + + +class TaskResult(BaseModel): + task_id: str + display_name: str + description: str + metrics: List[Metric] + + +class TaskResults(BaseModel): + results_id: str + description: str + task_results: List[TaskResult] + + +# Example task results that conforms to above data specification +mock_task_results = { + "results_id": "result_123", + "description": "Overall results of the tasks", + "task_results": [ + { + "task_id": "task_1", + "display_name": "Task 1", + "description": "Description of Task 1", + "metrics": [ + { + "metric_id": "metric_1", + "display_name": "Metric 1", + "display_value": "Value 1", + "description": "Description of Metric 1", + }, + { + "metric_id": "metric_2", + "display_name": "Metric 2", + "display_value": "Value 2", + "description": "Description of Metric 2", + }, + ], + }, + ], +} diff --git a/geb/tasks/__init__.py b/geb/tasks/__init__.py new file mode 100644 index 0000000..29bc042 --- /dev/null +++ b/geb/tasks/__init__.py @@ -0,0 +1,13 @@ +# ruff: noqa: F403 + +from .tasks import Task +from .eds_tasks import * +from .pair_classification_tasks import * +from .retrieval_tasks import * +from .classification_tasks import * +from .clustering_tasks import * +from .bigene_mining_tasks import * + +__all__ = [ + "Task", +] diff --git a/geb/tasks/bigene_mining_tasks.py b/geb/tasks/bigene_mining_tasks.py new file mode 100644 index 0000000..1b7607e --- /dev/null +++ b/geb/tasks/bigene_mining_tasks.py @@ -0,0 +1,77 @@ +""" +Bigene mining tasks are analogous to bitext matching tasks, but for genes. +Cosine similarity is used to mine genes of related functions from different organisms. +""" + +import logging +from collections import defaultdict + +from geb.evaluators import BiGeneMiningEvaluator +from geb.modality import Modality +from geb.models import BioSeqTransformer +from geb.tasks.tasks import Dataset, Task, TaskMetadata, TaskResult + +logger = logging.getLogger(__name__) + + +def run_bigene_mining_tasks( + model: BioSeqTransformer, metadata: TaskMetadata, top_k: int = 1 +) -> TaskResult: + """Evaluate bigene mining task. Utilizes the BiGeneMiningEvaluator.""" + if len(metadata.datasets) != 1: + raise ValueError("BiGeneMining tasks require 1 dataset.") + ds = metadata.datasets[0].load()["train"] + layer_results = defaultdict(dict) + embeds1 = model.encode(ds["Seq1"]) + embeds2 = model.encode(ds["Seq2"]) + for i, layer in enumerate(model.layers): + evaluator = BiGeneMiningEvaluator(embeds1[:, i], embeds2[:, i], top_k=top_k) + layer_results["layers"][layer] = evaluator() + logger.info( + f"Layer: {layer}, {metadata.display_name} matching results: {layer_results['layers'][layer]}" + ) + return TaskResult.from_dict(metadata, layer_results, model.metadata) + + +class BacArchBiGeneMining(Task): + metadata = TaskMetadata( + id="bacarch_bigene", + display_name="BacArch BiGene", + description="Evaluate on BacArch bigene matching task between bacterial (E.coli K-12) proteins and archaeal (Sulfolobus acidocaldarius DSM 639) proteins.", + type="bigene_mining", + modality=Modality.PROTEIN, + datasets=[ + Dataset( + path="tattabio/bac_arch_bigene", + revision="main", + ) + ], + primary_metric_id="f1", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_bigene_mining_tasks(model, self.metadata) + + +class ModACParalogyBiGeneMining(Task): + # ModAC Paralogy matching with top_k=1 is too strict (most models have accuracy < 0.1%) + # Instead use recall@5 as the main metric. + TOP_K = 5 + + metadata = TaskMetadata( + id="modac_paralogy_bigene", + display_name="ModAC Paralogy BiGene", + description="Evaluate on paralogy bitext matching task between paralogous protein (ModA and ModC).", + type="bigene_mining", + modality=Modality.PROTEIN, + datasets=[ + Dataset( + path="tattabio/modac_paralogy_bigene", + revision="main", + ) + ], + primary_metric_id="recall_at_5", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_bigene_mining_tasks(model, self.metadata, top_k=self.TOP_K) diff --git a/geb/tasks/classification_tasks.py b/geb/tasks/classification_tasks.py new file mode 100644 index 0000000..99baac7 --- /dev/null +++ b/geb/tasks/classification_tasks.py @@ -0,0 +1,208 @@ +""" +Classification tasks take in biological sequence and functional labels. +Multi-class and/or multi-label classification tasks are supported. +""" + +import logging +from collections import defaultdict + +import datasets +import numpy as np + +from geb.eval_utils import merge_split_elem_embeds +from geb.evaluators import ( + MultiClassMultiOutputKNNClassificationEvaluator, + logRegClassificationEvaluator, +) +from geb.modality import Modality +from geb.models import BioSeqTransformer +from geb.tasks.tasks import Dataset, Task, TaskMetadata, TaskResult + +logger = logging.getLogger(__name__) + + +def split_sequences( + ds: datasets.DatasetDict, max_seq_length: int +) -> datasets.DatasetDict: + """Split sequences into chunks of max_seq_length using datasets.Dataset.map().""" + + def _split_sequence(examples, max_seq_length): + assert ( + len(examples["Sequence"]) == 1 + ), "split map function should use batch size of 1." + example = {k: v[0] for k, v in examples.items()} + seq = example["Sequence"] + # Split by chunks of max_seq_length. + seq_split = [ + seq[i : i + max_seq_length] for i in range(0, len(seq), max_seq_length) + ] + # Repeat other fields by the number of splits. + example = { + k: [v] * len(seq_split) for k, v in example.items() if k != "Sequence" + } + example["Sequence"] = seq_split + return example + + ds = ds.map( + _split_sequence, + batched=True, + batch_size=1, + fn_kwargs={"max_seq_length": max_seq_length}, + keep_in_memory=True, + load_from_cache_file=False, + ) + return ds + + +def run_classification_task( + model: BioSeqTransformer, metadata: TaskMetadata +) -> TaskResult: + """Evaluate on classification tasks using logistic regression classifier.""" + ds = metadata.datasets[0].load() + layer_results = defaultdict(dict) + train_embeds = model.encode(ds["train"]["Sequence"]) + test_embeds = model.encode(ds["test"]["Sequence"]) + for i, layer in enumerate(model.layers): + layer_results["layers"][layer] = logRegClassificationEvaluator( + train_embeds[:, i], + ds["train"]["Label"], + test_embeds[:, i], + ds["test"]["Label"], + )() + logger.info( + f"Layer: {layer}, {metadata.display_name} results: {layer_results['layers'][layer]}" + ) + return TaskResult.from_dict(metadata, layer_results, model.metadata) + + +class EnzymeCommissionClassification(Task): + metadata = TaskMetadata( + id="ec_classification", + display_name="EC Classification", + description="Evaluate on Enzyme Commission number classification task.", + type="classification", + modality=Modality.PROTEIN, + datasets=[ + Dataset( + path="tattabio/ec_classification", + revision="main", + ) + ], + primary_metric_id="f1", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_classification_task(model, self.metadata) + + +class EnzymeCommissionDNAClassification(Task): + metadata = TaskMetadata( + id="ec_dna_classification", + display_name="EC Classification", + description="Evaluate on Enzyme Commission number classification task using DNA sequences.", + type="classification", + modality=Modality.DNA, + datasets=[ + Dataset( + path="tattabio/ec_classification_dna", + revision="main", + ) + ], + primary_metric_id="f1", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_classification_task(model, self.metadata) + + +class ConvergentEnzymesClassification(Task): + metadata = TaskMetadata( + id="convergent_enzymes_classification", + display_name="Convergent Enzymes Classification", + description="Evaluate on convergent enzymes classification task, where convergent enzymes are proteins with the same EC number but without blastp hits against each other", + type="classification", + modality=Modality.PROTEIN, + datasets=[ + Dataset( + path="tattabio/convergent_enzymes", + revision="main", + ) + ], + primary_metric_id="f1", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_classification_task(model, self.metadata) + + +def run_mibig_task(model: BioSeqTransformer, metadata: TaskMetadata) -> TaskResult: + """ + Evaluate on MIBIG classification tasks. Multiclass, multi-label KNN classification is used for evaluation. + """ + ds = metadata.datasets[0].load() + if metadata.modality == Modality.DNA: + # MIBiG DNA sequences can be very long. Instead of truncating to max_seq_length, + # split into multiple sequences and mean pool the resulting embeddings. + ds = split_sequences(ds, model.max_seq_length) + + layer_results = defaultdict(dict) + train_embeds = model.encode(ds["train"]["Sequence"]) + test_embeds = model.encode(ds["test"]["Sequence"]) + + train_ids = ds["train"]["Entry"] + test_ids = ds["test"]["Entry"] + train_labels = ds["train"]["class"] + test_labels = ds["test"]["class"] + train_id_to_label = {id: label for id, label in zip(train_ids, train_labels)} + test_id_to_label = {id: label for id, label in zip(test_ids, test_labels)} + # Mean pool embeds with the same ID. + train_ids, train_embeds = merge_split_elem_embeds(train_ids, train_embeds) + test_ids, test_embeds = merge_split_elem_embeds(test_ids, test_embeds) + # Gather the labels after merging by unique ID. + train_labels = np.array([train_id_to_label[id] for id in train_ids]) + test_labels = np.array([test_id_to_label[id] for id in test_ids]) + + for i, layer in enumerate(model.layers): + evaluator = MultiClassMultiOutputKNNClassificationEvaluator( + train_embeds[:, i], train_labels, test_embeds[:, i], test_labels + ) + layer_results["layers"][layer] = evaluator() + logger.info( + f"Layer: {layer}, MIBiG classification results: {layer_results['layers'][layer]}" + ) + return TaskResult.from_dict(metadata, layer_results, model.metadata) + + +class MIBiGProteinClassification(Task): + metadata = TaskMetadata( + id="MIBIG_protein_classification", + display_name="MIBiG Classification", + description="Biosynthetic Gene cluster classification using protein sequences on MIBIG dataset.", + type="classification", + modality=Modality.PROTEIN, + datasets=[Dataset(path="tattabio/mibig_classification_prot", revision="main")], + primary_metric_id="f1", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_mibig_task(model, self.metadata) + + +class MIBiGDNAClassification(Task): + metadata = TaskMetadata( + id="MIBIG_dna_classification", + display_name="MIBiG Classification", + description="Biosynthetic Gene cluster classification using DNA sequences on MIBIG dataset.", + type="classification", + modality=Modality.DNA, + datasets=[ + Dataset( + path="tattabio/mibig_classification_dna", + revision="main", + ) + ], + primary_metric_id="f1", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_mibig_task(model, self.metadata) diff --git a/geb/tasks/clustering_tasks.py b/geb/tasks/clustering_tasks.py new file mode 100644 index 0000000..c12ddad --- /dev/null +++ b/geb/tasks/clustering_tasks.py @@ -0,0 +1,70 @@ +""" +Biological sequences are clustered and performance is determined by how well clustering matches assigned labels. +""" + +import logging +from collections import defaultdict + +from geb.evaluators import ClusteringEvaluator +from geb.modality import Modality +from geb.models import BioSeqTransformer +from geb.tasks.tasks import Dataset, Task, TaskMetadata, TaskResult + +logger = logging.getLogger(__name__) + + +def run_clustering_task(model: BioSeqTransformer, metadata: TaskMetadata) -> TaskResult: + """Evaluate clustering task. Utilizes the ClusteringEvaluator.""" + if len(metadata.datasets) != 1: + raise ValueError("Clustering tasks require 1 dataset.") + ds = metadata.datasets[0].load()["train"] + embeds = model.encode(ds["Sequence"]) + layer_results = defaultdict(dict) + for i, layer in enumerate(model.layers): + labels = ds["Label"] + evaluator = ClusteringEvaluator(embeds[:, i], labels) + layer_results["layers"][layer] = evaluator() + logger.info( + f"Layer: {layer}, {metadata.display_name} results: {layer_results['layers'][layer]}" + ) + return TaskResult.from_dict(metadata, layer_results, model.metadata) + + +class RNAclustering(Task): + metadata = TaskMetadata( + id="ecoli_rna_clustering", + display_name="E.coli RNA Clustering", + description="Evaluate on RNA clustering task for sRNA/tRNA/rRNA segments in E.coli K-12.", + type="clustering", + modality=Modality.DNA, + datasets=[ + Dataset( + path="tattabio/e_coli_rnas", + revision="main", + ) + ], + primary_metric_id="v_measure", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_clustering_task(model, self.metadata) + + +class MopBClustering(Task): + metadata = TaskMetadata( + id="mopb_clustering", + display_name="MopB Clustering", + description="Evaluate on MopB clustering task.", + type="clustering", + modality=Modality.PROTEIN, + datasets=[ + Dataset( + path="tattabio/mopb_clustering", + revision="main", + ) + ], + primary_metric_id="v_measure", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_clustering_task(model, self.metadata) diff --git a/geb/tasks/eds_tasks.py b/geb/tasks/eds_tasks.py new file mode 100644 index 0000000..849bed9 --- /dev/null +++ b/geb/tasks/eds_tasks.py @@ -0,0 +1,198 @@ +""" +Evolutionary Distance Similarity (EDS) tasks compare embedding distances to continuous evolutionary distances. +The label distances are typically derived from phylogenetic trees. +""" + +import logging +from collections import defaultdict + +import numpy as np +import pandas as pd + +from geb.evaluators import EDSEvaluator +from geb.modality import Modality +from geb.models import BioSeqTransformer +from geb.tasks.tasks import Dataset, Task, TaskMetadata, TaskResult + +logger = logging.getLogger(__name__) + + +def run_eds_task(model: BioSeqTransformer, metadata: TaskMetadata) -> TaskResult: + """Evaluate phylogeny distance correlation task. Utilizes the Evolutionary Distance Similarity (EDS) evaluator.""" + if len(metadata.datasets) != 2: + raise ValueError("Phylogeny tasks require 2 datasets: sequences and distances.") + + ds = metadata.datasets[0].load()["train"] + distance_df = metadata.datasets[1].load()["train"].to_pandas() + assert isinstance( + distance_df, pd.DataFrame + ), f"Expected DataFrame, got {type(distance_df)}" + + id_index_dict = {k: i for i, k in enumerate(ds["Entry"])} + distance_df["embeds1"] = None + distance_df["embeds2"] = None + test_embeds = model.encode(ds["Sequence"]) + layer_results = defaultdict(dict) + for i, layer in enumerate(model.layers): + for row_idx, row in distance_df.iterrows(): + id1 = row["ID1"] + id2 = row["ID2"] + embedding1 = test_embeds[id_index_dict[id1], i] + embedding2 = test_embeds[id_index_dict[id2], i] + distance_df.at[row_idx, "embeds1"] = embedding1 + distance_df.at[row_idx, "embeds2"] = embedding2 + embeds1 = np.array(distance_df["embeds1"].tolist()) + embeds2 = np.array(distance_df["embeds2"].tolist()) + dists = np.array(distance_df["distance"].tolist()) + evaluator = EDSEvaluator(embeds1, embeds2, dists) + layer_results["layers"][layer] = evaluator() + # log results + logger.info( + f"Layer: {layer}, {metadata.display_name} distance correlation results: {layer_results['layers'][layer]}" + ) + + return TaskResult.from_dict(metadata, layer_results, model.metadata) + + +class RpobBacPhylogeny(Task): + metadata = TaskMetadata( + id="rpob_bac_phylogeny", + display_name="RpoB Bacterial Phylogeny", + description="Evaluate on RpoB phylogeny distance correlation task for Bacterial sequences.", + type="eds", + modality=Modality.PROTEIN, + datasets=[ + Dataset( + path="tattabio/rpob_bac_phylogeny_sequences", + revision="main", + ), + Dataset( + path="tattabio/rpob_bac_phylogeny_distances", + revision="main", + ), + ], + primary_metric_id="top_corr", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_eds_task(model, self.metadata) + + +class RpobArchPhylogeny(Task): + metadata = TaskMetadata( + id="rpob_arch_phylogeny", + display_name="RpoB Archaeal Phylogeny", + description="Evaluate on RpoB phylogeny distance correlation task for Archaeal sequences.", + type="eds", + modality=Modality.PROTEIN, + datasets=[ + Dataset( + path="tattabio/rpob_arch_phylogeny_sequences", + revision="main", + ), + Dataset( + path="tattabio/rpob_arch_phylogeny_distances", + revision="main", + ), + ], + primary_metric_id="top_corr", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_eds_task(model, self.metadata) + + +class FeFePhylogeny(Task): + metadata = TaskMetadata( + id="fefe_phylogeny", + display_name="FeFeHydrogenase Phylogeny", + description="Evaluate on FeFeHydrogenase phylogeny distance correlation task.", + type="eds", + modality=Modality.PROTEIN, + datasets=[ + Dataset( + path="tattabio/fefe_phylogeny_sequences", + revision="main", + ), + Dataset( + path="tattabio/fefe_phylogeny_distances", + revision="main", + ), + ], + primary_metric_id="top_corr", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_eds_task(model, self.metadata) + + +class Bac16SPhylogeny(Task): + metadata = TaskMetadata( + id="bac_16S_phylogeny", + display_name="16S Bacterial Phylogeny", + description="Evaluate on 16S Bacterial phylogeny distance correlation task.", + type="eds", + modality=Modality.DNA, + datasets=[ + Dataset( + path="tattabio/bac_16S_sequences", + revision="main", + ), + Dataset( + path="tattabio/bac_16S_distances", + revision="main", + ), + ], + primary_metric_id="top_corr", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_eds_task(model, self.metadata) + + +class Arch16SPhylogeny(Task): + metadata = TaskMetadata( + id="arch_16S_phylogeny", + display_name="16S Archaeal Phylogeny", + description="Evaluate on 16S Archaeal phylogeny distance correlation task.", + type="eds", + modality=Modality.DNA, + datasets=[ + Dataset( + path="tattabio/arch_16S_sequences", + revision="main", + ), + Dataset( + path="tattabio/arch_16S_distances", + revision="main", + ), + ], + primary_metric_id="top_corr", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_eds_task(model, self.metadata) + + +class Euk18SPhylogeny(Task): + metadata = TaskMetadata( + id="euk_18S_phylogeny", + display_name="18S Eukaryotic Phylogeny", + description="Evaluate on 18S Eukaryotic phylogeny distance correlation task.", + type="eds", + modality=Modality.DNA, + datasets=[ + Dataset( + path="tattabio/euk_18S_sequences", + revision="main", + ), + Dataset( + path="tattabio/euk_18S_distances", + revision="main", + ), + ], + primary_metric_id="top_corr", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_eds_task(model, self.metadata) diff --git a/geb/tasks/pair_classification_tasks.py b/geb/tasks/pair_classification_tasks.py new file mode 100644 index 0000000..316b201 --- /dev/null +++ b/geb/tasks/pair_classification_tasks.py @@ -0,0 +1,96 @@ +""" +Pair classification tasks evaluating distances between functionally relevant gene pairs. +For instance, distance thresholds distinguish between co-transcribed and non-co-transcribed gene pairs. +""" + +import logging +from collections import defaultdict + +from geb.evaluators import PairClassificationEvaluator +from geb.modality import Modality +from geb.models import BioSeqTransformer +from geb.tasks.tasks import Dataset, Task, TaskMetadata, TaskResult + +from ..eval_utils import paired_dataset + +logger = logging.getLogger(__name__) + + +def run_pair_classification_task( + model: BioSeqTransformer, metadata: TaskMetadata +) -> TaskResult: + """Evaluate pair classification task. Utilizes the PairClassificationEvaluator.""" + if len(metadata.datasets) != 1: + raise ValueError("Pair classification tasks require 1 dataset.") + ds = metadata.datasets[0].load()["train"] + embeds = model.encode(ds["Sequence"]) + layer_results = defaultdict(dict) + for i, layer in enumerate(model.layers): + labels = ds["Label"] + embeds1, embeds2, labels = paired_dataset(labels, embeds[:, i]) + evaluator = PairClassificationEvaluator(embeds1, embeds2, labels) + layer_results["layers"][layer] = evaluator() + logger.info( + f"Layer: {layer}, {metadata.display_name} classification results: {layer_results['layers'][layer]}" + ) + return TaskResult.from_dict(metadata, layer_results, model.metadata) + + +class EcoliOperon(Task): + metadata = TaskMetadata( + id="ecoli_operonic_pair", + display_name="E.coli Operonic Pair", + description="Evaluate on E.coli K-12 operonic pair classification task.", + type="pair_classification", + modality=Modality.PROTEIN, + datasets=[ + Dataset( + path="tattabio/ecoli_operonic_pair", + revision="main", + ) + ], + primary_metric_id="top_ap", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_pair_classification_task(model, self.metadata) + + +class CyanoOperonPair(Task): + metadata = TaskMetadata( + id="cyano_operonic_pair", + display_name="Cyano Operonic Pair", + description="Evaluate on Cyano operonic pair classification task.", + type="pair_classification", + modality=Modality.PROTEIN, + datasets=[ + Dataset( + path="tattabio/cyano_operonic_pair", + revision="main", + ) + ], + primary_metric_id="top_ap", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_pair_classification_task(model, self.metadata) + + +class VibrioOperonPair(Task): + metadata = TaskMetadata( + id="vibrio_operonic_pair", + display_name="Vibrio Operonic Pair", + description="Evaluate on Vibrio operonic pair classification task.", + type="pair_classification", + modality=Modality.PROTEIN, + datasets=[ + Dataset( + path="tattabio/vibrio_operonic_pair", + revision="main", + ) + ], + primary_metric_id="top_ap", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_pair_classification_task(model, self.metadata) diff --git a/geb/tasks/retrieval_tasks.py b/geb/tasks/retrieval_tasks.py new file mode 100644 index 0000000..132754e --- /dev/null +++ b/geb/tasks/retrieval_tasks.py @@ -0,0 +1,98 @@ +""" +Retrieval tasks find functionally relevant genes in a corpus of genes based on a query gene. +Typically corpus is derived from a different phylogenetic group than the query genes. +""" + +import logging +from collections import defaultdict + +from geb.evaluators import RetrievalEvaluator +from geb.modality import Modality +from geb.models import BioSeqTransformer +from geb.tasks.tasks import Dataset, Task, TaskMetadata, TaskResult + +logger = logging.getLogger(__name__) + + +def run_retrieval_task(model: BioSeqTransformer, metadata: TaskMetadata) -> TaskResult: + """Evaluate retrieval task. Utilizes the Retrieval evaluator.""" + if len(metadata.datasets) != 2: + raise ValueError("Retrieval tasks require 3 datasets: corpus, query and qrels.") + corpus_ds = metadata.datasets[0].load()["train"] + query_ds = metadata.datasets[0].load()["test"] + qrels = metadata.datasets[1].load() + corpus_embeds = model.encode(corpus_ds["Sequence"]) + query_embeds = model.encode(query_ds["Sequence"]) + qrels_dict = defaultdict(dict) + + def qrels_dict_init(row): + qrels_dict[str(row["query_id"])][str(row["corpus_id"])] = int(row["fuzz_ratio"]) + + # Populate `qrels_dict` from the dataset. + # See https://github.com/cvangysel/pytrec_eval for qrels format. + qrels.map(qrels_dict_init) + qrels = qrels_dict + layer_results = defaultdict(dict) + for i, layer in enumerate(model.layers): + evaluator = RetrievalEvaluator( + corpus_embeds[:, i], + query_embeds[:, i], + corpus_ds["Entry"], + query_ds["Entry"], + qrels, + ) + layer_results["layers"][layer] = evaluator() + logger.info( + f"Layer: {layer}, Retrieval results: {layer_results['layers'][layer]}" + ) + return TaskResult.from_dict(metadata, layer_results, model.metadata) + + +class ArchRetrieval(Task): + metadata = TaskMetadata( + id="arch_retrieval", + display_name="Arch Retrieval", + description="Retrieves bacterial proteins with similar swissprot annotations to a query archaeal protein", + type="retrieval", + modality=Modality.PROTEIN, + datasets=[ + Dataset( + path="tattabio/arch_retrieval", + revision="main", + ), + Dataset( + path="tattabio/arch_retrieval_qrels", + description="Relevance between query and corpus proteins", + revision="main", + ), + ], + primary_metric_id="map_at_5", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_retrieval_task(model, self.metadata) + + +class EukRetrieval(Task): + metadata = TaskMetadata( + id="euk_retrieval", + display_name="Euk Retrieval", + description="Retrieves bacterial proteins with similar swissprot annotations to a query eukaryotic protein", + type="retrieval", + modality=Modality.PROTEIN, + datasets=[ + Dataset( + path="tattabio/euk_retrieval", + revision="main", + ), + Dataset( + path="tattabio/euk_retrieval_qrels", + description="Relevance between query and corpus proteins", + revision="main", + ), + ], + primary_metric_id="map_at_5", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return run_retrieval_task(model, self.metadata) diff --git a/geb/tasks/tasks.py b/geb/tasks/tasks.py new file mode 100644 index 0000000..6f7dc76 --- /dev/null +++ b/geb/tasks/tasks.py @@ -0,0 +1,136 @@ +"""Task functions for evaluation. +# TODO: Add dataset revisions. +""" + +import logging +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Literal, Optional + +import datasets +from pydantic import BaseModel, model_validator + +from ..modality import Modality +from ..models import BioSeqTransformer + +logging.basicConfig(level=logging.INFO) + +TaskType = Literal[ + "classification", + "pair_classification", + "clustering", + "eds", + "bigene_mining", + "retrieval", +] + + +class TaskMetric(BaseModel): + id: str + display_name: str + description: Optional[str] = None + value: float = 0.0 + + +class LayerResult(BaseModel): + layer_number: int + layer_display_name: str + metrics: List[TaskMetric] + + +class TaskResult(BaseModel): + task: "TaskMetadata" + # TODO: Convert model to ModelMetadata + model: Dict[str, Any] + results: List[LayerResult] + + @model_validator(mode="after") + def check_valid_primary_metric(self): + for result in self.results: + if all( + metric.id != self.task.primary_metric_id for metric in result.metrics + ): + raise ValueError( + f"Primary metric {self.task.primary_metric_id} not found in results.metrics" + ) + return self + + @staticmethod + def from_dict( + task_metadata: "TaskMetadata", + layer_results: Dict[str, Any], + model_metadata: Dict[str, Any], + ): + return TaskResult( + task=task_metadata, + model=model_metadata, + results=list( + LayerResult( + layer_number=int(layer), + layer_display_name=str(layer), + metrics=[ + TaskMetric(id=metric, display_name=metric, value=value) + for metric, value in metrics.items() + ], + ) + for layer, metrics in layer_results["layers"].items() + ), + ) + + +class Dataset(BaseModel): + path: str + revision: str + + def load(self) -> datasets.DatasetDict: + ds = datasets.load_dataset(self.path, revision=self.revision) + if not isinstance(ds, datasets.DatasetDict): + raise ValueError( + f"Dataset {self.path} is not a datasets.DatasetDict object." + ) + return ds + + +class TaskMetadata(BaseModel): + id: str + display_name: str + description: str + modality: Modality + type: TaskType + # List of datasets used by the task. + # Each dataset is a dict of all arguments to pass to `datasets.load_dataset()`. + datasets: List[Dataset] + primary_metric_id: str + + +class Task(ABC): + metadata: TaskMetadata + + @abstractmethod + def run( + self, model: BioSeqTransformer, layers: Optional[List[int]] = None + ) -> TaskResult: + pass + + +class noop(Task): + metadata = TaskMetadata( + id="noop", + display_name="NoOp Task", + description="This task is used for testing and does nothing.", + type="classification", + modality=Modality.PROTEIN, + datasets=[ + Dataset( + path="", + revision="main", + ) + ], + primary_metric_id="f1", + ) + + def run(self, model: BioSeqTransformer) -> TaskResult: + return TaskResult.from_dict( + self.metadata, + {"layers": {32: {"accuracy": 0.5, "f1": 0.5}}}, + model.metadata, + ) diff --git a/plot_benchmarks.py b/plot_benchmarks.py new file mode 100644 index 0000000..1d75ceb --- /dev/null +++ b/plot_benchmarks.py @@ -0,0 +1,152 @@ +""" +Given a directory of results, plot the benchmarks for each task as a bar chart and line chart. +""" + +import argparse +import os +from typing import Optional + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns + +from geb.geb import get_all_tasks, get_output_folder, get_tasks_by_name +from geb.tasks.tasks import TaskResult + +ALL_TASKS = [task.metadata.id for task in get_all_tasks()] + + +def plot_benchmarks( + results_dir, + task_ids: Optional[list[str]] = None, + output="benchmarks.png", + model_substring=None, +): + models = os.listdir(results_dir) + all_results = [] + tasks = get_all_tasks() if task_ids is None else get_tasks_by_name(task_ids) + for model_name in models: + if model_substring is not None and all( + substr not in model_name for substr in model_substring + ): + continue + + for task in tasks: + if task.metadata.display_name == "NoOp Task": + continue + filepath = get_output_folder(model_name, task, results_dir, create=False) + # if the file does not exist, skip + if not os.path.exists(filepath): + continue + + with open(filepath) as f: + task_result = TaskResult.model_validate_json(f.read()) + num_params = task_result.model["num_params"] + primary_metric_id = task_result.task.primary_metric_id + main_scores = [ + metric.value + for layer_result in task_result.results + for metric in layer_result.metrics + if metric.id == primary_metric_id + ] + best_score = max(main_scores) + all_results.append( + { + "task": task.metadata.display_name, + "model": model_name, + "num_params": num_params, + "score": best_score, + } + ) + + results_df = pd.DataFrame(all_results) + # order the models by ascending number of parameters + results_df["num_params"] = results_df["num_params"].astype(int) + results_df = results_df.sort_values(by="num_params") + # number of tasks + n_tasks = len(set(results_df["task"])) + + _, ax = plt.subplots(2, n_tasks, figsize=(5 * n_tasks, 10)) + + for i, task in enumerate(set(results_df["task"])): + if n_tasks > 1: + sns.barplot( + x="model", + y="score", + data=results_df[results_df["task"] == task], + ax=ax[0][i], + ) + ax[0][i].set_title(task) + # rotate the x axis labels + + for tick in ax[0][i].get_xticklabels(): + tick.set_rotation(90) + else: + sns.barplot( + x="model", + y="score", + data=results_df[results_df["task"] == task], + ax=ax[0], + ) + ax[0].set_title(task) + # rotate the x axis labels + for tick in ax[0].get_xticklabels(): + tick.set_rotation(90) + + # make a line graph with number of parameters on x axis for each task in the second row of figures + for i, task in enumerate(set(results_df["task"])): + if n_tasks > 1: + sns.lineplot( + x="num_params", + y="score", + data=results_df[results_df["task"] == task], + ax=ax[1][i], + ) + ax[1][i].set_title(task) + ax[1][i].set_xlabel("Number of parameters") + else: + sns.lineplot( + x="num_params", + y="score", + data=results_df[results_df["task"] == task], + ax=ax[1], + ) + ax[1].set_title(task) + ax[1].set_xlabel("Number of parameters") + + plt.tight_layout() + plt.savefig(output) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-d", + "--results_dir", + type=str, + default="results", + help="Directory containing the results of the benchmarking", + ) + parser.add_argument( + "-t", + "--tasks", + type=lambda s: [item for item in s.split(",")], + default=None, + help=f"Comma separated list of tasks to plot. Choose from {ALL_TASKS} or do not specify to plot all tasks. ", + ) + parser.add_argument( + "-o", + "--output", + type=str, + default="benchmarks.png", + help="Output file for the plot", + ) + parser.add_argument( + "--model_substring", + type=lambda s: [item for item in s.split(",")], + default=None, + help="Comma separated list of model substrings. Only plot results for models containing this substring", + ) + args = parser.parse_args() + + plot_benchmarks(args.results_dir, args.tasks, args.output, args.model_substring) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..47b60d3 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,113 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "dgeb" +version = "0.0.0" +description = "Diverse Genomic Embedding Benchmark" +readme = "README.md" +license = { file = "LICENSE" } +keywords = ["scientific software", "genomic embeddings", "machine learning", "benchmark"] +classifiers = [ + "Development Status :: 2 - Pre-Alpha", + "Environment :: Console", + "Intended Audience :: Developers", + "Intended Audience :: Information Technology", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python", +] +dependencies = [ + "datasets>=2.20.0", + "matplotlib>=3.9.0", + "numpy>=2.0.0", + "pandas>=2.2.2", + "pydantic>=2.7.4", + "pytrec_eval>=0.5", + "rich>=13.7.1", + "scikit_learn>=1.5.0", + "scipy>=1.13.1", + "seaborn>=0.13.2", + "torch>=2.3.1", + "tqdm>=4.66.4", + "transformers>=4.41.2" +] + +[project.urls] +homepage = "https://github.com/TattaBio/DGEB" +"Huggingface Organization" = "https://huggingface.co/tattabio" +"Source Code" = "https://github.com/TattaBio/DGEB" + +[project.optional-dependencies] +dev = ["ruff>=0.0.254", "pytest", "pytest-xdist"] + +[tool.setuptools.packages.find] +exclude = ["tests", "results"] + +[tool.setuptools.package-data] +"*" = ["*.json"] + +[tool.ruff] +target-version = "py38" +exclude = [ + ".venv", + "build/" +] +line-length = 88 +indent-width = 4 + +[tool.semantic_release] +version_toml = ["pyproject.toml:project.version"] +build_command = "python -m pip install build; python -m build" +commit_message = "{version}\n\nAutomatically generated by python-semantic-release" +logging_use_named_masks = false +major_on_zero = true +allow_zero_version = true +no_git_verify = false +tag_format = "v{version}" + +[tool.semantic_release.branches.main] +match = "(main|master)" +prerelease_token = "rc" +prerelease = false + +[tool.semantic_release.changelog] +template_dir = "templates" +changelog_file = "CHANGELOG.md" +exclude_commit_patterns = [] + +[tool.semantic_release.changelog.environment] +block_start_string = "{%" +block_end_string = "%}" +variable_start_string = "{{" +variable_end_string = "}}" +comment_start_string = "{#" +comment_end_string = "#}" +trim_blocks = false +lstrip_blocks = false +newline_sequence = "\n" +keep_trailing_newline = false +extensions = [] +autoescape = true + +[tool.semantic_release.commit_author] +env = "GIT_COMMIT_AUTHOR" +default = "semantic-release " + +[tool.semantic_release.commit_parser_options] +allowed_tags = ["build", "chore", "ci", "docs", "feat", "fix", "perf", "style", "refactor", "test"] +minor_tags = ["feat"] +patch_tags = ["fix", "perf"] +default_bump_level = 0 + +[tool.semantic_release.remote] +name = "origin" +type = "github" +ignore_token_for_push = false +insecure = false + +[tool.semantic_release.publish] +dist_glob_patterns = ["dist/*"] +upload_to_vcs_release = true diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..6bf7651 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +datasets>=2.20.0 +matplotlib>=3.9.0 +numpy>=2.0.0 +pandas>=2.2.2 +pydantic>=2.7.4 +pytrec_eval>=0.5 +rich>=13.7.1 +scikit_learn>=1.5.0 +scipy>=1.13.1 +seaborn>=0.13.2 +torch>=2.3.1 +tqdm>=4.66.4 +transformers>=4.41.2 diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..5ff69db --- /dev/null +++ b/ruff.toml @@ -0,0 +1,8 @@ +exclude = [ + ".venv", + "build/", +] +# Same as Black. +line-length = 88 +indent-width = 4 + diff --git a/run_geb.py b/run_geb.py new file mode 100644 index 0000000..ffd7767 --- /dev/null +++ b/run_geb.py @@ -0,0 +1,135 @@ +""" +Main command to run genomic embedding benchmarks (GEB) on a model. +example command to run GEB: +python run_geb.py -m facebook/esm2_t6_8M_UR50D +""" + +import argparse +import logging +import os +import geb + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +ALL_TASK_NAMES = geb.get_all_task_names() +ALL_MODEL_NAMES = geb.get_all_model_names() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + type=str, + default=None, + help=f"Model to evaluate. Choose from {ALL_MODEL_NAMES}", + ) + parser.add_argument( + "-t", + "--tasks", + type=lambda s: [item for item in s.split(",")], + default=None, + help=f"Comma separated tasks to evaluate on. Choose from {ALL_TASK_NAMES} or do not specify to evaluate on all tasks", + ) + parser.add_argument( + "-l", + "--layers", + type=str, + default=None, + help="Layer to evaluate. Comma separated list of integers or 'mid' and 'last'. Default is 'mid,last'", + ) + parser.add_argument( + "--devices", + type=str, + default="0", + help="Comma separated list of GPU device ids to use. Default is 0 (if GPUs are detected).", + ) + parser.add_argument( + "--output_folder", + type=str, + default=None, + help="Output directory for results. Will default to results/model_name if not set.", + ) + parser.add_argument( + "-v", "--verbosity", type=int, default=2, help="Verbosity level" + ) + parser.add_argument( + "-b", "--batch_size", type=int, default=64, help="Batch size for evaluation" + ) + parser.add_argument( + "--max_seq_len", + type=int, + default=1024, + help="Maximum sequence length for model, default is 1024.", + ) + parser.add_argument( + "--pool_type", + type=str, + default="mean", + help="Pooling type for model, choose from mean, max, cls, last. Default is mean.", + ) + + args = parser.parse_args() + + # set logging based on verbosity level + if args.verbosity == 0: + logging.getLogger("geb").setLevel(logging.CRITICAL) + elif args.verbosity == 1: + logging.getLogger("geb").setLevel(logging.WARNING) + elif args.verbosity == 2: + logging.getLogger("geb").setLevel(logging.INFO) + elif args.verbosity == 3: + logging.getLogger("geb").setLevel(logging.DEBUG) + + if args.model is None: + raise ValueError("Please specify a model using the -m or --model argument") + + # make sure that devices are comma separated list of integers + try: + devices = [int(device) for device in args.devices.split(",")] + except ValueError: + raise ValueError("Devices must be comma separated list of integers") + + layers = args.layers + if layers: + if layers not in ["mid", "last"]: + # Layers should be list of integers. + try: + layers = [int(layer) for layer in layers.split(",")] + except ValueError: + raise ValueError("Layers must be a list of integers.") + + model_name = args.model.split("/")[-1] + output_folder = args.output_folder + if output_folder is None: + output_folder = os.path.join("results", model_name) + # create output folder if it does not exist + if not os.path.exists(output_folder): + os.makedirs(output_folder) + logger.info(f"Results will be saved to {output_folder}") + + # Load the model by name. + model = geb.get_model( + model_name=args.model, + layers=layers, + devices=devices, + max_seq_length=args.max_seq_len, + batch_size=args.batch_size, + pool_type=args.pool_type, + ) + + all_tasks_for_modality = geb.get_tasks_by_modality(model.modality) + + if args.tasks: + task_list = geb.get_tasks_by_name(args.tasks) + if not all([task.metadata.modality == model.modality for task in task_list]): + raise ValueError(f"Tasks must be one of {all_tasks_for_modality}") + else: + task_list = all_tasks_for_modality + evaluation = geb.GEB(tasks=task_list) + _ = evaluation.run(model) + + +if __name__ == "__main__": + main()