Skip to content

Commit

Permalink
Optional output folder
Browse files Browse the repository at this point in the history
  • Loading branch information
scopello committed Sep 3, 2024
1 parent a1125ca commit 694f9ae
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ import dgeb
model = dgeb.get_model("facebook/esm2_t6_8M_UR50D")
tasks = dgeb.get_tasks_by_modality(dgeb.Modality.PROTEIN)
evaluation = dgeb.DGEB(tasks=tasks)
evaluation.run(model, output_folder="results")
# Writes results to `output_folder`, and returns a list of TaskResult.
# You can disable writing to json by setting `output_folder=None`.
results = evaluation.run(model, output_folder="results")
```

### Using a custom model
Expand Down
16 changes: 8 additions & 8 deletions dgeb/dgeb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import os
import traceback
from itertools import chain
from typing import Any, List
from typing import Any, List, Optional

from rich.console import Console

from .eval_utils import set_all_seeds
from .modality import Modality
from .models import BioSeqTransformer
from .tasks.tasks import Task
from .tasks.tasks import Task, TaskResult

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand All @@ -36,8 +36,8 @@ def print_selected_tasks(self):
def run(
self,
model, # type encoder
output_folder: str = "results",
):
output_folder: Optional[str] = "results",
) -> List[TaskResult]:
"""Run the evaluation pipeline on the selected tasks.
Args:
Expand Down Expand Up @@ -66,10 +66,10 @@ def run(
continue

results.append(result)

save_path = get_output_folder(model.hf_name, task, output_folder)
with open(save_path, "w") as f_out:
f_out.write(result.model_dump_json(indent=2))
if output_folder:
save_path = get_output_folder(model.hf_name, task, output_folder)
with open(save_path, "w") as f_out:
f_out.write(result.model_dump_json(indent=2))
return results


Expand Down

0 comments on commit 694f9ae

Please sign in to comment.