Skip to content

Commit

Permalink
merged dist group shutdown fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jbieniusiewi committed Oct 16, 2024
2 parents 20525df + bef59a1 commit a4574b8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
10 changes: 6 additions & 4 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ jobs:
test_type: ['fault_tolerance', 'straggler', 'ptl_resiliency']
container:
image: ${{ matrix.container }}
env:
MKL_SERVICE_FORCE_INTEL: 1 # Fix for "MKL_THREADING_LAYER=INTEL is incompatible with libgomp.so.1 library."
steps:
- name: Checkout code
uses: actions/checkout@v4
Expand All @@ -71,20 +73,20 @@ jobs:
path: ./dist/
- name: Set up environment
run: |
pip install pytest lightning pytest-rerunfailures
pip install pytest lightning
PY_VER_NODOT=$(python -c"import sysconfig; print(sysconfig.get_config_var('py_version_nodot'))")
pip install ./dist/nvidia_resiliency_ext-*-cp${PY_VER_NODOT}-*.whl
- name: Run unit tests
shell: bash
run: |
if [[ "${{ matrix.test_type }}" == "straggler" ]]; then
pytest --reruns 3 --maxfail=1 -s -vvv -m "not gpu" ./tests/straggler/unit/
pytest -s -vvv -m "not gpu" ./tests/straggler/unit/
exit 0
elif [[ "${{ matrix.test_type }}" == "ptl_resiliency" ]]; then
pytest --reruns 3 --maxfail=1 -s -vvv -m "not gpu" ./tests/ptl_resiliency/unit/
pytest -s -vvv -m "not gpu" ./tests/ptl_resiliency/unit/
exit 0
elif [[ "${{ matrix.test_type }}" == "fault_tolerance" ]]; then
pytest --reruns 3 --maxfail=1 -s -vvv -m "not gpu" ./tests/fault_tolerance/unit/
pytest -s -vvv -m "not gpu" ./tests/fault_tolerance/unit/
exit 0
else
echo "Unknown test type: ${{ matrix.test_type }}"
Expand Down
5 changes: 5 additions & 0 deletions tests/fault_tolerance/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import contextlib
import gc
import os
import socket
import sys
Expand Down Expand Up @@ -113,6 +114,10 @@ def distributed_worker(

worker_fn(**kwargs)

# `destroy_process_group` hangs were observed in CI
# use GC collect and barrier to mitigate the issue
gc.collect()
torch.distributed.barrier()
torch.distributed.destroy_process_group()

sys.exit(0)
Expand Down
5 changes: 5 additions & 0 deletions tests/straggler/unit/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import contextlib
import gc
import os
import socket
import sys
Expand Down Expand Up @@ -113,6 +114,10 @@ def distributed_worker(

worker_fn(**kwargs)

# `destroy_process_group` hangs were observed in CI
# use GC collect and barrier to mitigate the issue
gc.collect()
torch.distributed.barrier()
torch.distributed.destroy_process_group()

sys.exit(0)
Expand Down

0 comments on commit a4574b8

Please sign in to comment.