Skip to content

Commit

Permalink
Add a Comet.ml trainer callback
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Oct 9, 2024
1 parent 9ba0e63 commit 873d4c4
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ beaker = [
wandb = [
"wandb",
]
comet = [
"comet_ml",
]
all = [
"ai2-olmo-core[dev,beaker,wandb]",
"ai2-olmo-core[dev,beaker,wandb,comet]",
]

[tool.setuptools]
Expand Down
2 changes: 2 additions & 0 deletions src/olmo_core/train/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .callback import Callback, CallbackConfig
from .checkpointer import CheckpointerCallback, CheckpointRemovalStrategy
from .comet import CometCallback
from .config_saver import ConfigSaverCallback
from .console_logger import ConsoleLoggerCallback
from .evaluator_callback import EvaluatorCallback, LMEvaluatorCallbackConfig
Expand All @@ -18,6 +19,7 @@
"CallbackConfig",
"CheckpointerCallback",
"CheckpointRemovalStrategy",
"CometCallback",
"ConfigSaverCallback",
"ConsoleLoggerCallback",
"EvaluatorCallback",
Expand Down
147 changes: 147 additions & 0 deletions src/olmo_core/train/callbacks/comet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import logging
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional

from olmo_core.distributed.utils import get_rank
from olmo_core.exceptions import OLMoEnvironmentError

from .callback import Callback

if TYPE_CHECKING:
from comet_ml import Experiment

log = logging.getLogger(__name__)

COMET_API_KEY_ENV_VAR = "COMET_API_KEY"


@dataclass
class CometCallback(Callback):
"""
Logs metrics to Comet.ml from rank 0.
.. important::
Requires the ``comet_ml`` package and the environment variable ``COMET_API_KEY``.
.. note::
This callback logs metrics from every single step to Comet.ml, regardless of the value
of :data:`Trainer.metrics_collect_interval <olmo_core.train.Trainer.metrics_collect_interval>`.
"""

enabled: bool = True
"""
Set to false to disable this callback.
"""

name: Optional[str] = None
"""
The name to give the Comet.ml experiment.
"""

project: Optional[str] = None
"""
The Comet.ml project to use.
"""

workspace: Optional[str] = None
"""
The name of the Comet.ml workspace to use.
"""

tags: Optional[List[str]] = None
"""
Tags to assign the experiment.
"""

cancel_tags: Optional[List[str]] = field(
default_factory=lambda: ["cancel", "canceled", "cancelled"]
)
"""
If you add any of these tags to an experiment on Comet.ml, the run will cancel itself.
Defaults to ``["cancel", "canceled", "cancelled"]``.
"""

cancel_check_interval: Optional[int] = None
"""
Check for cancel tags every this many steps. Defaults to
:data:`olmo_core.train.Trainer.cancel_check_interval`.
"""

failure_tag: str = "failed"
"""
The tag to assign to failed experiments.
"""

_exp = None
_finalized: bool = False

@property
def exp(self) -> "Experiment":
return self._exp # type: ignore

@exp.setter
def exp(self, exp: "Experiment"):
self._exp = exp

@property
def finalized(self) -> bool:
return self._finalized

def finalize(self):
if not self.finalized:
self.exp.end()
self._finalized = True

def pre_train(self):
if self.enabled and get_rank() == 0:
import comet_ml as comet

if COMET_API_KEY_ENV_VAR not in os.environ:
raise OLMoEnvironmentError(f"missing env var '{COMET_API_KEY_ENV_VAR}'")

self.exp = comet.Experiment(
api_key=os.environ[COMET_API_KEY_ENV_VAR],
project_name=self.project,
workspace=self.workspace,
)

if self.name is not None:
self.exp.set_name(self.name)

if self.tags:
self.exp.add_tags(self.tags)

def log_metrics(self, step: int, metrics: Dict[str, float]):
if self.enabled and get_rank() == 0:
self.exp.log_metrics(metrics, step=step)

def post_step(self):
cancel_check_interval = self.cancel_check_interval or self.trainer.cancel_check_interval
if self.enabled and get_rank() == 0 and self.step % cancel_check_interval == 0:
self.trainer.thread_pool.submit(self.check_if_canceled)

def post_train(self):
if self.enabled and get_rank() == 0:
log.info("Finalizing successful Comet.ml experiment...")
self.finalize()

def on_error(self, exc: BaseException):
del exc
if self.enabled and get_rank() == 0:
log.warning("Finalizing failed Comet.ml experiment...")
self.exp.add_tag(self.failure_tag)
self.finalize()

def check_if_canceled(self):
if self.enabled and not self.finalized and self.cancel_tags:
try:
tags = self.exp.get_tags()
except Exception as exc:
log.warning(f"Failed to pull tags for Comet.ml experiment:\n{exc}")
return

for tag in tags:
if tag.lower() in self.cancel_tags:
self.trainer.cancel_run("canceled from Comet.ml tag")
return

0 comments on commit 873d4c4

Please sign in to comment.