Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update NeptuneLogger #3165

Merged
merged 25 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
00dbf60
Update NeptuneLogger
AleksanderWWW Mar 25, 2024
05e89ca
better check symlinks
AleksanderWWW Mar 25, 2024
e4eb799
Update composer/loggers/neptune_logger.py
AleksanderWWW Mar 26, 2024
85dfcd9
use progress bar if possible
AleksanderWWW Apr 2, 2024
8e32faf
Merge branch 'dev' into neptune/update-logger
AleksanderWWW Apr 2, 2024
0e76288
simplify imports
AleksanderWWW Apr 2, 2024
45ea3d0
Merge pull request #2 from AleksanderWWW/neptune/update-logger
AleksanderWWW Apr 2, 2024
c8e8d77
update oom callback
AleksanderWWW Apr 2, 2024
96554fa
fix
AleksanderWWW Apr 2, 2024
e05900d
fix typing
AleksanderWWW Apr 2, 2024
0344fbb
Merge branch 'dev' into dev
AleksanderWWW Apr 3, 2024
ab1d630
Merge branch 'dev' into dev
AleksanderWWW Apr 4, 2024
c07ad62
Merge branch 'dev' into dev
AleksanderWWW Apr 10, 2024
3606052
Merge branch 'dev' into dev
mvpatel2000 Apr 10, 2024
f6e1b6b
code review
AleksanderWWW Apr 13, 2024
2b7a3f6
Merge branch 'dev' into dev
AleksanderWWW Apr 13, 2024
ce8df4f
maybe a fix
AleksanderWWW Apr 15, 2024
174ddb6
Apply suggestions from code review
AleksanderWWW Apr 15, 2024
6e4a9b7
Merge branch 'dev' into dev
AleksanderWWW Apr 15, 2024
eb7c8b6
format
AleksanderWWW Apr 15, 2024
4a476e2
Update composer/callbacks/oom_observer.py
mvpatel2000 Apr 15, 2024
d4ad2a6
Update tests/loggers/test_neptune_logger.py
mvpatel2000 Apr 15, 2024
a34e096
Update tests/loggers/test_neptune_logger.py
mvpatel2000 Apr 15, 2024
962139e
Update tests/loggers/test_neptune_logger.py
mvpatel2000 Apr 15, 2024
ef8ee89
Merge branch 'dev' into dev
mvpatel2000 Apr 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions composer/callbacks/oom_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@
# SPDX-License-Identifier: Apache-2.0

"""Generate a memory snapshot during an OutOfMemory exception."""
from __future__ import annotations

import dataclasses
import logging
import os
import pickle
import warnings
from typing import Optional
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

import torch.cuda
from packaging import version

from composer import State
from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import ensure_folder_is_empty, format_name_with_dist, format_name_with_dist_and_time, parse_uri
Expand All @@ -22,6 +25,29 @@
__all__ = ['OOMObserver']


@dataclass(frozen=True)
class SnapshotFileNameConfig:
"""Configuration for the file names of the memory snapshot visualizations."""
snapshot_file: str
trace_plot_file: str
segment_plot_file: str
segment_flamegraph_file: str
memory_flamegraph_file: str

@classmethod
def from_file_name(cls, filename: str) -> 'SnapshotFileNameConfig':
return cls(
snapshot_file=filename + '_snapshot.pickle',
trace_plot_file=filename + '_trace_plot.html',
segment_plot_file=filename + '_segment_plot.html',
segment_flamegraph_file=filename + '_segment_flamegraph.svg',
memory_flamegraph_file=filename + '_memory_flamegraph.svg',
)

def list_filenames(self) -> List[str]:
return [getattr(self, field.name) for field in dataclasses.fields(self)]


