Skip to content

Commit

Permalink
Merge branch 'master' into readme/pip-badge
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Aug 18, 2023
2 parents b87af9d + 2ca1571 commit d13cd8a
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/source-pytorch/upgrade/sections/1_8_regular.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
- `PR12804`_

* - used ``Trainer.reset_train_val_dataloaders()``
- call ``Trainer.reset_train_dataloaders()`` and ``Trainer.reset_val_dataloaders()`` separately
- call ``Trainer.fit_loop.setup_data()``
- `PR12184`_

* - imported ``pl.core.lightning``
Expand Down
25 changes: 25 additions & 0 deletions src/lightning/fabric/strategies/launchers/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
from dataclasses import dataclass
from multiprocessing.queues import SimpleQueue
from textwrap import dedent
from typing import Any, Callable, Dict, Literal, Optional, TYPE_CHECKING

import torch
Expand Down Expand Up @@ -91,6 +92,8 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
"""
if self._start_method in ("fork", "forkserver"):
_check_bad_cuda_fork()
if self._start_method == "spawn":
_check_missing_main_guard()

# The default cluster environment in Lightning chooses a random free port number
# This needs to be done in the main process here before starting processes to ensure each rank will connect
Expand Down Expand Up @@ -216,3 +219,25 @@ def unshare(module: Module) -> Module:
return module

return apply_to_collection(data, function=unshare, dtype=Module)


def _check_missing_main_guard() -> None:
"""Raises an exception if the ``__name__ == "__main__"`` guard is missing."""
if not getattr(mp.current_process(), "_inheriting", False):
return
message = dedent(
"""
Launching multiple processes with the 'spawn' start method requires that your script guards the main
function with an `if __name__ == \"__main__\"` clause. For example:
def main():
# Put your code here
...
if __name__ == "__main__":
main()
Alternatively, you can run with `strategy="ddp"` to avoid this error.
"""
)
raise RuntimeError(message)
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
from torch import Tensor

import lightning.pytorch as pl
from lightning.fabric.strategies.launchers.multiprocessing import _check_bad_cuda_fork, _disable_module_memory_sharing
from lightning.fabric.strategies.launchers.multiprocessing import (
_check_bad_cuda_fork,
_check_missing_main_guard,
_disable_module_memory_sharing,
)
from lightning.fabric.utilities import move_data_to_device
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states
from lightning.fabric.utilities.types import _PATH
Expand Down Expand Up @@ -99,6 +103,8 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
"""
if self._start_method in ("fork", "forkserver"):
_check_bad_cuda_fork()
if self._start_method == "spawn":
_check_missing_main_guard()

# The default cluster environment in Lightning chooses a random free port number
# This needs to be done in the main process here before starting processes to ensure each rank will connect
Expand Down
15 changes: 13 additions & 2 deletions tests/tests_fabric/strategies/launchers/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def test_forking_on_unsupported_platform(_):

@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
def test_start_method(mp_mock, start_method):
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing._check_missing_main_guard")
def test_start_method(_, mp_mock, start_method):
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
launcher.launch(function=Mock())
Expand All @@ -51,7 +52,8 @@ def test_start_method(mp_mock, start_method):

@pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))])
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
def test_restore_globals(mp_mock, start_method):
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing._check_missing_main_guard")
def test_restore_globals(_, mp_mock, start_method):
"""Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'."""
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
Expand Down Expand Up @@ -94,3 +96,12 @@ def test_check_for_bad_cuda_fork(mp_mock, _, start_method):
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
with pytest.raises(RuntimeError, match="Lightning can't create new processes if CUDA is already initialized"):
launcher.launch(function=Mock())


def test_check_for_missing_main_guard():
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn")
with mock.patch(
"lightning.fabric.strategies.launchers.multiprocessing.mp.current_process",
return_value=Mock(_inheriting=True), # pretend that main is importing itself
), pytest.raises(RuntimeError, match="requires that your script guards the main"):
launcher.launch(function=Mock())
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,12 @@ def test_memory_sharing_disabled():

trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp_spawn", max_steps=0)
trainer.fit(model)


def test_check_for_missing_main_guard():
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn")
with mock.patch(
"lightning.pytorch.strategies.launchers.multiprocessing.mp.current_process",
return_value=Mock(_inheriting=True), # pretend that main is importing itself
), pytest.raises(RuntimeError, match="requires that your script guards the main"):
launcher.launch(function=Mock())

0 comments on commit d13cd8a

Please sign in to comment.