diff --git a/README.md b/README.md index 548c650..1ce5228 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,9 @@ import dgeb model = dgeb.get_model("facebook/esm2_t6_8M_UR50D") tasks = dgeb.get_tasks_by_modality(dgeb.Modality.PROTEIN) evaluation = dgeb.DGEB(tasks=tasks) -evaluation.run(model, output_folder="results") +# Writes results to `output_folder`, and returns a list of TaskResult. +# You can disable writing to json by setting `output_folder=None`. +results = evaluation.run(model, output_folder="results") ``` ### Using a custom model diff --git a/dgeb/dgeb.py b/dgeb/dgeb.py index 6dfd940..1453a0e 100644 --- a/dgeb/dgeb.py +++ b/dgeb/dgeb.py @@ -2,14 +2,14 @@ import os import traceback from itertools import chain -from typing import Any, List +from typing import Any, List, Optional from rich.console import Console from .eval_utils import set_all_seeds from .modality import Modality from .models import BioSeqTransformer -from .tasks.tasks import Task +from .tasks.tasks import Task, TaskResult logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -36,8 +36,8 @@ def print_selected_tasks(self): def run( self, model, # type encoder - output_folder: str = "results", - ): + output_folder: Optional[str] = "results", + ) -> List[TaskResult]: """Run the evaluation pipeline on the selected tasks. Args: @@ -66,10 +66,10 @@ def run( continue results.append(result) - - save_path = get_output_folder(model.hf_name, task, output_folder) - with open(save_path, "w") as f_out: - f_out.write(result.model_dump_json(indent=2)) + if output_folder: + save_path = get_output_folder(model.hf_name, task, output_folder) + with open(save_path, "w") as f_out: + f_out.write(result.model_dump_json(indent=2)) return results