diff --git a/.editorconfig b/.editorconfig
new file mode 100644
index 0000000..efb1e94
--- /dev/null
+++ b/.editorconfig
@@ -0,0 +1,12 @@
+# top-most EditorConfig file
+root = true
+
+# Unix-style newlines with a newline ending every file
+[*]
+end_of_line = lf
+insert_final_newline = true
+
+[*.py]
+charset = utf-8
+indent_style = space
+indent_size = 4
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..bb313d0
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,34 @@
+name: CI for GEB
+
+on:
+ push:
+ branches: ["**"]
+ pull_request:
+ branches: ["**"]
+
+permissions:
+ id-token: write
+ contents: read
+ actions: write
+ pull-requests: read
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ ruff:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v4
+ with:
+ python-version: "3.11"
+ - uses: yezz123/setup-uv@v4
+ with:
+ uv-venv: ".geb_venv"
+ - run: uv pip install ruff
+ - run: ruff format .
+ - run: ruff check .
+ # TODO: pytest
+ # TODO: pyright
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
new file mode 100644
index 0000000..f349ef9
--- /dev/null
+++ b/.github/workflows/release.yml
@@ -0,0 +1,51 @@
+# This workflow will
+# - Find the latest version tag based on the commit history
+# - Create a git tag for the new version
+# - Update the version number in pyproject.toml based on the commit history
+# - Upload the package to PyPI
+# - Create a release on GitHub
+
+# This workflow required the following secrets to be set:
+# - a GitHub personal access token with the `repo` scope called `RELEASE`
+# - and that you setup trusted publishing using PyPI as described here: https://blog.pypi.org/posts/2023-04-20-introducing-trusted-publishers/
+
+name: Release
+on:
+ push:
+ branches:
+ - main
+
+jobs:
+ release:
+ runs-on: ubuntu-latest
+ concurrency: release
+ permissions:
+ id-token: write # IMPORTANT: this permission is mandatory for trusted publishing using PyPI
+ contents: write
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Python Semantic Release
+ id: release
+ uses: python-semantic-release/python-semantic-release@v9.8.3
+ with:
+ github_token: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Publish package distributions to PyPI
+ uses: pypa/gh-action-pypi-publish@v1.9.0
+ if: steps.release.outputs.released == 'true'
+ with:
+ repository-url: https://test.pypi.org/legacy/
+ # This action supports PyPI's trusted publishing implementation, which allows authentication to PyPI without a manually
+ # configured API token or username/password combination. To perform trusted publishing with this action, your project's
+ # publisher must already be configured on PyPI.
+
+ - name: Publish package distributions to GitHub Releases
+ uses: python-semantic-release/upload-to-gh-release@v9.8.3
+ if: steps.release.outputs.released == 'true'
+ with:
+ github_token: ${{ secrets.GITHUB_TOKEN }}
+ tag: ${{ steps.release.outputs.tag }}
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..bb077e1
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,6 @@
+.venv/
+__pycache__/
+.vscode/
+build/
+dist/
+*egg-info/
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..f49a4e1
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..a77f005
--- /dev/null
+++ b/README.md
@@ -0,0 +1,120 @@
+
Genomic Embedding Benchmark
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Installation |
+ Usage |
+ Leaderboard |
+ Documentation |
+ Citing
+
+
+
+
+
+
+
+## Installation
+
+TODO(joshua):
+```bash
+pip install geb
+```
+
+## Usage
+
+- Using the python script (see [run_geb.py](https://github.com/tattabio/geb/blob/main/run_geb.py)):
+
+```bash
+python run_geb.py --model facebook/esm2_t6_8M_UR50D
+```
+
+
+- Using the python API:
+
+```py
+import geb
+
+model = geb.get_model("facebook/esm2_t6_8M_UR50D")
+tasks = geb.get_tasks_by_modality(geb.Modality.PROTEIN)
+evaluation = geb.GEB(tasks=tasks)
+evaluation.run(model, output_folder="results")
+```
+
+
+### Using a custom model
+
+Custom models should be wrapped with the `geb.models.BioSeqTransformer` abstract class, and specify the modality, number of layers, and embedding dimension. See see [models.py](https://github.com/tattabio/geb/blob/main/geb/models.py) for additional examples on custom model loading and inference.
+
+
+```python
+import geb
+from geb.models import BioSeqTransformer
+from geb.modality import Modality
+
+class MyModel(BioSeqTransformer):
+
+ @property
+ def modality(self) -> Modality:
+ return Modality.PROTEIN
+
+ @property
+ def num_layers(self) -> int:
+ return self.config.num_hidden_layers
+
+ @property
+ def embed_dim(self) -> int:
+ return self.config.hidden_size
+
+
+model = MyModel()
+tasks = geb.get_tasks_by_modality(model.modality)
+evaluation = MTEB(tasks=tasks)
+evaluation.run(model)
+```
+
+### Evaluating on a custom dataset
+
+TODO(andre): Update this section
+
+To evaluate on a custom task, you can run the following code on your custom task.
+
+```python
+import geb
+from geb.tasks import AbsTask
+
+class MyCustomTask(AbsTask):
+ def run(
+ self, model: BioSeqTransformer, layers: Optional[List[int]] = None
+ ) -> TaskResult:
+ pass
+
+model = geb.models.ESM("facebook/esm2_t6_8M_UR50D")
+evaluation = geb.GEB(tasks=[MyCustomTask()])
+evaluation.run(model)
+```
+
+
+
+## Citing
+
+GEB was introduced in "[GEB: Genomic Embedding Benchmark]()", feel free to cite:
+
+TODO(andre): bibtex
+
+For works that have used GEB for benchmarking, you can find them on the [leaderboard](https://huggingface.co/spaces/tattabio/GEB/leaderboard).
diff --git a/benchmarks.png b/benchmarks.png
new file mode 100644
index 0000000..e7c6c76
Binary files /dev/null and b/benchmarks.png differ
diff --git a/docs/images/tatta_logo.png b/docs/images/tatta_logo.png
new file mode 100644
index 0000000..76220bd
Binary files /dev/null and b/docs/images/tatta_logo.png differ
diff --git a/geb/__init__.py b/geb/__init__.py
new file mode 100644
index 0000000..a3d4a78
--- /dev/null
+++ b/geb/__init__.py
@@ -0,0 +1,28 @@
+from geb.geb import (
+ GEB,
+ get_all_tasks,
+ get_output_folder,
+ get_all_task_names,
+ get_tasks_by_name,
+ get_tasks_by_modality,
+ get_all_model_names,
+ get_model,
+)
+from geb.tasks.tasks import TaskResult
+from geb.modality import Modality
+
+# importing without setting `__all__` produces a Ruff error:
+# "imported but unused; consider removing, adding to __all__, or using a redundant alias RuffF401"
+# See https://docs.astral.sh/ruff/rules/unused-import/#why-is-this-bad
+__all__ = [
+ "GEB",
+ "get_all_tasks",
+ "get_all_task_names",
+ "get_tasks_by_name",
+ "get_tasks_by_modality",
+ "get_all_model_names",
+ "get_model",
+ "get_output_folder",
+ "TaskResult",
+ "Modality",
+]
diff --git a/geb/eval_utils.py b/geb/eval_utils.py
new file mode 100644
index 0000000..7b5f630
--- /dev/null
+++ b/geb/eval_utils.py
@@ -0,0 +1,394 @@
+"""Utility functions for evaluation."""
+
+from typing import Any, Dict, List, Tuple
+import json
+import torch
+import random
+import numpy as np
+from sklearn.metrics import auc
+
+
+class ForwardHook:
+ """Pytorch forward hook class to store outputs of intermediate layers."""
+
+ def __init__(self, module: torch.nn.Module):
+ self.hook = module.register_forward_hook(self.hook_fn)
+ self.output = None
+
+ def hook_fn(self, module, input, output):
+ self.output = output
+
+ def close(self):
+ self.hook.remove()
+
+
+def pool(
+ last_hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool_type: str
+) -> torch.Tensor:
+ """Pool embeddings across the sequence length dimension."""
+ assert (
+ last_hidden_states.ndim == 3
+ ), f"Expected hidden_states to have shape [batch, seq_len, D], got shape: {last_hidden_states.shape}"
+ assert (
+ attention_mask.ndim == 2
+ ), f"Expected attention_mask to have shape [batch, seq_len], got shape: {attention_mask.shape}"
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
+ if pool_type == "mean":
+ emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
+ elif pool_type == "max":
+ emb = last_hidden.max(dim=1)[0]
+ elif pool_type == "cls":
+ emb = last_hidden[:, 0]
+ elif pool_type == "last":
+ emb = last_hidden[torch.arange(last_hidden.size(0)), attention_mask.sum(1) - 1]
+ else:
+ raise ValueError(f"pool_type {pool_type} not supported")
+ return emb
+
+
+def set_all_seeds(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.backends.cudnn.deterministic = True
+
+
+def write_results_to_json(results: Dict[str, Any], results_path: str):
+ """Write results dict to a json file."""
+ with open(results_path, "w") as f:
+ json.dump(results, f, indent=4)
+
+
+def merge_split_elem_embeds(ids, embeds, preserve_order: bool = False):
+ """Merge embeddings with the same id by mean-pooling and optionally preserve order in which they appear.
+
+ Args:
+ ids: Array of string ids, [batch].
+ embeds: Array of embeddings, [batch, ...].
+
+ Returns:
+ ids: Unique ids, [unique_batch].
+ embeds: Array of embeddings, [unique_batch, ...].
+ """
+ unique_ids, indices = np.unique(ids, return_inverse=True)
+ shape_no_batch = embeds.shape[1:]
+ sums = np.zeros([unique_ids.size, *shape_no_batch], dtype=embeds.dtype)
+ counts = np.bincount(indices, minlength=unique_ids.size)
+ np.add.at(sums, indices, embeds)
+ # Add trailing dimensions to counts.
+ counts = counts[(...,) + (None,) * len(shape_no_batch)]
+ mean_pooled = sums / counts
+ # Preserve the order of the input ids.
+ if preserve_order:
+ order = []
+ for id in unique_ids:
+ idx = np.where(ids == id)[0][0]
+ order.append(idx)
+ re_order = np.argsort(order)
+ unique_ids = unique_ids[re_order]
+ mean_pooled = mean_pooled[re_order]
+ return unique_ids, mean_pooled
+
+
+def paired_dataset(labels, embeds):
+ """Creates a paired dataset for consecutive operonic gene pairs."""
+ embeds1 = embeds[:-1]
+ embeds2 = embeds[1:]
+ labels = labels[:-1]
+ return embeds1, embeds2, labels
+
+
+def cos_sim(a, b):
+ """Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
+
+ Return:
+ Matrix with res[i][j] = cos_sim(a[i], b[j])
+ """ # noqa: D402
+ if not isinstance(a, torch.Tensor):
+ a = torch.tensor(a)
+
+ if not isinstance(b, torch.Tensor):
+ b = torch.tensor(b)
+
+ if len(a.shape) == 1:
+ a = a.unsqueeze(0)
+
+ if len(b.shape) == 1:
+ b = b.unsqueeze(0)
+
+ a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
+ b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
+ return torch.mm(a_norm, b_norm.transpose(0, 1))
+
+
+def dot_score(a: torch.Tensor, b: torch.Tensor):
+ """Computes the dot-product dot_prod(a[i], b[j]) for all i and j.
+ :return: Matrix with res[i][j] = dot_prod(a[i], b[j])
+ """
+ if not isinstance(a, torch.Tensor):
+ a = torch.tensor(a)
+
+ if not isinstance(b, torch.Tensor):
+ b = torch.tensor(b)
+
+ if len(a.shape) == 1:
+ a = a.unsqueeze(0)
+
+ if len(b.shape) == 1:
+ b = b.unsqueeze(0)
+
+ return torch.mm(a, b.transpose(0, 1))
+
+
+# From https://github.com/beir-cellar/beir/blob/f062f038c4bfd19a8ca942a9910b1e0d218759d4/beir/retrieval/custom_metrics.py#L4
+def mrr(
+ qrels: dict[str, dict[str, int]],
+ results: dict[str, dict[str, float]],
+ k_values: List[int],
+ output_type: str = "mean",
+) -> Tuple[Dict[str, float]]:
+ MRR = {}
+
+ for k in k_values:
+ MRR[f"MRR@{k}"] = []
+
+ k_max, top_hits = max(k_values), {}
+
+ for query_id, doc_scores in results.items():
+ top_hits[query_id] = sorted(
+ doc_scores.items(), key=lambda item: item[1], reverse=True
+ )[0:k_max]
+
+ for query_id in top_hits:
+ query_relevant_docs = set(
+ [doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0]
+ )
+ for k in k_values:
+ rr = 0
+ for rank, hit in enumerate(top_hits[query_id][0:k]):
+ if hit[0] in query_relevant_docs:
+ rr = 1.0 / (rank + 1)
+ break
+ MRR[f"MRR@{k}"].append(rr)
+
+ if output_type == "mean":
+ for k in k_values:
+ MRR[f"MRR@{k}"] = round(sum(MRR[f"MRR@{k}"]) / len(qrels), 5)
+
+ elif output_type == "all":
+ pass
+
+ return MRR
+
+
+# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py
+def recall_cap(
+ qrels: dict[str, dict[str, int]],
+ results: dict[str, dict[str, float]],
+ k_values: List[int],
+ output_type: str = "mean",
+) -> Tuple[Dict[str, float]]:
+ capped_recall = {}
+
+ for k in k_values:
+ capped_recall[f"R_cap@{k}"] = []
+
+ k_max = max(k_values)
+
+ for query_id, doc_scores in results.items():
+ top_hits = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[
+ 0:k_max
+ ]
+ query_relevant_docs = [
+ doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0
+ ]
+ for k in k_values:
+ retrieved_docs = [
+ row[0] for row in top_hits[0:k] if qrels[query_id].get(row[0], 0) > 0
+ ]
+ denominator = min(len(query_relevant_docs), k)
+ capped_recall[f"R_cap@{k}"].append(len(retrieved_docs) / denominator)
+
+ if output_type == "mean":
+ for k in k_values:
+ capped_recall[f"R_cap@{k}"] = round(
+ sum(capped_recall[f"R_cap@{k}"]) / len(qrels), 5
+ )
+
+ elif output_type == "all":
+ pass
+
+ return capped_recall
+
+
+# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py
+def hole(
+ qrels: dict[str, dict[str, int]],
+ results: dict[str, dict[str, float]],
+ k_values: List[int],
+ output_type: str = "mean",
+) -> Tuple[Dict[str, float]]:
+ Hole = {}
+
+ for k in k_values:
+ Hole[f"Hole@{k}"] = []
+
+ annotated_corpus = set()
+ for _, docs in qrels.items():
+ for doc_id, score in docs.items():
+ annotated_corpus.add(doc_id)
+
+ k_max = max(k_values)
+
+ for _, scores in results.items():
+ top_hits = sorted(scores.items(), key=lambda item: item[1], reverse=True)[
+ 0:k_max
+ ]
+ for k in k_values:
+ hole_docs = [
+ row[0] for row in top_hits[0:k] if row[0] not in annotated_corpus
+ ]
+ Hole[f"Hole@{k}"].append(len(hole_docs) / k)
+
+ if output_type == "mean":
+ for k in k_values:
+ Hole[f"Hole@{k}"] = round(Hole[f"Hole@{k}"] / len(qrels), 5)
+
+ elif output_type == "all":
+ pass
+
+ return Hole
+
+
+# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py
+def top_k_accuracy(
+ qrels: dict[str, dict[str, int]],
+ results: dict[str, dict[str, float]],
+ k_values: List[int],
+ output_type: str = "mean",
+) -> Tuple[Dict[str, float]]:
+ top_k_acc = {}
+
+ for k in k_values:
+ top_k_acc[f"Accuracy@{k}"] = []
+
+ k_max, top_hits = max(k_values), {}
+
+ for query_id, doc_scores in results.items():
+ top_hits[query_id] = [
+ item[0]
+ for item in sorted(
+ doc_scores.items(), key=lambda item: item[1], reverse=True
+ )[0:k_max]
+ ]
+
+ for query_id in top_hits:
+ query_relevant_docs = set(
+ [doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0]
+ )
+ for k in k_values:
+ for relevant_doc_id in query_relevant_docs:
+ if relevant_doc_id in top_hits[query_id][0:k]:
+ top_k_acc[f"Accuracy@{k}"].append(1.0)
+ break
+
+ if output_type == "mean":
+ for k in k_values:
+ top_k_acc[f"Accuracy@{k}"] = round(
+ top_k_acc[f"Accuracy@{k}"] / len(qrels), 5
+ )
+
+ elif output_type == "all":
+ pass
+
+ return top_k_acc
+
+
+# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py
+def confidence_scores(sim_scores: List[float]) -> Dict[str, float]:
+ """Computes confidence scores for a single instance = (query, positives, negatives)
+
+ Args:
+ sim_scores: Query-documents similarity scores with length `num_pos+num_neg`
+
+ Returns:
+ conf_scores:
+ - `max`: Maximum similarity score
+ - `std`: Standard deviation of similarity scores
+ - `diff1`: Difference between highest and second highest similarity scores
+ """
+ sim_scores_sorted = sorted(sim_scores)[::-1]
+
+ cs_max = sim_scores_sorted[0]
+ cs_std = np.std(sim_scores)
+ if len(sim_scores) > 1:
+ cs_diff1 = sim_scores_sorted[0] - sim_scores_sorted[1]
+ elif len(sim_scores) == 1:
+ cs_diff1 = 0.0
+
+ conf_scores = {"max": cs_max, "std": cs_std, "diff1": cs_diff1}
+
+ return conf_scores
+
+
+# From https://github.com/embeddings-benchmark/mteb/blob/8178981fd8fcd546d7031afe61a083d13c41520f/mteb/evaluation/evaluators/utils.py
+def nAUC(
+ conf_scores: np.ndarray,
+ metrics: np.ndarray,
+ abstention_rates: np.ndarray = np.linspace(0, 1, 11)[:-1],
+) -> float:
+ """Computes normalized Area Under the Curve on a set of evaluated instances as presented in the paper https://arxiv.org/abs/2402.12997
+ 1/ Computes the raw abstention curve, i.e., the average evaluation metric at different abstention rates determined by the confidence scores
+ 2/ Computes the oracle abstention curve, i.e., the best theoretical abstention curve (e.g.: at a 10% abstention rate, the oracle abstains on the bottom-10% instances with regard to the evaluation metric)
+ 3/ Computes the flat abstention curve, i.e., the one remains flat for all abstention rates (ineffective abstention)
+ 4/ Computes the area under the three curves
+ 5/ Finally scales the raw AUC between the oracle and the flat AUCs to get normalized AUC
+
+ Args:
+ conf_scores: Instance confidence scores used for abstention thresholding, with shape `(num_test_instances,)`
+ metrics: Metric evaluations at instance-level (e.g.: average precision, NDCG...), with shape `(num_test_instances,)`
+ abstention_rates: Target rates for the computation of the abstention curve
+
+ Returns:
+ abst_nauc: Normalized area under the abstention curve (upper-bounded by 1)
+ """
+
+ def abstention_curve(
+ conf_scores: np.ndarray,
+ metrics: np.ndarray,
+ abstention_rates: np.ndarray = np.linspace(0, 1, 11)[:-1],
+ ) -> np.ndarray:
+ """Computes the raw abstention curve for a given set of evaluated instances and corresponding confidence scores
+
+ Args:
+ conf_scores: Instance confidence scores used for abstention thresholding, with shape `(num_test_instances,)`
+ metrics: Metric evaluations at instance-level (e.g.: average precision, NDCG...), with shape `(num_test_instances,)`
+ abstention_rates: Target rates for the computation of the abstention curve
+
+ Returns:
+ abst_curve: Abstention curve of length `len(abstention_rates)`
+ """
+ conf_scores_argsort = np.argsort(conf_scores)
+ abst_curve = np.zeros(len(abstention_rates))
+
+ for i, rate in enumerate(abstention_rates):
+ num_instances_abst = min(
+ round(rate * len(conf_scores_argsort)), len(conf_scores) - 1
+ )
+ abst_curve[i] = metrics[conf_scores_argsort[num_instances_abst:]].mean()
+
+ return abst_curve
+
+ abst_curve = abstention_curve(conf_scores, metrics, abstention_rates)
+ or_curve = abstention_curve(metrics, metrics, abstention_rates)
+ abst_auc = auc(abstention_rates, abst_curve)
+ or_auc = auc(abstention_rates, or_curve)
+ flat_auc = or_curve[0] * (abstention_rates[-1] - abstention_rates[0])
+
+ if or_auc == flat_auc:
+ abst_nauc = np.nan
+ else:
+ abst_nauc = (abst_auc - flat_auc) / (or_auc - flat_auc)
+
+ return abst_nauc
diff --git a/geb/evaluators.py b/geb/evaluators.py
new file mode 100644
index 0000000..5098970
--- /dev/null
+++ b/geb/evaluators.py
@@ -0,0 +1,839 @@
+"""
+Evaluator objects for different evaluation types.
+"""
+
+import logging
+import random
+from abc import ABC, abstractmethod
+import heapq
+from collections import defaultdict
+import pytrec_eval
+import numpy as np
+import sklearn.cluster
+import torch
+from scipy.stats import pearsonr
+from sklearn.linear_model import LogisticRegression
+from sklearn.metrics import (
+ accuracy_score,
+ average_precision_score,
+ classification_report,
+ f1_score,
+ precision_score,
+ recall_score,
+ label_ranking_average_precision_score,
+)
+from sklearn.metrics.cluster import v_measure_score
+from sklearn.metrics.pairwise import (
+ paired_cosine_distances,
+ paired_euclidean_distances,
+ paired_manhattan_distances,
+)
+from sklearn.multioutput import MultiOutputRegressor
+from sklearn.preprocessing import MultiLabelBinarizer
+from typing import Dict, List, Tuple
+
+from .eval_utils import (
+ cos_sim,
+ dot_score,
+ mrr,
+ recall_cap,
+ hole,
+ confidence_scores,
+ nAUC,
+ top_k_accuracy,
+)
+
+
+class Evaluator(ABC):
+ """Base class for all evaluators
+ Extend this class and implement __call__ for custom evaluators.
+ """
+
+ def __init__(self, seed=42, **kwargs):
+ self.seed = seed
+ random.seed(self.seed)
+ np.random.seed(self.seed)
+ torch.manual_seed(self.seed)
+ torch.cuda.manual_seed_all(self.seed)
+
+ @abstractmethod
+ def __call__(self, model):
+ """This is called during training to evaluate the model.
+ It returns scores.
+
+ Parameters
+ ----------
+ model:
+ the model to evaluate
+ """
+ pass
+
+
+logger = logging.getLogger(__name__)
+
+
+class logRegClassificationEvaluator(Evaluator):
+ def __init__(
+ self,
+ embeds_train,
+ y_train,
+ embeds_test,
+ y_test,
+ max_iter=1000,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.embeds_train = embeds_train
+ self.y_train = y_train
+ self.embeds_test = embeds_test
+ self.y_test = y_test
+
+ self.max_iter = max_iter
+
+ def __call__(self):
+ scores = {}
+ clf = LogisticRegression(
+ random_state=self.seed,
+ n_jobs=-1,
+ max_iter=self.max_iter,
+ verbose=1 if logger.isEnabledFor(logging.DEBUG) else 0,
+ )
+ logger.info(f"Encoding {len(self.embeds_train)} training embeds...")
+ X_train = np.asarray(self.embeds_train)
+
+ logger.info(f"Encoding {len(self.embeds_test)} test embeds...")
+ X_test = np.asarray(self.embeds_test)
+ logger.info("Fitting logistic regression classifier...")
+ clf.fit(X_train, self.y_train)
+ logger.info("Evaluating...")
+ y_pred = clf.predict(X_test)
+ accuracy = accuracy_score(self.y_test, y_pred)
+ f1 = f1_score(self.y_test, y_pred, average="macro")
+ scores["accuracy"] = accuracy
+ scores["f1"] = f1
+
+ # if binary classification
+ if len(np.unique(self.y_train)) == 2:
+ ap = average_precision_score(self.y_test, y_pred)
+ scores["ap"] = ap
+
+ return scores
+
+
+class ClusteringEvaluator(Evaluator):
+ def __init__(
+ self,
+ embeds,
+ labels,
+ clustering_batch_size=500,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.embeds = embeds
+ self.labels = labels
+ self.clustering_batch_size = clustering_batch_size
+
+ def __call__(self):
+ logger.info(f"Encoding {len(self.embeds)} embeds...")
+ corpus_embeddings = np.asarray(self.embeds)
+
+ logger.info("Fitting Mini-Batch K-Means model...")
+ clustering_model = sklearn.cluster.MiniBatchKMeans(
+ n_clusters=len(set(self.labels)),
+ batch_size=self.clustering_batch_size,
+ n_init="auto",
+ )
+ clustering_model.fit(corpus_embeddings)
+ cluster_assignment = clustering_model.labels_
+
+ logger.info("Evaluating...")
+ v_measure = v_measure_score(self.labels, cluster_assignment)
+
+ return {"v_measure": v_measure}
+
+
+class PairClassificationEvaluator(Evaluator):
+ """Evaluate a model based on the similarity of the embeddings by calculating the accuracy of identifying similar and
+ dissimilar embeds.
+ The metrics are the cosine similarity as well as euclidean and Manhattan distance
+ The returned score is the accuracy with a specified metric.
+ The results are written in a CSV. If a CSV already exists, then values are appended.
+ The labels need to be 0 for dissimilar pairs and 1 for similar pairs.
+ :param embeds1: The first column of embeds
+ :param embeds2: The second column of embeds
+ :param labels: labels[i] is the label for the pair (embeds1[i], embeds2[i]). Must be 0 or 1
+ :param name: Name for the output
+ :param write_csv: Write results to a CSV file
+ """
+
+ def __init__(self, embeds1, embeds2, labels, **kwargs):
+ super().__init__(**kwargs)
+ self.embeds1 = embeds1
+ self.embeds2 = embeds2
+ self.labels = labels
+
+ assert len(self.embeds1) == len(self.embeds2)
+ assert len(self.embeds1) == len(self.labels)
+ for label in labels:
+ assert label == 0 or label == 1
+
+ def __call__(self):
+ scores = self.compute_metrics()
+ # Compute the max of Average Precision (AP) over all distance metrics.
+ top_ap_score = max(score for k, score in scores.items() if k.endswith("_ap"))
+ scores["top_ap"] = top_ap_score
+ return scores
+
+ def compute_metrics(self):
+ embeddings1 = np.array(self.embeds1)
+ embeddings2 = np.array(self.embeds2)
+
+ logger.info("Computing similarity distances...")
+ cosine_scores = 1 - paired_cosine_distances(embeddings1, embeddings2)
+ manhattan_distances = paired_manhattan_distances(embeddings1, embeddings2)
+ euclidean_distances = paired_euclidean_distances(embeddings1, embeddings2)
+
+ embeddings1_np = np.asarray(embeddings1)
+ embeddings2_np = np.asarray(embeddings2)
+ dot_scores = [
+ np.dot(embeddings1_np[i], embeddings2_np[i])
+ for i in range(len(embeddings1_np))
+ ]
+
+ logger.info("Computing metrics...")
+ labels = np.asarray(self.labels)
+ output_scores = {}
+ for short_name, name, scores, reverse in [
+ ["cos_sim", "Cosine-Similarity", cosine_scores, True],
+ ["manhattan", "Manhattan-Distance", manhattan_distances, False],
+ ["euclidean", "Euclidean-Distance", euclidean_distances, False],
+ ["dot", "Dot-Product", dot_scores, True],
+ ]:
+ metrics = self._compute_metrics(scores, labels, reverse)
+ metrics = {short_name + "_" + k: v for k, v in metrics.items()}
+ output_scores.update(metrics)
+
+ return output_scores
+
+ @staticmethod
+ def _compute_metrics(scores, labels, high_score_more_similar):
+ """Compute the metrics for the given scores and labels.
+
+ Args:
+ scores (`np.ndarray` of shape (n_pairs, )): The similarity/dissimilarity scores for the pairs.
+ labels (`np.ndarray` of shape (n_pairs, )): The labels for the pairs.
+ high_score_more_similar (`bool`): If true, then the higher the score, the more similar the pairs are.
+
+ Returns:
+ `dict`: The metrics for the given scores and labels.
+ """
+ acc, acc_threshold = PairClassificationEvaluator.find_best_acc_and_threshold(
+ scores, labels, high_score_more_similar
+ )
+ f1, precision, recall, f1_threshold = (
+ PairClassificationEvaluator.find_best_f1_and_threshold(
+ scores, labels, high_score_more_similar
+ )
+ )
+ ap = PairClassificationEvaluator.ap_score(
+ scores, labels, high_score_more_similar
+ )
+
+ return {
+ "accuracy": acc,
+ "accuracy_threshold": acc_threshold,
+ "f1": f1,
+ "f1_threshold": f1_threshold,
+ "precision": precision,
+ "recall": recall,
+ "ap": ap,
+ }
+
+ @staticmethod
+ def find_best_acc_and_threshold(scores, labels, high_score_more_similar: bool):
+ assert len(scores) == len(labels)
+ rows = list(zip(scores, labels))
+
+ rows = sorted(rows, key=lambda x: x[0], reverse=high_score_more_similar)
+
+ max_acc = 0
+ best_threshold = -1
+
+ positive_so_far = 0
+ remaining_negatives = sum(np.array(labels) == 0)
+
+ for i in range(len(rows) - 1):
+ score, label = rows[i]
+ if label == 1:
+ positive_so_far += 1
+ else:
+ remaining_negatives -= 1
+
+ acc = (positive_so_far + remaining_negatives) / len(labels)
+ if acc > max_acc:
+ max_acc = acc
+ best_threshold = (rows[i][0] + rows[i + 1][0]) / 2
+
+ return max_acc, best_threshold
+
+ @staticmethod
+ def find_best_f1_and_threshold(scores, labels, high_score_more_similar: bool):
+ assert len(scores) == len(labels)
+
+ scores = np.asarray(scores)
+ labels = np.asarray(labels)
+
+ rows = list(zip(scores, labels))
+
+ rows = sorted(rows, key=lambda x: x[0], reverse=high_score_more_similar)
+
+ best_f1 = best_precision = best_recall = 0
+ threshold = 0
+ nextract = 0
+ ncorrect = 0
+ total_num_duplicates = sum(labels)
+
+ for i in range(len(rows) - 1):
+ score, label = rows[i]
+ nextract += 1
+
+ if label == 1:
+ ncorrect += 1
+
+ if ncorrect > 0:
+ precision = ncorrect / nextract
+ recall = ncorrect / total_num_duplicates
+ f1 = 2 * precision * recall / (precision + recall)
+ if f1 > best_f1:
+ best_f1 = f1
+ best_precision = precision
+ best_recall = recall
+ threshold = (rows[i][0] + rows[i + 1][0]) / 2
+
+ return best_f1, best_precision, best_recall, threshold
+
+ @staticmethod
+ def ap_score(scores, labels, high_score_more_similar: bool):
+ return average_precision_score(
+ labels, scores * (1 if high_score_more_similar else -1)
+ )
+
+
+class MultiClassMultiOutputLogRegClassificationEvaluator(Evaluator):
+ def __init__(
+ self,
+ embeds_train,
+ y_train,
+ embeds_test,
+ y_test,
+ max_iter=1000,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.embeds_train = embeds_train
+ self.y_train = y_train
+ self.embeds_test = embeds_test
+ self.y_test = y_test
+ self.max_iter = max_iter
+
+ def __call__(self):
+ scores = {}
+ mlb = MultiLabelBinarizer()
+ # all classes in y_train and y_test
+
+ class_labels = list(self.y_train) + list(self.y_test)
+ labels = [class_label.split(", ") for class_label in class_labels]
+ mlb.fit(labels)
+ train_labels = [class_label.split(", ") for class_label in self.y_train]
+ test_labels = [class_label.split(", ") for class_label in self.y_test]
+
+ y_train = mlb.transform(train_labels)
+ y_test = mlb.transform(test_labels)
+ clf = MultiOutputRegressor(
+ LogisticRegression(
+ random_state=self.seed, solver="lbfgs", max_iter=self.max_iter
+ )
+ ).fit(self.embeds_train, y_train)
+ y_pred = clf.predict(self.embeds_test)
+
+ results_dict = classification_report(y_test, y_pred, output_dict=True)
+ assert isinstance(
+ results_dict, dict
+ ), "Should always be true since `output_dict=True` is passed to sklearn.metric.classification_report"
+ scores["precision"] = results_dict["macro avg"]["precision"]
+ scores["recall"] = results_dict["macro avg"]["recall"]
+ scores["f1"] = results_dict["macro avg"]["f1-score"]
+ scores["accuracy"] = accuracy_score(y_test, y_pred)
+
+ return scores
+
+
+class MultiClassMultiOutputKNNClassificationEvaluator(Evaluator):
+ def __init__(
+ self,
+ embeds_train,
+ y_train,
+ embeds_test,
+ y_test,
+ n_neighbors=5,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.embeds_train = embeds_train
+ self.y_train = y_train
+ self.embeds_test = embeds_test
+ self.y_test = y_test
+ self.n_neighbors = n_neighbors
+
+ def __call__(self):
+ scores = {}
+
+ mlb = MultiLabelBinarizer()
+ class_labels = list(self.y_train) + list(self.y_test)
+ labels = [class_label.split(", ") for class_label in class_labels]
+ mlb.fit(labels)
+ train_labels = [class_label.split(", ") for class_label in self.y_train]
+ test_labels = [class_label.split(", ") for class_label in self.y_test]
+
+ y_train = mlb.transform(train_labels)
+ y_test = mlb.transform(test_labels)
+ clf = sklearn.neighbors.KNeighborsClassifier(
+ n_neighbors=self.n_neighbors, metric="cosine"
+ )
+ logger.info("Fitting KNN classifier...")
+ clf.fit(self.embeds_train, y_train)
+ logger.info("Evaluating...")
+ y_pred = clf.predict(self.embeds_test)
+ accuracy = accuracy_score(y_test, y_pred)
+ f1 = f1_score(y_test, y_pred, average="macro")
+ precision = precision_score(y_test, y_pred, average="macro")
+ recall = recall_score(y_test, y_pred, average="macro")
+ lrap = label_ranking_average_precision_score(y_test, y_pred)
+ scores["f1"] = f1
+ scores["accuracy"] = accuracy
+ scores["precision"] = precision
+ scores["recall"] = recall
+ scores["lrap"] = lrap
+
+ return scores
+
+
+class BiGeneMiningEvaluator(Evaluator):
+ """
+ BiGene Mining Evaluator, analogous to Bitext Mining Evaluator https://github.com/embeddings-benchmark/mteb/blob/main/mteb/evaluation/evaluators/BitextMiningEvaluator.py.
+
+ If top_k > 1, then recall@k is also computed.
+ """
+
+ def __init__(self, embeds1, embeds2, top_k=1, **kwargs):
+ super().__init__(**kwargs)
+ self.n = len(embeds1)
+ self.embeds1 = np.array(embeds1)
+ self.embeds2 = np.array(embeds2)
+ self.gold = list(zip(range(self.n), range(self.n)))
+ self.top_k = top_k
+
+ def __call__(self):
+ scores = self.compute_metrics()
+ return scores
+
+ def compute_metrics(self):
+ logger.info(f"Finding nearest neighbors... with top_k={self.top_k}")
+ nearest_neighbors = self._similarity_search(
+ self.embeds1, self.embeds2, top_k=self.top_k
+ )
+
+ # Compute errors
+ logger.info("Computing metrics...")
+ labels = []
+ predictions = []
+
+ # Get predictions and labels for top_k=1.
+ for i, x in enumerate(nearest_neighbors):
+ j = x[0]["corpus_id"]
+ predictions.append(j)
+ labels.append(self.gold[i][1])
+
+ scores = {
+ "precision": precision_score(
+ labels, predictions, zero_division=0, average="weighted"
+ ),
+ "recall": recall_score(
+ labels, predictions, zero_division=0, average="weighted"
+ ),
+ "f1": f1_score(labels, predictions, zero_division=0, average="weighted"),
+ "accuracy": accuracy_score(labels, predictions),
+ }
+
+ if self.top_k > 1:
+ # Compute recall@k.
+ top_k_preds = []
+ for i, x in enumerate(nearest_neighbors):
+ top_k_preds.append([pred["corpus_id"] for pred in x])
+ top_k_recall = [
+ self.gold[i][1] in top_k_pred
+ for i, top_k_pred in enumerate(top_k_preds)
+ ]
+ scores[f"recall_at_{self.top_k}"] = sum(top_k_recall) / len(top_k_recall)
+ return scores
+
+ def _similarity_search(
+ self,
+ query_embeddings,
+ corpus_embeddings,
+ query_chunk_size=100,
+ corpus_chunk_size=500000,
+ top_k=1,
+ score_function=cos_sim,
+ ):
+ """This function performs a cosine similarity search between a list of query embeddings and a list of corpus embeddings.
+ It can be used for Information Retrieval / Semantic Search for corpora up to about 1 Million entries.
+ :param query_embeddings: A 2 dimensional tensor with the query embeddings.
+ :param corpus_embeddings: A 2 dimensional tensor with the corpus embeddings.
+ :param query_chunk_size: Process 100 queries simultaneously. Increasing that value increases the speed, but requires more memory.
+ :param corpus_chunk_size: Scans the corpus 50k entries at a time. Increasing that value increases the speed, but requires more memory.
+ :param top_k: Retrieve top k matching entries.
+ :param score_function: Function for computing scores. By default, cosine similarity.
+ :return: Returns a list with one entry for each query. Each entry is a list of dictionaries with the keys 'corpus_id' and 'score', sorted by decreasing cosine similarity scores.
+ """
+ query_embeddings = torch.from_numpy(query_embeddings)
+ corpus_embeddings = torch.from_numpy(corpus_embeddings)
+ if len(query_embeddings.shape) == 1:
+ query_embeddings = query_embeddings.unsqueeze(0)
+ if len(corpus_embeddings.shape) == 1:
+ corpus_embeddings = corpus_embeddings.unsqueeze(0)
+
+ # Check that corpus and queries are on the same device
+ if corpus_embeddings.device != query_embeddings.device:
+ query_embeddings = query_embeddings.to(corpus_embeddings.device)
+
+ queries_result_list = [[] for _ in range(len(query_embeddings))]
+
+ for query_start_idx in range(0, len(query_embeddings), query_chunk_size):
+ # Iterate over chunks of the corpus
+ for corpus_start_idx in range(0, len(corpus_embeddings), corpus_chunk_size):
+ # Compute cosine similarities
+ cos_scores = score_function(
+ query_embeddings[
+ query_start_idx : query_start_idx + query_chunk_size
+ ],
+ corpus_embeddings[
+ corpus_start_idx : corpus_start_idx + corpus_chunk_size
+ ],
+ )
+
+ # Get top-k scores
+ cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(
+ cos_scores,
+ min(top_k, len(cos_scores[0])),
+ dim=1,
+ largest=True,
+ sorted=False,
+ )
+ cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
+ cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()
+
+ for query_itr in range(len(cos_scores)):
+ for sub_corpus_id, score in zip(
+ cos_scores_top_k_idx[query_itr],
+ cos_scores_top_k_values[query_itr],
+ ):
+ corpus_id = corpus_start_idx + sub_corpus_id
+ query_id = query_start_idx + query_itr
+ queries_result_list[query_id].append(
+ {"corpus_id": corpus_id, "score": score}
+ )
+
+ # Sort and strip to top_k results
+ for idx in range(len(queries_result_list)):
+ queries_result_list[idx] = sorted(
+ queries_result_list[idx], key=lambda x: x["score"], reverse=True
+ )
+ queries_result_list[idx] = queries_result_list[idx][0:top_k]
+
+ return queries_result_list
+
+
+class EDSEvaluator(Evaluator):
+ """
+ Evolutionary Distance Similarity Evaluator, analogous to Semantic Textual Similarity Evaluator.
+ Adapted from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/evaluation/evaluators/STSEvaluator.py
+ """
+
+ def __init__(self, embeds1, embeds2, gold_scores, **kwargs):
+ super().__init__(**kwargs)
+ self.embeds1 = embeds1
+ self.embeds2 = embeds2
+ self.gold_scores = gold_scores
+
+ def __call__(self):
+ embeddings1 = np.array(self.embeds1)
+ embeddings2 = np.array(self.embeds2)
+ logger.info("Evaluating...")
+ cosine_scores = paired_cosine_distances(embeddings1, embeddings2)
+ manhattan_distances = paired_manhattan_distances(embeddings1, embeddings2)
+ euclidean_distances = paired_euclidean_distances(embeddings1, embeddings2)
+
+ cosine_pearson, _ = pearsonr(self.gold_scores, cosine_scores)
+ manhattan_pearson, _ = pearsonr(self.gold_scores, manhattan_distances)
+ euclidean_pearson, _ = pearsonr(self.gold_scores, euclidean_distances)
+
+ top_corr = max(
+ cosine_pearson,
+ manhattan_pearson,
+ euclidean_pearson,
+ )
+ return {
+ "cos_sim": cosine_pearson,
+ "manhattan": manhattan_pearson,
+ "euclidean": euclidean_pearson,
+ "top_corr": top_corr,
+ }
+
+
+class RetrievalEvaluator(Evaluator):
+ """Adapted from
+ https://github.com/embeddings-benchmark/mteb/blob/main/mteb/evaluation/evaluators/RetrievalEvaluator.py
+ """
+
+ def __init__(
+ self,
+ corpus_embeds,
+ query_embeds,
+ corpus_ids,
+ query_ids,
+ qrels: Dict[str, Dict[str, int]],
+ k_values: List[int] = [5, 10, 50],
+ score_function: str = "cos_sim",
+ corpus_chunk_size: int = 50000,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.corpus_embeds = corpus_embeds
+ self.query_embeds = query_embeds
+ self.corpus_ids = corpus_ids
+ self.query_ids = query_ids
+ self.qrels = qrels
+ self.k_values = k_values
+ self.top_k = max(k_values) if "top_k" not in kwargs else kwargs["top_k"]
+ self.score_function = score_function
+ self.score_functions = {
+ "cos_sim": cos_sim,
+ "dot": dot_score,
+ }
+ self.corpus_chunk_size = corpus_chunk_size
+
+ def __call__(self):
+ results = self.search(
+ self.corpus_embeds,
+ self.query_embeds,
+ self.corpus_ids,
+ self.query_ids,
+ self.top_k,
+ self.score_function,
+ )
+ ndcg, _map, recall, precision, naucs = self.evaluate(
+ self.qrels, results, self.k_values
+ )
+ mrr, naucs_mrr = self.evaluate_custom(self.qrels, results, self.k_values, "mrr")
+ scores = {
+ **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
+ **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
+ **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
+ **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
+ **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
+ **{
+ k.replace("@", "_at_").replace("_P", "_precision").lower(): v
+ for k, v in naucs.items()
+ },
+ **{
+ k.replace("@", "_at_").replace("_P", "_precision").lower(): v
+ for k, v in naucs_mrr.items()
+ },
+ }
+ return scores
+
+ def search(
+ self,
+ corpus_embeds,
+ query_embeds,
+ corpus_ids,
+ query_ids,
+ top_k: int,
+ score_function: str,
+ return_sorted: bool = False,
+ **kwargs,
+ ) -> dict[str, dict[str, float]]:
+ # Create embeddings for all queries using model.encode()
+ # Runs semantic search against the corpus embeddings
+ # Returns a ranked list with the corpus ids
+ if score_function not in self.score_functions:
+ raise ValueError(
+ f"score function: {score_function} must be either (cos_sim) for cosine similarity or (dot) for dot product"
+ )
+ # make query embeds and corpus embeds torch tensors
+ query_embeds = torch.from_numpy(query_embeds)
+ corpus_embeds = torch.from_numpy(corpus_embeds)
+ itr = range(0, len(corpus_embeds), self.corpus_chunk_size)
+ results = defaultdict(dict)
+ # Keep only the top-k docs for each query
+ result_heaps = defaultdict(list)
+ for batch_num, corpus_start_idx in enumerate(itr):
+ logger.info("Searching Batch {}/{}...".format(batch_num + 1, len(itr)))
+ corpus_end_idx = min(
+ corpus_start_idx + self.corpus_chunk_size, len(corpus_ids)
+ )
+ sub_corpus_embeds = corpus_embeds[corpus_start_idx:corpus_end_idx]
+ # Compute similarites using either cosine-similarity or dot product
+ cos_scores = self.score_functions[score_function](
+ query_embeds, sub_corpus_embeds
+ )
+ cos_scores[torch.isnan(cos_scores)] = -1
+
+ # Get top-k values
+ cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(
+ cos_scores,
+ min(
+ top_k + 1,
+ len(cos_scores[1]) if len(cos_scores) > 1 else len(cos_scores[-1]),
+ ),
+ dim=1,
+ largest=True,
+ sorted=return_sorted,
+ )
+ cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
+ cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()
+
+ for query_itr in range(len(query_embeds)):
+ query_id = query_ids[query_itr]
+ for sub_corpus_id, score in zip(
+ cos_scores_top_k_idx[query_itr], cos_scores_top_k_values[query_itr]
+ ):
+ corpus_id = corpus_ids[corpus_start_idx + sub_corpus_id]
+ if corpus_id != query_id:
+ if len(result_heaps[query_id]) < top_k:
+ # Push item on the heap
+ heapq.heappush(result_heaps[query_id], (score, corpus_id))
+ else:
+ # If item is larger than the smallest in the heap, push it on the heap then pop the smallest element
+ heapq.heappushpop(
+ result_heaps[query_id], (score, corpus_id)
+ )
+
+ for qid in result_heaps:
+ for score, corpus_id in result_heaps[qid]:
+ results[qid][corpus_id] = score
+
+ return results
+
+ @staticmethod
+ def evaluate(
+ qrels: dict[str, dict[str, int]],
+ results: dict[str, dict[str, float]],
+ k_values: List[int],
+ ignore_identical_ids: bool = True,
+ ) -> Tuple[Dict[str, float], dict[str, float], dict[str, float], dict[str, float]]:
+ if ignore_identical_ids:
+ logger.info(
+ "For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this."
+ )
+ popped = []
+ for qid, rels in results.items():
+ for pid in list(rels):
+ if qid == pid:
+ results[qid].pop(pid)
+ popped.append(pid)
+
+ all_ndcgs, all_aps, all_recalls, all_precisions = {}, {}, {}, {}
+
+ for k in k_values:
+ all_ndcgs[f"NDCG@{k}"] = []
+ all_aps[f"MAP@{k}"] = []
+ all_recalls[f"Recall@{k}"] = []
+ all_precisions[f"P@{k}"] = []
+
+ map_string = "map_cut." + ",".join([str(k) for k in k_values])
+ ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
+ recall_string = "recall." + ",".join([str(k) for k in k_values])
+ precision_string = "P." + ",".join([str(k) for k in k_values])
+ evaluator = pytrec_eval.RelevanceEvaluator(
+ qrels, {map_string, ndcg_string, recall_string, precision_string}
+ )
+ scores = evaluator.evaluate(results)
+
+ for query_id in scores.keys():
+ for k in k_values:
+ all_ndcgs[f"NDCG@{k}"].append(scores[query_id]["ndcg_cut_" + str(k)])
+ all_aps[f"MAP@{k}"].append(scores[query_id]["map_cut_" + str(k)])
+ all_recalls[f"Recall@{k}"].append(scores[query_id]["recall_" + str(k)])
+ all_precisions[f"P@{k}"].append(scores[query_id]["P_" + str(k)])
+ ndcg, _map, recall, precision = (
+ all_ndcgs.copy(),
+ all_aps.copy(),
+ all_recalls.copy(),
+ all_precisions.copy(),
+ )
+
+ for k in k_values:
+ ndcg[f"NDCG@{k}"] = round(sum(ndcg[f"NDCG@{k}"]) / len(scores), 5)
+ _map[f"MAP@{k}"] = round(sum(_map[f"MAP@{k}"]) / len(scores), 5)
+ recall[f"Recall@{k}"] = round(sum(recall[f"Recall@{k}"]) / len(scores), 5)
+ precision[f"P@{k}"] = round(sum(precision[f"P@{k}"]) / len(scores), 5)
+ naucs = RetrievalEvaluator.evaluate_abstention(
+ results, {**all_ndcgs, **all_aps, **all_recalls, **all_precisions}
+ )
+ return ndcg, _map, recall, precision, naucs
+
+ @staticmethod
+ def evaluate_abstention(
+ results: dict[str, dict[str, float]],
+ metric_scores: dict[str, list[float]],
+ ) -> Dict[str, float]:
+ """Computes normalized Area Under the Curve on a set of evaluated instances as presented in the paper https://arxiv.org/abs/2402.12997"""
+ all_sim_scores = [list(results[qid].values()) for qid in list(results.keys())]
+ all_conf_scores = [
+ confidence_scores(sim_scores) for sim_scores in all_sim_scores
+ ]
+ conf_fcts = list(all_conf_scores[0].keys())
+ all_conf_scores = {
+ fct: np.array([x[fct] for x in all_conf_scores]) for fct in conf_fcts
+ }
+ metric_scores = {k: np.array(v) for k, v in metric_scores.items()}
+ naucs = {}
+
+ for metric_name, scores in metric_scores.items():
+ for fct, conf_scores in all_conf_scores.items():
+ naucs[f"nAUC_{metric_name}_{fct}"] = nAUC(conf_scores, scores)
+
+ return naucs
+
+ @staticmethod
+ def evaluate_custom(
+ qrels: dict[str, dict[str, int]],
+ results: dict[str, dict[str, float]],
+ k_values: List[int],
+ metric: str,
+ output_type: str = "all",
+ ) -> Tuple[Dict[str, float]]:
+ if metric.lower() in ["mrr", "mrr@k", "mrr_cut"]:
+ metric_scores = mrr(qrels, results, k_values, output_type)
+
+ elif metric.lower() in ["recall_cap", "r_cap", "r_cap@k"]:
+ metric_scores = recall_cap(qrels, results, k_values, output_type)
+
+ elif metric.lower() in ["hole", "hole@k"]:
+ metric_scores = hole(qrels, results, k_values, output_type)
+
+ elif metric.lower() in [
+ "acc",
+ "top_k_acc",
+ "accuracy",
+ "accuracy@k",
+ "top_k_accuracy",
+ ]:
+ metric_scores = top_k_accuracy(qrels, results, k_values, output_type)
+
+ naucs = RetrievalEvaluator.evaluate_abstention(results, metric_scores)
+ metric_scores_avg = {k: sum(v) / len(v) for k, v in metric_scores.items()}
+
+ return metric_scores_avg, naucs
diff --git a/geb/geb.py b/geb/geb.py
new file mode 100644
index 0000000..e6bf3df
--- /dev/null
+++ b/geb/geb.py
@@ -0,0 +1,129 @@
+from itertools import chain
+import logging
+import os
+import traceback
+from typing import Any, List
+
+from rich.console import Console
+
+from .eval_utils import set_all_seeds
+from .modality import Modality
+from .models import BioSeqTransformer
+from .tasks.tasks import Task
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+class GEB:
+ """GEB class to run the evaluation pipeline."""
+
+ def __init__(self, tasks: List[type[Task]], seed: int = 42):
+ self.tasks = tasks
+ set_all_seeds(seed)
+
+ def print_selected_tasks(self):
+ """Print the selected tasks."""
+ console = Console()
+ console.rule("[bold]Selected Tasks\n", style="grey15")
+ for task in self.tasks:
+ prefix = " - "
+ name = f"{task.metadata.display_name}"
+ category = f", [italic grey39]{task.metadata.type}[/]"
+ console.print(f"{prefix}{name}{category}")
+ console.print("\n")
+
+ def run(
+ self,
+ model, # type encoder
+ output_folder: str = "results",
+ ):
+ """Run the evaluation pipeline on the selected tasks.
+
+ Args:
+ model: Model to be used for evaluation
+ output_folder: Folder where the results will be saved. Default to 'results'. Where it will save the results in the format:
+ `{output_folder}/{model_name}/{model_revision}/{task_name}.json`.
+
+ Returns:
+ A list of MTEBResults objects, one for each task evaluated.
+ """
+ # Run selected tasks
+ self.print_selected_tasks()
+ results = []
+
+ for task in self.tasks:
+ logger.info(
+ f"\n\n********************** Evaluating {task.metadata.display_name} **********************"
+ )
+
+ try:
+ result = task().run(model)
+ except Exception as e:
+ logger.error(e)
+ logger.error(traceback.format_exc())
+ logger.error(f"Error running task {task}")
+ continue
+
+ results.append(result)
+
+ save_path = get_output_folder(model.hf_name, task, output_folder)
+ with open(save_path, "w") as f_out:
+ f_out.write(result.model_dump_json(indent=2))
+ return results
+
+
+def get_model(model_name: str, **kwargs: Any) -> type[BioSeqTransformer]:
+ all_names = get_all_model_names()
+ for cls in BioSeqTransformer.__subclasses__():
+ if model_name in cls.MODEL_NAMES:
+ return cls(model_name, **kwargs)
+ raise ValueError(f"Model {model_name} not found in {all_names}.")
+
+
+def get_all_model_names() -> List[str]:
+ return list(
+ chain.from_iterable(
+ cls.MODEL_NAMES for cls in BioSeqTransformer.__subclasses__()
+ )
+ )
+
+
+def get_all_task_names() -> List[str]:
+ return [task.metadata.id for task in get_all_tasks()]
+
+
+def get_tasks_by_name(tasks: List[str]) -> List[type[Task]]:
+ return [_get_task(task) for task in tasks]
+
+
+def get_tasks_by_modality(modality: Modality) -> List[type[Task]]:
+ return [task for task in get_all_tasks() if task.metadata.modality == modality]
+
+
+def get_all_tasks() -> List[type[Task]]:
+ return Task.__subclasses__()
+
+
+def _get_task(task_name: str) -> type[Task]:
+ logger.info(f"Getting task {task_name}")
+ for task in get_all_tasks():
+ if task.metadata.id == task_name:
+ return task
+
+ raise ValueError(
+ f"Task {task_name} not found, available tasks are: {[task.metadata.id for task in get_all_tasks()]}"
+ )
+
+
+def get_output_folder(
+ model_hf_name: str, task: type[Task], output_folder: str, create: bool = True
+):
+ output_folder = os.path.join(output_folder, os.path.basename(model_hf_name))
+ # create output folder if it does not exist
+ if create and not os.path.exists(output_folder):
+ os.makedirs(output_folder)
+ return os.path.join(
+ output_folder,
+ f"{task.metadata.id}.json",
+ )
diff --git a/geb/modality.py b/geb/modality.py
new file mode 100644
index 0000000..146d1ad
--- /dev/null
+++ b/geb/modality.py
@@ -0,0 +1,10 @@
+"""Defines the data modality enum."""
+
+from enum import Enum
+
+
+class Modality(Enum):
+ """Data modality, either DNA or protein sequence."""
+
+ PROTEIN = "protein"
+ DNA = "dna"
diff --git a/geb/models.py b/geb/models.py
new file mode 100644
index 0000000..1c46c81
--- /dev/null
+++ b/geb/models.py
@@ -0,0 +1,482 @@
+import logging
+import re
+from abc import ABC, abstractmethod
+from functools import partial
+from types import SimpleNamespace
+from typing import Dict, List, Literal, Optional
+
+import numpy as np
+import torch
+import tqdm as tqdm
+from datasets import Dataset
+from torch import Tensor
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from transformers import (
+ AutoConfig,
+ AutoModel,
+ AutoModelForCausalLM,
+ AutoModelForMaskedLM,
+ AutoTokenizer,
+ BatchEncoding,
+ DefaultDataCollator,
+ T5EncoderModel,
+ T5Tokenizer,
+)
+from transformers.modeling_outputs import BaseModelOutput
+
+from .modality import Modality
+from .eval_utils import ForwardHook, pool
+
+logger = logging.getLogger(__name__)
+
+
+class BioSeqTransformer(ABC):
+ """
+ Abstract class to wrap models which map biological sequences (DNA/Prot) to embeddings.
+ Modelled after SentenceTransformer (https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/SentenceTransformer.py)
+
+ Args:
+ model_name: Name or path to the pretrained model.
+ layers: List of model layers to probe. Can be integers or "mid" or "last".
+ devices: List of device ids for inference. If cuda is not available, will use cpu.
+ num_processes: Number of processes to use for data loading.
+ max_seq_length: Maximum sequence length of the input sequences.
+ l2_norm: If true, embeddings are L2-normalized before they are returned.
+ batch_size: Batch size for encoding.
+ pool_type: Pooling strategy to use. One of "mean", "max", "cls", "last".
+ """
+
+ def __init__(
+ self,
+ model_name: str,
+ layers: Optional[List[int] | Literal["mid"] | Literal["last"]] = None,
+ devices: List[int] = [0],
+ num_processes: int = 16,
+ max_seq_length: int = 1024,
+ l2_norm: bool = False,
+ batch_size: int = 128,
+ pool_type: str = "mean",
+ ):
+ super().__init__()
+
+ self.id = self.__class__.__name__
+ self.hf_name = model_name
+ self.encoder = self._load_model(model_name)
+ if not hasattr(self.encoder, "config"):
+ raise ValueError(
+ 'The model from `self._load_model()` must have a "config" attribute.'
+ )
+ self.config = self.encoder.config
+ self.tokenizer = self._get_tokenizer(model_name)
+ self.num_param = sum(p.numel() for p in self.encoder.parameters())
+ self.data_collator = DefaultDataCollator()
+ self.gpu_count = len(devices)
+ self.l2_norm = l2_norm
+
+ self.device = torch.device(
+ f"cuda:{devices[0]}" if torch.cuda.is_available() else "cpu"
+ )
+ self.num_processes = num_processes
+ self.max_seq_length = max_seq_length
+ self.batch_size = batch_size
+ self.pool_type = pool_type
+
+ if self.gpu_count > 1:
+ self.encoder = torch.nn.DataParallel(self.encoder, device_ids=devices)
+ self.encoder.to(self.device)
+ self.encoder.eval()
+
+ mid_layer = self.num_layers // 2
+ last_layer = self.num_layers - 1
+ mid_layer_label = f"mid ({mid_layer})"
+ last_layer_label = f"last ({self.num_layers - 1})"
+
+ if layers is None:
+ logger.debug(f"Using default layers: {mid_layer_label}, {last_layer_label}")
+ self.layers = [mid_layer, last_layer]
+ self.layer_labels = [mid_layer_label, last_layer_label]
+ elif layers == "mid":
+ self.layers = [mid_layer]
+ self.layer_labels = [mid_layer_label]
+ elif layers == "last":
+ self.layers = [last_layer]
+ self.layer_labels = [last_layer_label]
+ else:
+ self.layers = layers
+ self.layer_labels = [str(layer) for layer in layers]
+
+ def _encode_single_batch(self, batch_dict: Dict[str, Tensor]):
+ """Returns the output embedding for the given batch with shape [batch, num_layers, D]."""
+ outputs = self.encoder(**batch_dict, output_hidden_states=True)
+ embeds = [outputs.hidden_states[layer] for layer in self.layers]
+ embeds = [
+ pool(layer_embeds, batch_dict["attention_mask"], self.pool_type)
+ for layer_embeds in embeds
+ ]
+ # Stack with shape [B, num_layers, D].
+ embeds = torch.stack(embeds, dim=1)
+ return embeds
+
+ def _load_model(self, model_name):
+ return AutoModel.from_pretrained(model_name, trust_remote_code=True)
+
+ def _get_tokenizer(self, model_name):
+ return AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+
+ def _tokenize_func(
+ self, tokenizer, examples: Dict[str, List], max_seq_length: int
+ ) -> BatchEncoding:
+ batch_dict = tokenizer(
+ examples["input_seqs"],
+ max_length=max_seq_length,
+ padding=True,
+ truncation=True,
+ )
+ return batch_dict
+
+ @property
+ def metadata(self) -> Dict:
+ return {
+ "hf_name": self.hf_name,
+ "revision": "...", # TODO: Fix
+ "num_layers": self.num_layers,
+ "num_params": self.num_param,
+ "embed_dim": self.embed_dim,
+ }
+
+ @property
+ @abstractmethod
+ def num_layers(self) -> int:
+ pass
+
+ @property
+ @abstractmethod
+ def embed_dim(self) -> int:
+ pass
+
+ @property
+ @abstractmethod
+ def modality(self) -> Modality:
+ pass
+
+ @torch.no_grad()
+ def encode(self, sequences, **kwargs) -> np.ndarray:
+ """Returns a list of embeddings for the given sequences.
+ Args:
+ sequences (`List[str]`): List of sequences to encode
+ Returns:
+ `np.ndarray`: Embeddings for the given sequences of shape [num_sequences, num_layers, embedding_dim].
+ """
+ dataset = Dataset.from_dict({"input_seqs": sequences})
+ dataset.set_transform(
+ partial(
+ self._tokenize_func, self.tokenizer, max_seq_length=self.max_seq_length
+ )
+ )
+ data_loader = DataLoader(
+ dataset,
+ batch_size=self.batch_size * self.gpu_count,
+ shuffle=False,
+ drop_last=False,
+ num_workers=self.num_processes,
+ collate_fn=self.data_collator,
+ pin_memory=True,
+ )
+
+ if max(self.layers) >= self.num_layers:
+ raise ValueError(
+ f"Layer {max(self.layers)} is not available in the model. Choose a layer between 0 and {self.num_layers - 1}"
+ )
+
+ encoded_embeds = []
+ for batch_dict in tqdm.tqdm(
+ data_loader, desc="encoding", mininterval=10, disable=len(sequences) < 128
+ ):
+ batch_dict = {k: v.to(self.device) for k, v in batch_dict.items()}
+
+ embeds = self._encode_single_batch(batch_dict)
+
+ if self.l2_norm:
+ embeds = F.normalize(embeds, p=2, dim=-1)
+ encoded_embeds.append(embeds.cpu().numpy())
+
+ return np.concatenate(encoded_embeds, axis=0)
+
+
+class ESM(BioSeqTransformer):
+ """ESM model from https://huggingface.co/docs/transformers/en/model_doc/esm"""
+
+ MODEL_NAMES = [
+ "facebook/esm2_t6_8M_UR50D",
+ "facebook/esm2_t12_35M_UR50D",
+ "facebook/esm2_t30_150M_UR50D",
+ "facebook/esm2_t33_650M_UR50D",
+ "facebook/esm2_t36_3B_UR50D",
+ "facebook/esm2_t48_15B_UR50D",
+ ]
+
+ @property
+ def modality(self) -> Modality:
+ return Modality.PROTEIN
+
+ @property
+ def num_layers(self) -> int:
+ return self.config.num_hidden_layers
+
+ @property
+ def embed_dim(self) -> int:
+ return self.config.hidden_size
+
+
+class ESM3(BioSeqTransformer):
+ """ESM3 model from https://github.com/evolutionaryscale/esm"""
+
+ MODEL_NAMES = ["esm3_sm_open_v1"]
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # Register forward hooks to store embeddings per layer.
+ self.hooks = [
+ ForwardHook(self.encoder.transformer.blocks[layer]) for layer in self.layers
+ ]
+
+ @property
+ def modality(self) -> Modality:
+ return Modality.PROTEIN
+
+ @property
+ def num_layers(self) -> int:
+ return self.config.num_hidden_layers
+
+ @property
+ def embed_dim(self) -> int:
+ return self.config.hidden_size
+
+ def _load_model(self, model_name):
+ try:
+ from esm.models.esm3 import ESM3 as ModelESM3
+ except ImportError:
+ raise ImportError(
+ "ESM3 is not installed. Please install it with `pip install esm`."
+ )
+ model = ModelESM3.from_pretrained("esm3_sm_open_v1")
+ model.config = SimpleNamespace(
+ num_hidden_layers=len(model.transformer.blocks),
+ hidden_size=model.transformer.blocks[0].ffn[-1].out_features,
+ )
+ return model
+
+ def _get_tokenizer(self, model_name):
+ try:
+ from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
+ except ImportError:
+ raise ImportError(
+ "ESM3 is not installed. Please install it with `pip install esm`."
+ )
+ return EsmSequenceTokenizer()
+
+ def _encode_single_batch(self, batch_dict: Dict[str, Tensor]):
+ _ = self.encoder.forward(sequence_tokens=batch_dict["input_ids"])
+ embeds = [hook.output for hook in self.hooks]
+ embeds = [
+ pool(layer_embeds, batch_dict["attention_mask"], self.pool_type)
+ for layer_embeds in embeds
+ ]
+ # Stack with shape [B, num_layers, D].
+ embeds = torch.stack(embeds, dim=1)
+ embeds = embeds.to(torch.float32)
+ return embeds
+
+
+class ProtT5(BioSeqTransformer):
+ """ProtT5 model from https://github.com/agemagician/ProtTrans"""
+
+ MODEL_NAMES = [
+ "Rostlab/prot_t5_xl_uniref50",
+ "Rostlab/prot_t5_xl_bfd",
+ "Rostlab/prot_t5_xxl_uniref50",
+ "Rostlab/prot_t5_xxl_bfd",
+ ]
+
+ @property
+ def modality(self) -> Modality:
+ return Modality.PROTEIN
+
+ @property
+ def num_layers(self) -> int:
+ return self.config.num_layers
+
+ @property
+ def embed_dim(self) -> int:
+ return self.config.d_model
+
+ def _load_model(self, model_name):
+ return T5EncoderModel.from_pretrained(model_name)
+
+ def _get_tokenizer(self, model_name):
+ return T5Tokenizer.from_pretrained(model_name, do_lower_case=False)
+
+ def _tokenize_func(
+ self, tokenizer, examples: Dict[str, List], max_seq_length: int
+ ) -> BatchEncoding:
+ example_sequences = examples["input_seqs"]
+ # Add space between amino acids to make sure they are tokenized correctly.
+ example_sequences = [" ".join(sequence) for sequence in example_sequences]
+ example_sequences = [
+ re.sub(r"[UZOB]", "X", sequence) for sequence in example_sequences
+ ]
+ batch_dict = tokenizer(
+ example_sequences,
+ max_length=max_seq_length,
+ padding=True,
+ truncation=True,
+ add_special_tokens=True,
+ )
+
+ return batch_dict
+
+
+class ProGen(BioSeqTransformer):
+ """ProGen models from https://github.com/salesforce/progen."""
+
+ MODEL_NAMES = [
+ "hugohrban/progen2-small",
+ "hugohrban/progen2-medium",
+ "hugohrban/progen2-base",
+ "hugohrban/progen2-large",
+ "hugohrban/progen2-xlarge",
+ ]
+
+ @property
+ def modality(self) -> Modality:
+ return Modality.PROTEIN
+
+ @property
+ def num_layers(self) -> int:
+ return self.config.n_layer
+
+ @property
+ def embed_dim(self) -> int:
+ return self.config.embed_dim
+
+ def _load_model(self, model_name):
+ return AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
+
+ def _get_tokenizer(self, model_name_or_path):
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_name_or_path, trust_remote_code=True
+ )
+ tokenizer.pad_token = "<|pad|>"
+ return tokenizer
+
+ def _encode_single_batch(self, batch_dict: Dict[str, Tensor]):
+ """Returns the output embedding for the given batch with shape [batch, num_layers, D]."""
+ outputs: BaseModelOutput = self.encoder(
+ input_ids=batch_dict["input_ids"],
+ output_hidden_states=True,
+ use_cache=False,
+ )
+ embeds = [outputs.hidden_states[layer] for layer in self.layers]
+ embeds = [
+ pool(layer_embeds, batch_dict["attention_mask"], self.pool_type)
+ for layer_embeds in embeds
+ ]
+ # Stack with shape [B, num_layers, D].
+ embeds = torch.stack(embeds, dim=1)
+ return embeds
+
+
+class EvoModel(BioSeqTransformer):
+ """https://github.com/evo-design/evo."""
+
+ MODEL_NAMES = [
+ "togethercomputer/evo-1-8k-base",
+ "togethercomputer/evo-1-131k-base",
+ ]
+
+ @property
+ def modality(self) -> Modality:
+ return Modality.DNA
+
+ @property
+ def num_layers(self) -> int:
+ return self.config.num_layers
+
+ @property
+ def embed_dim(self) -> int:
+ return self.config.hidden_size
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # Register forward hooks to store embeddings per layer.
+ self.hooks = []
+ for layer in self.layers:
+ # For the last layer, get the output of `backbone.norm`, which directly precedes `backbone.unembed`.
+ # This is equivalent to the approach in https://github.com/evo-design/evo/issues/32.
+ if layer == self.num_layers - 1 or layer == -1:
+ self.hooks.append(ForwardHook(self.encoder.backbone.norm))
+ else:
+ self.hooks.append(ForwardHook(self.encoder.backbone.blocks[layer]))
+
+ def _load_model(self, model_name):
+ config = AutoConfig.from_pretrained(
+ model_name, trust_remote_code=True, revision="1.1_fix"
+ )
+ model = AutoModelForCausalLM.from_pretrained(
+ model_name, config=config, trust_remote_code=True, revision="1.1_fix"
+ )
+ return model
+
+ def _get_tokenizer(self, model_name):
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_name, revision="1.1_fix", trust_remote_code=True
+ )
+ # Evo tokenizer is missing pad_token by default.
+ tokenizer.add_special_tokens({"pad_token": "N"})
+ return tokenizer
+
+ def _encode_single_batch(self, batch_dict: Dict[str, Tensor]):
+ _ = self.encoder(batch_dict["input_ids"], use_cache=False)
+ embeds = [hook.output for hook in self.hooks]
+ # The hook output for Evo middle layers is a tuple (embedding, inference_params=None).
+ embeds = [x[0] if isinstance(x, tuple) else x for x in embeds]
+ embeds = [
+ pool(layer_embeds, batch_dict["attention_mask"], self.pool_type)
+ for layer_embeds in embeds
+ ]
+ # Stack with shape [B, num_layers, D].
+ embeds = torch.stack(embeds, dim=1)
+ embeds = embeds.to(torch.float32)
+ return embeds
+
+
+class NTModel(BioSeqTransformer):
+ """Nucleotide Transformer https://github.com/instadeepai/nucleotide-transformer"""
+
+ MODEL_NAMES = [
+ "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species",
+ "InstaDeepAI/nucleotide-transformer-v2-100m-multi-species",
+ "InstaDeepAI/nucleotide-transformer-v2-250m-multi-species",
+ "InstaDeepAI/nucleotide-transformer-v2-500m-multi-species",
+ "InstaDeepAI/nucleotide-transformer-2.5b-multi-species",
+ ]
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.max_seq_length = self.tokenizer.model_max_length
+
+ @property
+ def modality(self) -> Modality:
+ return Modality.DNA
+
+ @property
+ def num_layers(self) -> int:
+ return self.config.num_hidden_layers
+
+ @property
+ def embed_dim(self) -> int:
+ return self.config.hidden_size
+
+ def _load_model(self, model_name):
+ return AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
diff --git a/geb/results.py b/geb/results.py
new file mode 100644
index 0000000..afa378f
--- /dev/null
+++ b/geb/results.py
@@ -0,0 +1,50 @@
+from typing import List
+from pydantic import BaseModel
+
+
+class Metric(BaseModel):
+ metric_id: str
+ display_name: str
+ display_value: float
+ description: str
+
+
+class TaskResult(BaseModel):
+ task_id: str
+ display_name: str
+ description: str
+ metrics: List[Metric]
+
+
+class TaskResults(BaseModel):
+ results_id: str
+ description: str
+ task_results: List[TaskResult]
+
+
+# Example task results that conforms to above data specification
+mock_task_results = {
+ "results_id": "result_123",
+ "description": "Overall results of the tasks",
+ "task_results": [
+ {
+ "task_id": "task_1",
+ "display_name": "Task 1",
+ "description": "Description of Task 1",
+ "metrics": [
+ {
+ "metric_id": "metric_1",
+ "display_name": "Metric 1",
+ "display_value": "Value 1",
+ "description": "Description of Metric 1",
+ },
+ {
+ "metric_id": "metric_2",
+ "display_name": "Metric 2",
+ "display_value": "Value 2",
+ "description": "Description of Metric 2",
+ },
+ ],
+ },
+ ],
+}
diff --git a/geb/tasks/__init__.py b/geb/tasks/__init__.py
new file mode 100644
index 0000000..29bc042
--- /dev/null
+++ b/geb/tasks/__init__.py
@@ -0,0 +1,13 @@
+# ruff: noqa: F403
+
+from .tasks import Task
+from .eds_tasks import *
+from .pair_classification_tasks import *
+from .retrieval_tasks import *
+from .classification_tasks import *
+from .clustering_tasks import *
+from .bigene_mining_tasks import *
+
+__all__ = [
+ "Task",
+]
diff --git a/geb/tasks/bigene_mining_tasks.py b/geb/tasks/bigene_mining_tasks.py
new file mode 100644
index 0000000..1b7607e
--- /dev/null
+++ b/geb/tasks/bigene_mining_tasks.py
@@ -0,0 +1,77 @@
+"""
+Bigene mining tasks are analogous to bitext matching tasks, but for genes.
+Cosine similarity is used to mine genes of related functions from different organisms.
+"""
+
+import logging
+from collections import defaultdict
+
+from geb.evaluators import BiGeneMiningEvaluator
+from geb.modality import Modality
+from geb.models import BioSeqTransformer
+from geb.tasks.tasks import Dataset, Task, TaskMetadata, TaskResult
+
+logger = logging.getLogger(__name__)
+
+
+def run_bigene_mining_tasks(
+ model: BioSeqTransformer, metadata: TaskMetadata, top_k: int = 1
+) -> TaskResult:
+ """Evaluate bigene mining task. Utilizes the BiGeneMiningEvaluator."""
+ if len(metadata.datasets) != 1:
+ raise ValueError("BiGeneMining tasks require 1 dataset.")
+ ds = metadata.datasets[0].load()["train"]
+ layer_results = defaultdict(dict)
+ embeds1 = model.encode(ds["Seq1"])
+ embeds2 = model.encode(ds["Seq2"])
+ for i, layer in enumerate(model.layers):
+ evaluator = BiGeneMiningEvaluator(embeds1[:, i], embeds2[:, i], top_k=top_k)
+ layer_results["layers"][layer] = evaluator()
+ logger.info(
+ f"Layer: {layer}, {metadata.display_name} matching results: {layer_results['layers'][layer]}"
+ )
+ return TaskResult.from_dict(metadata, layer_results, model.metadata)
+
+
+class BacArchBiGeneMining(Task):
+ metadata = TaskMetadata(
+ id="bacarch_bigene",
+ display_name="BacArch BiGene",
+ description="Evaluate on BacArch bigene matching task between bacterial (E.coli K-12) proteins and archaeal (Sulfolobus acidocaldarius DSM 639) proteins.",
+ type="bigene_mining",
+ modality=Modality.PROTEIN,
+ datasets=[
+ Dataset(
+ path="tattabio/bac_arch_bigene",
+ revision="main",
+ )
+ ],
+ primary_metric_id="f1",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_bigene_mining_tasks(model, self.metadata)
+
+
+class ModACParalogyBiGeneMining(Task):
+ # ModAC Paralogy matching with top_k=1 is too strict (most models have accuracy < 0.1%)
+ # Instead use recall@5 as the main metric.
+ TOP_K = 5
+
+ metadata = TaskMetadata(
+ id="modac_paralogy_bigene",
+ display_name="ModAC Paralogy BiGene",
+ description="Evaluate on paralogy bitext matching task between paralogous protein (ModA and ModC).",
+ type="bigene_mining",
+ modality=Modality.PROTEIN,
+ datasets=[
+ Dataset(
+ path="tattabio/modac_paralogy_bigene",
+ revision="main",
+ )
+ ],
+ primary_metric_id="recall_at_5",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_bigene_mining_tasks(model, self.metadata, top_k=self.TOP_K)
diff --git a/geb/tasks/classification_tasks.py b/geb/tasks/classification_tasks.py
new file mode 100644
index 0000000..99baac7
--- /dev/null
+++ b/geb/tasks/classification_tasks.py
@@ -0,0 +1,208 @@
+"""
+Classification tasks take in biological sequence and functional labels.
+Multi-class and/or multi-label classification tasks are supported.
+"""
+
+import logging
+from collections import defaultdict
+
+import datasets
+import numpy as np
+
+from geb.eval_utils import merge_split_elem_embeds
+from geb.evaluators import (
+ MultiClassMultiOutputKNNClassificationEvaluator,
+ logRegClassificationEvaluator,
+)
+from geb.modality import Modality
+from geb.models import BioSeqTransformer
+from geb.tasks.tasks import Dataset, Task, TaskMetadata, TaskResult
+
+logger = logging.getLogger(__name__)
+
+
+def split_sequences(
+ ds: datasets.DatasetDict, max_seq_length: int
+) -> datasets.DatasetDict:
+ """Split sequences into chunks of max_seq_length using datasets.Dataset.map()."""
+
+ def _split_sequence(examples, max_seq_length):
+ assert (
+ len(examples["Sequence"]) == 1
+ ), "split map function should use batch size of 1."
+ example = {k: v[0] for k, v in examples.items()}
+ seq = example["Sequence"]
+ # Split by chunks of max_seq_length.
+ seq_split = [
+ seq[i : i + max_seq_length] for i in range(0, len(seq), max_seq_length)
+ ]
+ # Repeat other fields by the number of splits.
+ example = {
+ k: [v] * len(seq_split) for k, v in example.items() if k != "Sequence"
+ }
+ example["Sequence"] = seq_split
+ return example
+
+ ds = ds.map(
+ _split_sequence,
+ batched=True,
+ batch_size=1,
+ fn_kwargs={"max_seq_length": max_seq_length},
+ keep_in_memory=True,
+ load_from_cache_file=False,
+ )
+ return ds
+
+
+def run_classification_task(
+ model: BioSeqTransformer, metadata: TaskMetadata
+) -> TaskResult:
+ """Evaluate on classification tasks using logistic regression classifier."""
+ ds = metadata.datasets[0].load()
+ layer_results = defaultdict(dict)
+ train_embeds = model.encode(ds["train"]["Sequence"])
+ test_embeds = model.encode(ds["test"]["Sequence"])
+ for i, layer in enumerate(model.layers):
+ layer_results["layers"][layer] = logRegClassificationEvaluator(
+ train_embeds[:, i],
+ ds["train"]["Label"],
+ test_embeds[:, i],
+ ds["test"]["Label"],
+ )()
+ logger.info(
+ f"Layer: {layer}, {metadata.display_name} results: {layer_results['layers'][layer]}"
+ )
+ return TaskResult.from_dict(metadata, layer_results, model.metadata)
+
+
+class EnzymeCommissionClassification(Task):
+ metadata = TaskMetadata(
+ id="ec_classification",
+ display_name="EC Classification",
+ description="Evaluate on Enzyme Commission number classification task.",
+ type="classification",
+ modality=Modality.PROTEIN,
+ datasets=[
+ Dataset(
+ path="tattabio/ec_classification",
+ revision="main",
+ )
+ ],
+ primary_metric_id="f1",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_classification_task(model, self.metadata)
+
+
+class EnzymeCommissionDNAClassification(Task):
+ metadata = TaskMetadata(
+ id="ec_dna_classification",
+ display_name="EC Classification",
+ description="Evaluate on Enzyme Commission number classification task using DNA sequences.",
+ type="classification",
+ modality=Modality.DNA,
+ datasets=[
+ Dataset(
+ path="tattabio/ec_classification_dna",
+ revision="main",
+ )
+ ],
+ primary_metric_id="f1",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_classification_task(model, self.metadata)
+
+
+class ConvergentEnzymesClassification(Task):
+ metadata = TaskMetadata(
+ id="convergent_enzymes_classification",
+ display_name="Convergent Enzymes Classification",
+ description="Evaluate on convergent enzymes classification task, where convergent enzymes are proteins with the same EC number but without blastp hits against each other",
+ type="classification",
+ modality=Modality.PROTEIN,
+ datasets=[
+ Dataset(
+ path="tattabio/convergent_enzymes",
+ revision="main",
+ )
+ ],
+ primary_metric_id="f1",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_classification_task(model, self.metadata)
+
+
+def run_mibig_task(model: BioSeqTransformer, metadata: TaskMetadata) -> TaskResult:
+ """
+ Evaluate on MIBIG classification tasks. Multiclass, multi-label KNN classification is used for evaluation.
+ """
+ ds = metadata.datasets[0].load()
+ if metadata.modality == Modality.DNA:
+ # MIBiG DNA sequences can be very long. Instead of truncating to max_seq_length,
+ # split into multiple sequences and mean pool the resulting embeddings.
+ ds = split_sequences(ds, model.max_seq_length)
+
+ layer_results = defaultdict(dict)
+ train_embeds = model.encode(ds["train"]["Sequence"])
+ test_embeds = model.encode(ds["test"]["Sequence"])
+
+ train_ids = ds["train"]["Entry"]
+ test_ids = ds["test"]["Entry"]
+ train_labels = ds["train"]["class"]
+ test_labels = ds["test"]["class"]
+ train_id_to_label = {id: label for id, label in zip(train_ids, train_labels)}
+ test_id_to_label = {id: label for id, label in zip(test_ids, test_labels)}
+ # Mean pool embeds with the same ID.
+ train_ids, train_embeds = merge_split_elem_embeds(train_ids, train_embeds)
+ test_ids, test_embeds = merge_split_elem_embeds(test_ids, test_embeds)
+ # Gather the labels after merging by unique ID.
+ train_labels = np.array([train_id_to_label[id] for id in train_ids])
+ test_labels = np.array([test_id_to_label[id] for id in test_ids])
+
+ for i, layer in enumerate(model.layers):
+ evaluator = MultiClassMultiOutputKNNClassificationEvaluator(
+ train_embeds[:, i], train_labels, test_embeds[:, i], test_labels
+ )
+ layer_results["layers"][layer] = evaluator()
+ logger.info(
+ f"Layer: {layer}, MIBiG classification results: {layer_results['layers'][layer]}"
+ )
+ return TaskResult.from_dict(metadata, layer_results, model.metadata)
+
+
+class MIBiGProteinClassification(Task):
+ metadata = TaskMetadata(
+ id="MIBIG_protein_classification",
+ display_name="MIBiG Classification",
+ description="Biosynthetic Gene cluster classification using protein sequences on MIBIG dataset.",
+ type="classification",
+ modality=Modality.PROTEIN,
+ datasets=[Dataset(path="tattabio/mibig_classification_prot", revision="main")],
+ primary_metric_id="f1",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_mibig_task(model, self.metadata)
+
+
+class MIBiGDNAClassification(Task):
+ metadata = TaskMetadata(
+ id="MIBIG_dna_classification",
+ display_name="MIBiG Classification",
+ description="Biosynthetic Gene cluster classification using DNA sequences on MIBIG dataset.",
+ type="classification",
+ modality=Modality.DNA,
+ datasets=[
+ Dataset(
+ path="tattabio/mibig_classification_dna",
+ revision="main",
+ )
+ ],
+ primary_metric_id="f1",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_mibig_task(model, self.metadata)
diff --git a/geb/tasks/clustering_tasks.py b/geb/tasks/clustering_tasks.py
new file mode 100644
index 0000000..c12ddad
--- /dev/null
+++ b/geb/tasks/clustering_tasks.py
@@ -0,0 +1,70 @@
+"""
+Biological sequences are clustered and performance is determined by how well clustering matches assigned labels.
+"""
+
+import logging
+from collections import defaultdict
+
+from geb.evaluators import ClusteringEvaluator
+from geb.modality import Modality
+from geb.models import BioSeqTransformer
+from geb.tasks.tasks import Dataset, Task, TaskMetadata, TaskResult
+
+logger = logging.getLogger(__name__)
+
+
+def run_clustering_task(model: BioSeqTransformer, metadata: TaskMetadata) -> TaskResult:
+ """Evaluate clustering task. Utilizes the ClusteringEvaluator."""
+ if len(metadata.datasets) != 1:
+ raise ValueError("Clustering tasks require 1 dataset.")
+ ds = metadata.datasets[0].load()["train"]
+ embeds = model.encode(ds["Sequence"])
+ layer_results = defaultdict(dict)
+ for i, layer in enumerate(model.layers):
+ labels = ds["Label"]
+ evaluator = ClusteringEvaluator(embeds[:, i], labels)
+ layer_results["layers"][layer] = evaluator()
+ logger.info(
+ f"Layer: {layer}, {metadata.display_name} results: {layer_results['layers'][layer]}"
+ )
+ return TaskResult.from_dict(metadata, layer_results, model.metadata)
+
+
+class RNAclustering(Task):
+ metadata = TaskMetadata(
+ id="ecoli_rna_clustering",
+ display_name="E.coli RNA Clustering",
+ description="Evaluate on RNA clustering task for sRNA/tRNA/rRNA segments in E.coli K-12.",
+ type="clustering",
+ modality=Modality.DNA,
+ datasets=[
+ Dataset(
+ path="tattabio/e_coli_rnas",
+ revision="main",
+ )
+ ],
+ primary_metric_id="v_measure",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_clustering_task(model, self.metadata)
+
+
+class MopBClustering(Task):
+ metadata = TaskMetadata(
+ id="mopb_clustering",
+ display_name="MopB Clustering",
+ description="Evaluate on MopB clustering task.",
+ type="clustering",
+ modality=Modality.PROTEIN,
+ datasets=[
+ Dataset(
+ path="tattabio/mopb_clustering",
+ revision="main",
+ )
+ ],
+ primary_metric_id="v_measure",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_clustering_task(model, self.metadata)
diff --git a/geb/tasks/eds_tasks.py b/geb/tasks/eds_tasks.py
new file mode 100644
index 0000000..849bed9
--- /dev/null
+++ b/geb/tasks/eds_tasks.py
@@ -0,0 +1,198 @@
+"""
+Evolutionary Distance Similarity (EDS) tasks compare embedding distances to continuous evolutionary distances.
+The label distances are typically derived from phylogenetic trees.
+"""
+
+import logging
+from collections import defaultdict
+
+import numpy as np
+import pandas as pd
+
+from geb.evaluators import EDSEvaluator
+from geb.modality import Modality
+from geb.models import BioSeqTransformer
+from geb.tasks.tasks import Dataset, Task, TaskMetadata, TaskResult
+
+logger = logging.getLogger(__name__)
+
+
+def run_eds_task(model: BioSeqTransformer, metadata: TaskMetadata) -> TaskResult:
+ """Evaluate phylogeny distance correlation task. Utilizes the Evolutionary Distance Similarity (EDS) evaluator."""
+ if len(metadata.datasets) != 2:
+ raise ValueError("Phylogeny tasks require 2 datasets: sequences and distances.")
+
+ ds = metadata.datasets[0].load()["train"]
+ distance_df = metadata.datasets[1].load()["train"].to_pandas()
+ assert isinstance(
+ distance_df, pd.DataFrame
+ ), f"Expected DataFrame, got {type(distance_df)}"
+
+ id_index_dict = {k: i for i, k in enumerate(ds["Entry"])}
+ distance_df["embeds1"] = None
+ distance_df["embeds2"] = None
+ test_embeds = model.encode(ds["Sequence"])
+ layer_results = defaultdict(dict)
+ for i, layer in enumerate(model.layers):
+ for row_idx, row in distance_df.iterrows():
+ id1 = row["ID1"]
+ id2 = row["ID2"]
+ embedding1 = test_embeds[id_index_dict[id1], i]
+ embedding2 = test_embeds[id_index_dict[id2], i]
+ distance_df.at[row_idx, "embeds1"] = embedding1
+ distance_df.at[row_idx, "embeds2"] = embedding2
+ embeds1 = np.array(distance_df["embeds1"].tolist())
+ embeds2 = np.array(distance_df["embeds2"].tolist())
+ dists = np.array(distance_df["distance"].tolist())
+ evaluator = EDSEvaluator(embeds1, embeds2, dists)
+ layer_results["layers"][layer] = evaluator()
+ # log results
+ logger.info(
+ f"Layer: {layer}, {metadata.display_name} distance correlation results: {layer_results['layers'][layer]}"
+ )
+
+ return TaskResult.from_dict(metadata, layer_results, model.metadata)
+
+
+class RpobBacPhylogeny(Task):
+ metadata = TaskMetadata(
+ id="rpob_bac_phylogeny",
+ display_name="RpoB Bacterial Phylogeny",
+ description="Evaluate on RpoB phylogeny distance correlation task for Bacterial sequences.",
+ type="eds",
+ modality=Modality.PROTEIN,
+ datasets=[
+ Dataset(
+ path="tattabio/rpob_bac_phylogeny_sequences",
+ revision="main",
+ ),
+ Dataset(
+ path="tattabio/rpob_bac_phylogeny_distances",
+ revision="main",
+ ),
+ ],
+ primary_metric_id="top_corr",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_eds_task(model, self.metadata)
+
+
+class RpobArchPhylogeny(Task):
+ metadata = TaskMetadata(
+ id="rpob_arch_phylogeny",
+ display_name="RpoB Archaeal Phylogeny",
+ description="Evaluate on RpoB phylogeny distance correlation task for Archaeal sequences.",
+ type="eds",
+ modality=Modality.PROTEIN,
+ datasets=[
+ Dataset(
+ path="tattabio/rpob_arch_phylogeny_sequences",
+ revision="main",
+ ),
+ Dataset(
+ path="tattabio/rpob_arch_phylogeny_distances",
+ revision="main",
+ ),
+ ],
+ primary_metric_id="top_corr",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_eds_task(model, self.metadata)
+
+
+class FeFePhylogeny(Task):
+ metadata = TaskMetadata(
+ id="fefe_phylogeny",
+ display_name="FeFeHydrogenase Phylogeny",
+ description="Evaluate on FeFeHydrogenase phylogeny distance correlation task.",
+ type="eds",
+ modality=Modality.PROTEIN,
+ datasets=[
+ Dataset(
+ path="tattabio/fefe_phylogeny_sequences",
+ revision="main",
+ ),
+ Dataset(
+ path="tattabio/fefe_phylogeny_distances",
+ revision="main",
+ ),
+ ],
+ primary_metric_id="top_corr",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_eds_task(model, self.metadata)
+
+
+class Bac16SPhylogeny(Task):
+ metadata = TaskMetadata(
+ id="bac_16S_phylogeny",
+ display_name="16S Bacterial Phylogeny",
+ description="Evaluate on 16S Bacterial phylogeny distance correlation task.",
+ type="eds",
+ modality=Modality.DNA,
+ datasets=[
+ Dataset(
+ path="tattabio/bac_16S_sequences",
+ revision="main",
+ ),
+ Dataset(
+ path="tattabio/bac_16S_distances",
+ revision="main",
+ ),
+ ],
+ primary_metric_id="top_corr",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_eds_task(model, self.metadata)
+
+
+class Arch16SPhylogeny(Task):
+ metadata = TaskMetadata(
+ id="arch_16S_phylogeny",
+ display_name="16S Archaeal Phylogeny",
+ description="Evaluate on 16S Archaeal phylogeny distance correlation task.",
+ type="eds",
+ modality=Modality.DNA,
+ datasets=[
+ Dataset(
+ path="tattabio/arch_16S_sequences",
+ revision="main",
+ ),
+ Dataset(
+ path="tattabio/arch_16S_distances",
+ revision="main",
+ ),
+ ],
+ primary_metric_id="top_corr",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_eds_task(model, self.metadata)
+
+
+class Euk18SPhylogeny(Task):
+ metadata = TaskMetadata(
+ id="euk_18S_phylogeny",
+ display_name="18S Eukaryotic Phylogeny",
+ description="Evaluate on 18S Eukaryotic phylogeny distance correlation task.",
+ type="eds",
+ modality=Modality.DNA,
+ datasets=[
+ Dataset(
+ path="tattabio/euk_18S_sequences",
+ revision="main",
+ ),
+ Dataset(
+ path="tattabio/euk_18S_distances",
+ revision="main",
+ ),
+ ],
+ primary_metric_id="top_corr",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_eds_task(model, self.metadata)
diff --git a/geb/tasks/pair_classification_tasks.py b/geb/tasks/pair_classification_tasks.py
new file mode 100644
index 0000000..316b201
--- /dev/null
+++ b/geb/tasks/pair_classification_tasks.py
@@ -0,0 +1,96 @@
+"""
+Pair classification tasks evaluating distances between functionally relevant gene pairs.
+For instance, distance thresholds distinguish between co-transcribed and non-co-transcribed gene pairs.
+"""
+
+import logging
+from collections import defaultdict
+
+from geb.evaluators import PairClassificationEvaluator
+from geb.modality import Modality
+from geb.models import BioSeqTransformer
+from geb.tasks.tasks import Dataset, Task, TaskMetadata, TaskResult
+
+from ..eval_utils import paired_dataset
+
+logger = logging.getLogger(__name__)
+
+
+def run_pair_classification_task(
+ model: BioSeqTransformer, metadata: TaskMetadata
+) -> TaskResult:
+ """Evaluate pair classification task. Utilizes the PairClassificationEvaluator."""
+ if len(metadata.datasets) != 1:
+ raise ValueError("Pair classification tasks require 1 dataset.")
+ ds = metadata.datasets[0].load()["train"]
+ embeds = model.encode(ds["Sequence"])
+ layer_results = defaultdict(dict)
+ for i, layer in enumerate(model.layers):
+ labels = ds["Label"]
+ embeds1, embeds2, labels = paired_dataset(labels, embeds[:, i])
+ evaluator = PairClassificationEvaluator(embeds1, embeds2, labels)
+ layer_results["layers"][layer] = evaluator()
+ logger.info(
+ f"Layer: {layer}, {metadata.display_name} classification results: {layer_results['layers'][layer]}"
+ )
+ return TaskResult.from_dict(metadata, layer_results, model.metadata)
+
+
+class EcoliOperon(Task):
+ metadata = TaskMetadata(
+ id="ecoli_operonic_pair",
+ display_name="E.coli Operonic Pair",
+ description="Evaluate on E.coli K-12 operonic pair classification task.",
+ type="pair_classification",
+ modality=Modality.PROTEIN,
+ datasets=[
+ Dataset(
+ path="tattabio/ecoli_operonic_pair",
+ revision="main",
+ )
+ ],
+ primary_metric_id="top_ap",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_pair_classification_task(model, self.metadata)
+
+
+class CyanoOperonPair(Task):
+ metadata = TaskMetadata(
+ id="cyano_operonic_pair",
+ display_name="Cyano Operonic Pair",
+ description="Evaluate on Cyano operonic pair classification task.",
+ type="pair_classification",
+ modality=Modality.PROTEIN,
+ datasets=[
+ Dataset(
+ path="tattabio/cyano_operonic_pair",
+ revision="main",
+ )
+ ],
+ primary_metric_id="top_ap",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_pair_classification_task(model, self.metadata)
+
+
+class VibrioOperonPair(Task):
+ metadata = TaskMetadata(
+ id="vibrio_operonic_pair",
+ display_name="Vibrio Operonic Pair",
+ description="Evaluate on Vibrio operonic pair classification task.",
+ type="pair_classification",
+ modality=Modality.PROTEIN,
+ datasets=[
+ Dataset(
+ path="tattabio/vibrio_operonic_pair",
+ revision="main",
+ )
+ ],
+ primary_metric_id="top_ap",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_pair_classification_task(model, self.metadata)
diff --git a/geb/tasks/retrieval_tasks.py b/geb/tasks/retrieval_tasks.py
new file mode 100644
index 0000000..132754e
--- /dev/null
+++ b/geb/tasks/retrieval_tasks.py
@@ -0,0 +1,98 @@
+"""
+Retrieval tasks find functionally relevant genes in a corpus of genes based on a query gene.
+Typically corpus is derived from a different phylogenetic group than the query genes.
+"""
+
+import logging
+from collections import defaultdict
+
+from geb.evaluators import RetrievalEvaluator
+from geb.modality import Modality
+from geb.models import BioSeqTransformer
+from geb.tasks.tasks import Dataset, Task, TaskMetadata, TaskResult
+
+logger = logging.getLogger(__name__)
+
+
+def run_retrieval_task(model: BioSeqTransformer, metadata: TaskMetadata) -> TaskResult:
+ """Evaluate retrieval task. Utilizes the Retrieval evaluator."""
+ if len(metadata.datasets) != 2:
+ raise ValueError("Retrieval tasks require 3 datasets: corpus, query and qrels.")
+ corpus_ds = metadata.datasets[0].load()["train"]
+ query_ds = metadata.datasets[0].load()["test"]
+ qrels = metadata.datasets[1].load()
+ corpus_embeds = model.encode(corpus_ds["Sequence"])
+ query_embeds = model.encode(query_ds["Sequence"])
+ qrels_dict = defaultdict(dict)
+
+ def qrels_dict_init(row):
+ qrels_dict[str(row["query_id"])][str(row["corpus_id"])] = int(row["fuzz_ratio"])
+
+ # Populate `qrels_dict` from the dataset.
+ # See https://github.com/cvangysel/pytrec_eval for qrels format.
+ qrels.map(qrels_dict_init)
+ qrels = qrels_dict
+ layer_results = defaultdict(dict)
+ for i, layer in enumerate(model.layers):
+ evaluator = RetrievalEvaluator(
+ corpus_embeds[:, i],
+ query_embeds[:, i],
+ corpus_ds["Entry"],
+ query_ds["Entry"],
+ qrels,
+ )
+ layer_results["layers"][layer] = evaluator()
+ logger.info(
+ f"Layer: {layer}, Retrieval results: {layer_results['layers'][layer]}"
+ )
+ return TaskResult.from_dict(metadata, layer_results, model.metadata)
+
+
+class ArchRetrieval(Task):
+ metadata = TaskMetadata(
+ id="arch_retrieval",
+ display_name="Arch Retrieval",
+ description="Retrieves bacterial proteins with similar swissprot annotations to a query archaeal protein",
+ type="retrieval",
+ modality=Modality.PROTEIN,
+ datasets=[
+ Dataset(
+ path="tattabio/arch_retrieval",
+ revision="main",
+ ),
+ Dataset(
+ path="tattabio/arch_retrieval_qrels",
+ description="Relevance between query and corpus proteins",
+ revision="main",
+ ),
+ ],
+ primary_metric_id="map_at_5",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_retrieval_task(model, self.metadata)
+
+
+class EukRetrieval(Task):
+ metadata = TaskMetadata(
+ id="euk_retrieval",
+ display_name="Euk Retrieval",
+ description="Retrieves bacterial proteins with similar swissprot annotations to a query eukaryotic protein",
+ type="retrieval",
+ modality=Modality.PROTEIN,
+ datasets=[
+ Dataset(
+ path="tattabio/euk_retrieval",
+ revision="main",
+ ),
+ Dataset(
+ path="tattabio/euk_retrieval_qrels",
+ description="Relevance between query and corpus proteins",
+ revision="main",
+ ),
+ ],
+ primary_metric_id="map_at_5",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return run_retrieval_task(model, self.metadata)
diff --git a/geb/tasks/tasks.py b/geb/tasks/tasks.py
new file mode 100644
index 0000000..6f7dc76
--- /dev/null
+++ b/geb/tasks/tasks.py
@@ -0,0 +1,136 @@
+"""Task functions for evaluation.
+# TODO: Add dataset revisions.
+"""
+
+import logging
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Literal, Optional
+
+import datasets
+from pydantic import BaseModel, model_validator
+
+from ..modality import Modality
+from ..models import BioSeqTransformer
+
+logging.basicConfig(level=logging.INFO)
+
+TaskType = Literal[
+ "classification",
+ "pair_classification",
+ "clustering",
+ "eds",
+ "bigene_mining",
+ "retrieval",
+]
+
+
+class TaskMetric(BaseModel):
+ id: str
+ display_name: str
+ description: Optional[str] = None
+ value: float = 0.0
+
+
+class LayerResult(BaseModel):
+ layer_number: int
+ layer_display_name: str
+ metrics: List[TaskMetric]
+
+
+class TaskResult(BaseModel):
+ task: "TaskMetadata"
+ # TODO: Convert model to ModelMetadata
+ model: Dict[str, Any]
+ results: List[LayerResult]
+
+ @model_validator(mode="after")
+ def check_valid_primary_metric(self):
+ for result in self.results:
+ if all(
+ metric.id != self.task.primary_metric_id for metric in result.metrics
+ ):
+ raise ValueError(
+ f"Primary metric {self.task.primary_metric_id} not found in results.metrics"
+ )
+ return self
+
+ @staticmethod
+ def from_dict(
+ task_metadata: "TaskMetadata",
+ layer_results: Dict[str, Any],
+ model_metadata: Dict[str, Any],
+ ):
+ return TaskResult(
+ task=task_metadata,
+ model=model_metadata,
+ results=list(
+ LayerResult(
+ layer_number=int(layer),
+ layer_display_name=str(layer),
+ metrics=[
+ TaskMetric(id=metric, display_name=metric, value=value)
+ for metric, value in metrics.items()
+ ],
+ )
+ for layer, metrics in layer_results["layers"].items()
+ ),
+ )
+
+
+class Dataset(BaseModel):
+ path: str
+ revision: str
+
+ def load(self) -> datasets.DatasetDict:
+ ds = datasets.load_dataset(self.path, revision=self.revision)
+ if not isinstance(ds, datasets.DatasetDict):
+ raise ValueError(
+ f"Dataset {self.path} is not a datasets.DatasetDict object."
+ )
+ return ds
+
+
+class TaskMetadata(BaseModel):
+ id: str
+ display_name: str
+ description: str
+ modality: Modality
+ type: TaskType
+ # List of datasets used by the task.
+ # Each dataset is a dict of all arguments to pass to `datasets.load_dataset()`.
+ datasets: List[Dataset]
+ primary_metric_id: str
+
+
+class Task(ABC):
+ metadata: TaskMetadata
+
+ @abstractmethod
+ def run(
+ self, model: BioSeqTransformer, layers: Optional[List[int]] = None
+ ) -> TaskResult:
+ pass
+
+
+class noop(Task):
+ metadata = TaskMetadata(
+ id="noop",
+ display_name="NoOp Task",
+ description="This task is used for testing and does nothing.",
+ type="classification",
+ modality=Modality.PROTEIN,
+ datasets=[
+ Dataset(
+ path="",
+ revision="main",
+ )
+ ],
+ primary_metric_id="f1",
+ )
+
+ def run(self, model: BioSeqTransformer) -> TaskResult:
+ return TaskResult.from_dict(
+ self.metadata,
+ {"layers": {32: {"accuracy": 0.5, "f1": 0.5}}},
+ model.metadata,
+ )
diff --git a/plot_benchmarks.py b/plot_benchmarks.py
new file mode 100644
index 0000000..1d75ceb
--- /dev/null
+++ b/plot_benchmarks.py
@@ -0,0 +1,152 @@
+"""
+Given a directory of results, plot the benchmarks for each task as a bar chart and line chart.
+"""
+
+import argparse
+import os
+from typing import Optional
+
+import matplotlib.pyplot as plt
+import pandas as pd
+import seaborn as sns
+
+from geb.geb import get_all_tasks, get_output_folder, get_tasks_by_name
+from geb.tasks.tasks import TaskResult
+
+ALL_TASKS = [task.metadata.id for task in get_all_tasks()]
+
+
+def plot_benchmarks(
+ results_dir,
+ task_ids: Optional[list[str]] = None,
+ output="benchmarks.png",
+ model_substring=None,
+):
+ models = os.listdir(results_dir)
+ all_results = []
+ tasks = get_all_tasks() if task_ids is None else get_tasks_by_name(task_ids)
+ for model_name in models:
+ if model_substring is not None and all(
+ substr not in model_name for substr in model_substring
+ ):
+ continue
+
+ for task in tasks:
+ if task.metadata.display_name == "NoOp Task":
+ continue
+ filepath = get_output_folder(model_name, task, results_dir, create=False)
+ # if the file does not exist, skip
+ if not os.path.exists(filepath):
+ continue
+
+ with open(filepath) as f:
+ task_result = TaskResult.model_validate_json(f.read())
+ num_params = task_result.model["num_params"]
+ primary_metric_id = task_result.task.primary_metric_id
+ main_scores = [
+ metric.value
+ for layer_result in task_result.results
+ for metric in layer_result.metrics
+ if metric.id == primary_metric_id
+ ]
+ best_score = max(main_scores)
+ all_results.append(
+ {
+ "task": task.metadata.display_name,
+ "model": model_name,
+ "num_params": num_params,
+ "score": best_score,
+ }
+ )
+
+ results_df = pd.DataFrame(all_results)
+ # order the models by ascending number of parameters
+ results_df["num_params"] = results_df["num_params"].astype(int)
+ results_df = results_df.sort_values(by="num_params")
+ # number of tasks
+ n_tasks = len(set(results_df["task"]))
+
+ _, ax = plt.subplots(2, n_tasks, figsize=(5 * n_tasks, 10))
+
+ for i, task in enumerate(set(results_df["task"])):
+ if n_tasks > 1:
+ sns.barplot(
+ x="model",
+ y="score",
+ data=results_df[results_df["task"] == task],
+ ax=ax[0][i],
+ )
+ ax[0][i].set_title(task)
+ # rotate the x axis labels
+
+ for tick in ax[0][i].get_xticklabels():
+ tick.set_rotation(90)
+ else:
+ sns.barplot(
+ x="model",
+ y="score",
+ data=results_df[results_df["task"] == task],
+ ax=ax[0],
+ )
+ ax[0].set_title(task)
+ # rotate the x axis labels
+ for tick in ax[0].get_xticklabels():
+ tick.set_rotation(90)
+
+ # make a line graph with number of parameters on x axis for each task in the second row of figures
+ for i, task in enumerate(set(results_df["task"])):
+ if n_tasks > 1:
+ sns.lineplot(
+ x="num_params",
+ y="score",
+ data=results_df[results_df["task"] == task],
+ ax=ax[1][i],
+ )
+ ax[1][i].set_title(task)
+ ax[1][i].set_xlabel("Number of parameters")
+ else:
+ sns.lineplot(
+ x="num_params",
+ y="score",
+ data=results_df[results_df["task"] == task],
+ ax=ax[1],
+ )
+ ax[1].set_title(task)
+ ax[1].set_xlabel("Number of parameters")
+
+ plt.tight_layout()
+ plt.savefig(output)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-d",
+ "--results_dir",
+ type=str,
+ default="results",
+ help="Directory containing the results of the benchmarking",
+ )
+ parser.add_argument(
+ "-t",
+ "--tasks",
+ type=lambda s: [item for item in s.split(",")],
+ default=None,
+ help=f"Comma separated list of tasks to plot. Choose from {ALL_TASKS} or do not specify to plot all tasks. ",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ type=str,
+ default="benchmarks.png",
+ help="Output file for the plot",
+ )
+ parser.add_argument(
+ "--model_substring",
+ type=lambda s: [item for item in s.split(",")],
+ default=None,
+ help="Comma separated list of model substrings. Only plot results for models containing this substring",
+ )
+ args = parser.parse_args()
+
+ plot_benchmarks(args.results_dir, args.tasks, args.output, args.model_substring)
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..47b60d3
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,113 @@
+[build-system]
+requires = ["setuptools>=42", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "dgeb"
+version = "0.0.0"
+description = "Diverse Genomic Embedding Benchmark"
+readme = "README.md"
+license = { file = "LICENSE" }
+keywords = ["scientific software", "genomic embeddings", "machine learning", "benchmark"]
+classifiers = [
+ "Development Status :: 2 - Pre-Alpha",
+ "Environment :: Console",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Information Technology",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: Apache Software License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python",
+]
+dependencies = [
+ "datasets>=2.20.0",
+ "matplotlib>=3.9.0",
+ "numpy>=2.0.0",
+ "pandas>=2.2.2",
+ "pydantic>=2.7.4",
+ "pytrec_eval>=0.5",
+ "rich>=13.7.1",
+ "scikit_learn>=1.5.0",
+ "scipy>=1.13.1",
+ "seaborn>=0.13.2",
+ "torch>=2.3.1",
+ "tqdm>=4.66.4",
+ "transformers>=4.41.2"
+]
+
+[project.urls]
+homepage = "https://github.com/TattaBio/DGEB"
+"Huggingface Organization" = "https://huggingface.co/tattabio"
+"Source Code" = "https://github.com/TattaBio/DGEB"
+
+[project.optional-dependencies]
+dev = ["ruff>=0.0.254", "pytest", "pytest-xdist"]
+
+[tool.setuptools.packages.find]
+exclude = ["tests", "results"]
+
+[tool.setuptools.package-data]
+"*" = ["*.json"]
+
+[tool.ruff]
+target-version = "py38"
+exclude = [
+ ".venv",
+ "build/"
+]
+line-length = 88
+indent-width = 4
+
+[tool.semantic_release]
+version_toml = ["pyproject.toml:project.version"]
+build_command = "python -m pip install build; python -m build"
+commit_message = "{version}\n\nAutomatically generated by python-semantic-release"
+logging_use_named_masks = false
+major_on_zero = true
+allow_zero_version = true
+no_git_verify = false
+tag_format = "v{version}"
+
+[tool.semantic_release.branches.main]
+match = "(main|master)"
+prerelease_token = "rc"
+prerelease = false
+
+[tool.semantic_release.changelog]
+template_dir = "templates"
+changelog_file = "CHANGELOG.md"
+exclude_commit_patterns = []
+
+[tool.semantic_release.changelog.environment]
+block_start_string = "{%"
+block_end_string = "%}"
+variable_start_string = "{{"
+variable_end_string = "}}"
+comment_start_string = "{#"
+comment_end_string = "#}"
+trim_blocks = false
+lstrip_blocks = false
+newline_sequence = "\n"
+keep_trailing_newline = false
+extensions = []
+autoescape = true
+
+[tool.semantic_release.commit_author]
+env = "GIT_COMMIT_AUTHOR"
+default = "semantic-release "
+
+[tool.semantic_release.commit_parser_options]
+allowed_tags = ["build", "chore", "ci", "docs", "feat", "fix", "perf", "style", "refactor", "test"]
+minor_tags = ["feat"]
+patch_tags = ["fix", "perf"]
+default_bump_level = 0
+
+[tool.semantic_release.remote]
+name = "origin"
+type = "github"
+ignore_token_for_push = false
+insecure = false
+
+[tool.semantic_release.publish]
+dist_glob_patterns = ["dist/*"]
+upload_to_vcs_release = true
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..6bf7651
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,13 @@
+datasets>=2.20.0
+matplotlib>=3.9.0
+numpy>=2.0.0
+pandas>=2.2.2
+pydantic>=2.7.4
+pytrec_eval>=0.5
+rich>=13.7.1
+scikit_learn>=1.5.0
+scipy>=1.13.1
+seaborn>=0.13.2
+torch>=2.3.1
+tqdm>=4.66.4
+transformers>=4.41.2
diff --git a/ruff.toml b/ruff.toml
new file mode 100644
index 0000000..5ff69db
--- /dev/null
+++ b/ruff.toml
@@ -0,0 +1,8 @@
+exclude = [
+ ".venv",
+ "build/",
+]
+# Same as Black.
+line-length = 88
+indent-width = 4
+
diff --git a/run_geb.py b/run_geb.py
new file mode 100644
index 0000000..ffd7767
--- /dev/null
+++ b/run_geb.py
@@ -0,0 +1,135 @@
+"""
+Main command to run genomic embedding benchmarks (GEB) on a model.
+example command to run GEB:
+python run_geb.py -m facebook/esm2_t6_8M_UR50D
+"""
+
+import argparse
+import logging
+import os
+import geb
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+ALL_TASK_NAMES = geb.get_all_task_names()
+ALL_MODEL_NAMES = geb.get_all_model_names()
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-m",
+ "--model",
+ type=str,
+ default=None,
+ help=f"Model to evaluate. Choose from {ALL_MODEL_NAMES}",
+ )
+ parser.add_argument(
+ "-t",
+ "--tasks",
+ type=lambda s: [item for item in s.split(",")],
+ default=None,
+ help=f"Comma separated tasks to evaluate on. Choose from {ALL_TASK_NAMES} or do not specify to evaluate on all tasks",
+ )
+ parser.add_argument(
+ "-l",
+ "--layers",
+ type=str,
+ default=None,
+ help="Layer to evaluate. Comma separated list of integers or 'mid' and 'last'. Default is 'mid,last'",
+ )
+ parser.add_argument(
+ "--devices",
+ type=str,
+ default="0",
+ help="Comma separated list of GPU device ids to use. Default is 0 (if GPUs are detected).",
+ )
+ parser.add_argument(
+ "--output_folder",
+ type=str,
+ default=None,
+ help="Output directory for results. Will default to results/model_name if not set.",
+ )
+ parser.add_argument(
+ "-v", "--verbosity", type=int, default=2, help="Verbosity level"
+ )
+ parser.add_argument(
+ "-b", "--batch_size", type=int, default=64, help="Batch size for evaluation"
+ )
+ parser.add_argument(
+ "--max_seq_len",
+ type=int,
+ default=1024,
+ help="Maximum sequence length for model, default is 1024.",
+ )
+ parser.add_argument(
+ "--pool_type",
+ type=str,
+ default="mean",
+ help="Pooling type for model, choose from mean, max, cls, last. Default is mean.",
+ )
+
+ args = parser.parse_args()
+
+ # set logging based on verbosity level
+ if args.verbosity == 0:
+ logging.getLogger("geb").setLevel(logging.CRITICAL)
+ elif args.verbosity == 1:
+ logging.getLogger("geb").setLevel(logging.WARNING)
+ elif args.verbosity == 2:
+ logging.getLogger("geb").setLevel(logging.INFO)
+ elif args.verbosity == 3:
+ logging.getLogger("geb").setLevel(logging.DEBUG)
+
+ if args.model is None:
+ raise ValueError("Please specify a model using the -m or --model argument")
+
+ # make sure that devices are comma separated list of integers
+ try:
+ devices = [int(device) for device in args.devices.split(",")]
+ except ValueError:
+ raise ValueError("Devices must be comma separated list of integers")
+
+ layers = args.layers
+ if layers:
+ if layers not in ["mid", "last"]:
+ # Layers should be list of integers.
+ try:
+ layers = [int(layer) for layer in layers.split(",")]
+ except ValueError:
+ raise ValueError("Layers must be a list of integers.")
+
+ model_name = args.model.split("/")[-1]
+ output_folder = args.output_folder
+ if output_folder is None:
+ output_folder = os.path.join("results", model_name)
+ # create output folder if it does not exist
+ if not os.path.exists(output_folder):
+ os.makedirs(output_folder)
+ logger.info(f"Results will be saved to {output_folder}")
+
+ # Load the model by name.
+ model = geb.get_model(
+ model_name=args.model,
+ layers=layers,
+ devices=devices,
+ max_seq_length=args.max_seq_len,
+ batch_size=args.batch_size,
+ pool_type=args.pool_type,
+ )
+
+ all_tasks_for_modality = geb.get_tasks_by_modality(model.modality)
+
+ if args.tasks:
+ task_list = geb.get_tasks_by_name(args.tasks)
+ if not all([task.metadata.modality == model.modality for task in task_list]):
+ raise ValueError(f"Tasks must be one of {all_tasks_for_modality}")
+ else:
+ task_list = all_tasks_for_modality
+ evaluation = geb.GEB(tasks=task_list)
+ _ = evaluation.run(model)
+
+
+if __name__ == "__main__":
+ main()