class OOMObserver(Callback):
"""Generate visualizations of the state of allocated memory during an OutOfMemory exception.

Expand Down Expand Up @@ -94,6 +120,8 @@ def __init__(
self._enabled = False
warnings.warn('OOMObserver is supported after PyTorch 2.1.0. Disabling OOMObserver callback.')

self.filename_config: Optional[SnapshotFileNameConfig] = None

def init(self, state: State, logger: Logger) -> None:
if not self._enabled:
return
Expand All @@ -117,17 +145,12 @@ def oom_observer(device: int, alloc: int, device_alloc: int, device_free: int):

assert self.filename
assert self.folder_name, 'folder_name must be set in init'
filename = os.path.join(
self.folder_name,
filename = Path(self.folder_name) / Path(
format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=state.timestamp),
)

try:
snapshot_file = filename + '_snapshot.pickle'
trace_plot_file = filename + '_trace_plot.html'
segment_plot_file = filename + '_segment_plot.html'
segment_flamegraph_file = filename + '_segment_flamegraph.svg'
memory_flamegraph_file = filename + '_memory_flamegraph.svg'
self.filename_config = SnapshotFileNameConfig.from_file_name(str(filename))
log.info(f'Dumping OOMObserver visualizations')

snapshot = torch.cuda.memory._snapshot()
Expand All @@ -136,31 +159,26 @@ def oom_observer(device: int, alloc: int, device_alloc: int, device_free: int):
log.info(f'No allocation is recorded in memory snapshot)')
return

with open(snapshot_file, 'wb') as fd:
with open(self.filename_config.snapshot_file, 'wb') as fd:
pickle.dump(snapshot, fd)

with open(trace_plot_file, 'w+') as fd:
with open(self.filename_config.trace_plot_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.trace_plot(snapshot)) # type: ignore

with open(segment_plot_file, 'w+') as fd:
with open(self.filename_config.segment_plot_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.segment_plot(snapshot)) # type: ignore

with open(segment_flamegraph_file, 'w+') as fd:
with open(self.filename_config.segment_flamegraph_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.segments(snapshot)) # type: ignore

with open(memory_flamegraph_file, 'w+') as fd:
with open(self.filename_config.memory_flamegraph_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.memory(snapshot)) # type: ignore

log.info(f'Saved memory visualizations to local files with prefix = {filename} during OOM')

if self.remote_path_in_bucket is not None:
for f in [
snapshot_file,
trace_plot_file,
segment_plot_file,
segment_flamegraph_file,
memory_flamegraph_file,
]:

mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
for f in self.filename_config.list_filenames():
base_file_name = os.path.basename(f)
remote_file_name = os.path.join(self.remote_path_in_bucket, base_file_name)
remote_file_name = remote_file_name.lstrip('/') # remove leading slashes
Expand Down
132 changes: 87 additions & 45 deletions composer/loggers/neptune_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,30 @@
import pathlib
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set, Union
from importlib.metadata import version
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence, Set, Union

import numpy as np
import torch
from packaging.version import Version

from composer._version import __version__
from composer.loggers import LoggerDestination
from composer.utils import MissingConditionalImportError, dist
from composer.utils import MissingConditionalImportError, VersionedDeprecationWarning, dist

if TYPE_CHECKING:
from composer import Logger
from composer.core import State

NEPTUNE_MODE_TYPE = Literal['async', 'sync', 'offline', 'read-only', 'debug']
NEPTUNE_VERSION_WITH_PROGRESS_BAR = Version('1.9.0')


class NeptuneLogger(LoggerDestination):
"""Log to `neptune.ai <https://neptune.ai/>`_.

