Skip to content

Commit

Permalink
Modifications to resumption shared memory allowing load_state_dict
Browse files Browse the repository at this point in the history
…multiple times. (#593)

* shared mem modification on load state dict

* shared mem modification because it doesn't exist...

* write to shared mem

* write to shared mem

* cleaning up shm

* cleaning up shm

* shm logs

* create new shm

* get shm

* get shm

* shm cleanup and deletion...

* shm cleanup and deletion...

* shm cleanup and deletion...

* shm cleanup and deletion...

* shm cleanup and deletion...

* buffer decoding

* correct deletion?

* possibly correct creation

* logging

* prev shm

* trying new approach

* trying new approach

* trying newer approach

* actually getting the size of shm

* actually making dat shm correctly

* logging

* shm creation....hmm

* shm creation, no second create

* shm only on local leader

* shm fixed size

* write null byte

* writing whole string

* new shm tests

* logging

* linting

* changed error message

* using mmap PAGESIZE now

* added clarifying comments

* corrected test

* corrected test

* made test even larger

* made test even larger

* state dict corrupted message
  • Loading branch information
snarayan21 authored Feb 7, 2024
1 parent bc72659 commit 2ef0729
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 2 deletions.
27 changes: 25 additions & 2 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import json
import logging
import mmap
import os
import sys
import warnings
Expand Down Expand Up @@ -804,12 +805,34 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None:
Args:
obj (Dict[str, Any]): The state.
"""
# Set shared memory block to be 1024 characters long. This enables calling
# `load_state_dict` multiple times without needing to resize the shared memory block.
# Resizing the shared memory block is not possible, and closing the shared memory block
# and replacing it with a new one is causing great difficulties.

name = _get_path(self._shm_prefix_int, RESUME)
data = json.dumps(obj, sort_keys=True).encode('utf-8')
data = json.dumps(obj, sort_keys=True)

len_needed = len(data)
# Note: mmap.PAGESIZE has a minimum size of 4096 bytes across systems. For reference,
# see the link below:
# https://en.wikipedia.org/wiki/Page_(computer_memory)#Multiple_page_sizes
if len_needed > mmap.PAGESIZE:
raise ValueError(
f'The StreamingDataset state dict for resumption is currently ',
f'allocated {mmap.PAGESIZE} bytes, insufficient to store the ',
f'state dict that was attempted to load in, which uses {len_needed} ',
f'bytes. Please increase the bytes allocated to the state dict by ',
f'changing the SharedMemory size parameter, set in this function.',
f'The state dict may also be corrupted. The state dict is: {data}.')
# Some platforms choose to allocate chunks of memory based upon that platform's memory page
# size, hence the exact size of the shared memory block that was returned may be larger
# than what was requested.
self._resume_shm = SharedMemory(name=name, size=len(data))
self._resume_shm = SharedMemory(name=name, size=mmap.PAGESIZE)
# Write a null byte at the end of the shared memory block so that we read in the state
# dict correctly in `_resume`.
data += '\0'
data = data.encode('utf-8')
self._resume_shm.buf[:len(data)] = data

def resample_streams(
Expand Down
106 changes: 106 additions & 0 deletions tests/test_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

import pytest

from streaming.base import StreamingDataset
from streaming.base.shared import get_shm_prefix
from streaming.base.world import World
from tests.common.utils import convert_to_mds


@pytest.mark.usefixtures('local_remote_dir')
Expand Down Expand Up @@ -41,3 +43,107 @@ def test_same_local_remote_none(local_remote_dir: Tuple[str, str]):
local, _ = local_remote_dir
_, _ = get_shm_prefix(streams_local=[local], streams_remote=[None], world=World())
_, _ = get_shm_prefix(streams_local=[local], streams_remote=[None], world=World())


@pytest.mark.parametrize('from_beginning', [True, False])
@pytest.mark.usefixtures('local_remote_dir')
def test_load_get_state_dict_once(local_remote_dir: Tuple[str, str], from_beginning: bool):
local, remote = local_remote_dir
convert_to_mds(out_root=remote,
dataset_name='sequencedataset',
num_samples=117,
size_limit=1 << 8)
dataset = StreamingDataset(local=local, remote=remote)

# Get the current dataset state dict
old_state_dict = dataset.state_dict(0, from_beginning)
assert old_state_dict is not None

state_keys = list(old_state_dict.keys())

# Change the state dict and load it back to the dataset.
new_state_dict = old_state_dict.copy()
for key in state_keys:
new_state_dict[key] += 1
dataset.load_state_dict(new_state_dict)

new_loaded_state_dict = dataset.state_dict(0, from_beginning)
assert new_loaded_state_dict is not None
if from_beginning:
for key in state_keys:
if key == 'sample_in_epoch':
# If `from_beginning` is True, we expect sample_in_epoch to be 0.
assert new_loaded_state_dict[key] == 0
else:
# All other fields in retrieved and loaded state dicts should match.
assert new_loaded_state_dict[key] == new_state_dict[key]
else:
# If `from_beginning` is False, retrieved and loaded state dicts should match completely.
assert new_loaded_state_dict == new_state_dict

for key in state_keys:
if key == 'sample_in_epoch' and from_beginning:
# If `from_beginning` is True, we expect sample_in_epoch to be the same, 0.
assert new_loaded_state_dict[key] == old_state_dict[key]
else:
assert new_loaded_state_dict[key] == old_state_dict[key] + 1


@pytest.mark.parametrize('iterations', [10])
@pytest.mark.usefixtures('local_remote_dir')
def test_load_get_state_dict_multiple(local_remote_dir: Tuple[str, str], iterations: int):
local, remote = local_remote_dir
convert_to_mds(out_root=remote,
dataset_name='sequencedataset',
num_samples=117,
size_limit=1 << 8)
dataset = StreamingDataset(local=local, remote=remote)

# Get the current dataset state dict
old_state_dict = dataset.state_dict(0, False)
assert old_state_dict is not None

state_keys = list(old_state_dict.keys())

for _ in range(iterations):
# Change the state dict and load it back to the dataset.
new_state_dict = old_state_dict.copy()
for key in state_keys:
# If the epoch from the loaded state dict is -1, make sure that the new epoch
# is greater than -1. Otherwise, we will assume a stale resumption state, ignoring it.
if key == 'epoch' and new_state_dict[key] < 0:
new_state_dict[key] *= -5
else:
new_state_dict[key] *= 5

dataset.load_state_dict(new_state_dict)
new_loaded_state_dict = dataset.state_dict(0, False)

assert new_loaded_state_dict is not None
assert new_loaded_state_dict == new_state_dict
for key in state_keys:
# Ensure we check that epoch has been correctly updated, in case it was negative.
if key == 'epoch' and old_state_dict[key] < 0:
assert new_loaded_state_dict[key] == old_state_dict[key] * -5
else:
assert new_loaded_state_dict[key] == old_state_dict[key] * 5

old_state_dict = new_loaded_state_dict


@pytest.mark.usefixtures('local_remote_dir')
def test_state_dict_too_large(local_remote_dir: Tuple[str, str]):
local, remote = local_remote_dir
convert_to_mds(out_root=remote,
dataset_name='sequencedataset',
num_samples=117,
size_limit=1 << 8)
dataset = StreamingDataset(local=local, remote=remote)

# Make a state dict that is too large to fit in the allocated shared memory.
import mmap
key = 'a' * mmap.PAGESIZE
big_state_dict = {key: 1}

with pytest.raises(ValueError, match='The StreamingDataset state dict*'):
dataset.load_state_dict(big_state_dict)

0 comments on commit 2ef0729

Please sign in to comment.