For more, see the [Neptune-Composer integration guide](https://docs.neptune.ai/integrations/composer/).
For instructions, see the
`integration guide <https://docs.neptune.ai/integrations/mosaicml_composer/>`_.

Args:
project (str, optional): The name of your Neptune project,
Expand All @@ -36,16 +42,15 @@ class NeptuneLogger(LoggerDestination):
You can leave out this argument if you save your token to the
``NEPTUNE_API_TOKEN`` environment variable (recommended).
You can find your API token in the user menu of the Neptune web app.
rank_zero_only (bool, optional): Whether to log only on the rank-zero process.
(default: ``True``).
upload_artifacts (bool, optional): Whether the logger should upload artifacts to Neptune.
rank_zero_only (bool): Whether to log only on the rank-zero process (default: ``True``).
upload_artifacts (bool, optional): Deprecated. See ``upload_checkpoints``.
upload_checkpoints (bool): Whether the logger should upload checkpoints to Neptune
(default: ``False``).
base_namespace (str, optional): The name of the base namespace to log the metadata to.
(default: "training").
base_namespace (str, optional): The name of the base namespace where the metadata
is logged (default: "training").
neptune_kwargs (Dict[str, Any], optional): Any additional keyword arguments to the
``neptune.init_run()`` function. For options, see the
`Run API reference <https://docs.neptune.ai/api/neptune/#init_run>`_ in the
Neptune docs.
`Run API reference <https://docs.neptune.ai/api/neptune/#init_run>`_.
"""
metric_namespace = 'metrics'
hyperparam_namespace = 'hyperparameters'
Expand All @@ -58,8 +63,10 @@ def __init__(
project: Optional[str] = None,
api_token: Optional[str] = None,
rank_zero_only: bool = True,
upload_artifacts: bool = False,
upload_artifacts: Optional[bool] = None,
upload_checkpoints: bool = False,
base_namespace: str = 'training',
mode: Optional[NEPTUNE_MODE_TYPE] = None,
**neptune_kwargs,
) -> None:
try:
Expand All @@ -74,7 +81,8 @@ def __init__(
verify_type('project', project, (str, type(None)))
verify_type('api_token', api_token, (str, type(None)))
verify_type('rank_zero_only', rank_zero_only, bool)
verify_type('upload_artifacts', upload_artifacts, bool)
verify_type('upload_artifacts', upload_artifacts, (bool, type(None)))
verify_type('upload_checkpoints', upload_checkpoints, bool)
verify_type('base_namespace', base_namespace, str)

if not base_namespace:
Expand All @@ -83,15 +91,19 @@ def __init__(
self._project = project
self._api_token = api_token
self._rank_zero_only = rank_zero_only
self._upload_artifacts = upload_artifacts

if upload_artifacts is not None:
_warn_about_deprecated_upload_artifacts()
self._upload_checkpoints = upload_artifacts
else:
self._upload_checkpoints = upload_checkpoints

self._base_namespace = base_namespace
self._neptune_kwargs = neptune_kwargs

mode = self._neptune_kwargs.pop('mode', 'async')

self._enabled = (not rank_zero_only) or dist.get_global_rank() == 0

self._mode = mode if self._enabled else 'debug'
self._mode: Optional[NEPTUNE_MODE_TYPE] = mode if self._enabled else 'debug'

self._neptune_run = None
self._base_handler = None
Expand All @@ -104,17 +116,8 @@ def __init__(
def neptune_run(self):
"""Gets the Neptune run object from a NeptuneLogger instance.

You can log additional metadata to the run by accessing a path inside the run and assigning metadata to it
with "=" or [Neptune logging methods](https://docs.neptune.ai/logging/methods/).

Example:
from composer import Trainer
from composer.loggers import NeptuneLogger
neptune_logger = NeptuneLogger()
trainer = Trainer(loggers=neptune_logger, ...)
trainer.fit()
neptune_logger.neptune_run["some_metric"] = 1
trainer.close()
To log additional metadata to the run, access a path inside the run and assign metadata
with ``=`` or other `Neptune logging methods <https://docs.neptune.ai/logging/methods/>`_.
"""
from neptune import Run

Expand All @@ -131,19 +134,10 @@ def neptune_run(self):
def base_handler(self):
"""Gets a handler for the base logging namespace.

Use the handler to log extra metadata to the run and organize it under the base namespace (default: "training").
You can operate on it like a run object: Access a path inside the handler and assign metadata to it with "=" or
other [Neptune logging methods](https://docs.neptune.ai/logging/methods/).

Example:
from composer import Trainer
from composer.loggers import NeptuneLogger
neptune_logger = NeptuneLogger()
trainer = Trainer(loggers=neptune_logger, ...)
trainer.fit()
neptune_logger.base_handler["some_metric"] = 1
trainer.close()
Result: The value `1` is organized under "training/some_metric" inside the run.
Use the handler to log extra metadata to the run and organize it under the base namespace
(default: "training"). You can operate on it like a run object: Access a path inside the
handler and assign metadata to it with ``=`` or other
`Neptune logging methods <https://docs.neptune.ai/logging/methods/>`_.
"""
return self.neptune_run[self._base_namespace]

Expand Down Expand Up @@ -213,7 +207,7 @@ def log_traces(self, traces: Dict[str, Any]):

def can_upload_files(self) -> bool:
"""Whether the logger supports uploading files."""
return self._enabled and self._upload_artifacts
return self._enabled and self._upload_checkpoints

def upload_file(
self,
Expand All @@ -226,6 +220,9 @@ def upload_file(
if not self.can_upload_files():
return

if file_path.is_symlink() or file_path.suffix.lower() == '.symlink':
return # skip symlinks

neptune_path = f'{self._base_namespace}/{remote_file_name}'
if self.neptune_run.exists(neptune_path) and not overwrite:

Expand All @@ -236,7 +233,11 @@ def upload_file(
return

del state # unused
self.base_handler[remote_file_name].upload(str(file_path))

from neptune.types import File

with open(str(file_path), 'rb') as fp:
self.base_handler[remote_file_name] = File.from_stream(fp, extension=file_path.suffix)

def download_file(
self,
Expand All @@ -245,7 +246,6 @@ def download_file(
overwrite: bool = False,
progress_bar: bool = True,
):
del progress_bar # not supported

if not self._enabled:
return
Expand All @@ -266,7 +266,11 @@ def download_file(
if not self.neptune_run.exists(file_path):
raise FileNotFoundError(f'File {file_path} not found')

self.base_handler[remote_file_name].download(destination=destination)
if _is_progress_bar_enabled():
self.base_handler[remote_file_name].download(destination=destination, progress_bar=progress_bar)
else:
del progress_bar
self.base_handler[remote_file_name].download(destination=destination)

def log_images(
self,
Expand Down Expand Up @@ -312,4 +316,42 @@ def _validate_image(img: Union[np.ndarray, torch.Tensor], channels_last: bool) -
if not channels_last:
img_numpy = np.moveaxis(img_numpy, 0, -1)

return img_numpy
return _validate_image_value_range(img_numpy)


def _validate_image_value_range(img: np.ndarray) -> np.ndarray:
array_min = img.min()
array_max = img.max()

if (array_min >= 0 and 1 < array_max <= 255) or (array_min >= 0 and array_max <= 1):
return img

from neptune.common.warnings import NeptuneWarning, warn_once

warn_once(
'Image value range is not in the expected range of [0.0, 1.0] or [0, 255]. '
'This might be due to the presence of `transforms.Normalize` in the data pipeline. '
'Logged images may not display correctly in Neptune.',
exception=NeptuneWarning,
)

return _scale_image_to_0_255(img, array_min, array_max)


def _scale_image_to_0_255(img: np.ndarray, array_min: Union[int, float], array_max: Union[int, float]) -> np.ndarray:
scaled_image = 255 * (img - array_min) / (array_max - array_min)
return scaled_image.astype(np.uint8)


def _warn_about_deprecated_upload_artifacts() -> None:
warnings.warn(
VersionedDeprecationWarning(
'The \'upload_artifacts\' parameter is deprecated and will be removed in the next version. '
'Use the \'upload_checkpoints\' parameter instead.',
remove_version='0.23',
),
)


def _is_progress_bar_enabled() -> bool:
return Version(version('neptune')) >= NEPTUNE_VERSION_WITH_PROGRESS_BAR
3 changes: 3 additions & 0 deletions docs/source/doctest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@
# Disable wandb
os.environ['WANDB_MODE'] = 'disabled'

# Disable neptune
os.environ['NEPTUNE_MODE'] = 'debug'

# Change the cwd to be the tempfile, so we don't pollute the documentation source folder
tmpdir = tempfile.mkdtemp()
cwd = os.path.abspath('.')
Expand Down
Loading
Loading