Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable universal checkpoint for zero stage 1 #4516

Merged
merged 14 commits into from
Oct 25, 2023
295 changes: 295 additions & 0 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
#!/usr/bin/env python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from functools import partial
import argparse
import glob
import itertools
import multiprocessing
import os
import re
import shutil
import torch
import tqdm
# from pprint import pprint

from deepspeed.checkpoint import DeepSpeedCheckpoint
from deepspeed.checkpoint import (
OPTIMIZER_STATE_DICT,
BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS,
PARAM_SLICE_MAPPINGS,
PARAM_SHAPES,
PARAM,
CAT_DIM,
VOCAB_DIVISIBILITY_PADDING_TENSOR,
ORIGINAL_VOCAB_SIZE,
UNIVERSAL_CHECKPOINT_INFO,
VOCABULARY_PARAMETER_PATTERNS,
PIPELINE_REPLICATED_PARAMETER_PATTERNS,
PARAMETER_TO_AVERAGE_PATTERNS,
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS,
)


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--input_folder', type=str, required=True, help='Input DeepSpeed Checkpoint folder')
parser.add_argument('--output_folder', type=str, required=True, help='Output DeepSpeed checkpoint folder')
parser.add_argument('--num_extract_workers',
default=4,
type=int,
help='How many parallel processes to extract zero shards')
parser.add_argument(
'--num_merge_workers',
default=2,
type=int,
help=
'How many parallel processes to merge tp slices (more memory intensive, use much fewer than --num_extract_workers))'
)
parser.add_argument('--keep_temp_folder',
action='store_true',
help='Preserve temporary folder of intermediate checkpoint slice files. Useful for debugging.')
args = parser.parse_args()
print(f'args = {args}')
return args


def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree):
path_list = []
iter_folder = f'iter_{iteration:07d}'
for i in range(0, tp_degree):
path_list.append([])
for j in range(0, pp_degree):
rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}'
ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt')
path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path))

return path_list


def _save_checkpoint(file_path, chkpt_sd):
dir, _ = os.path.split(file_path)
os.makedirs(dir, exist_ok=True)
torch.save(chkpt_sd, file_path)


def extract_zero_shards(dir, ds_checkpoint, indices_3D):
pp_index, tp_index, dp_index = indices_3D
sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index)

# pprint(f"Processing {dp_index=} {pp_index=}, {tp_index=}")

optim_sd = sd[OPTIMIZER_STATE_DICT]
param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS]
universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
pipeline_replicated_params = universal_checkpoint_info.get(PIPELINE_REPLICATED_PARAMETER_PATTERNS, [])
# print(f'{pipeline_replicated_params=}')

# dict
state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"]
# list
fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS]
param_groups_cnt = len(state_groups)

for param_group_id in range(param_groups_cnt):

flat_state = dict(
exp_avg=state_groups[param_group_id]["exp_avg"],
exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"],
fp32=fp32_groups[param_group_id],
)

for name, fragment_mapping in param_slice_mappings[param_group_id].items():
if pp_index > 0 and any(re.match(pattern, name) for pattern in pipeline_replicated_params):
# Skip tied weights that are replicated in first and last pp stages
continue

# pprint(f"dpt{dp_index}{pp_index}{tp_index} {param_group_id} {name} => {fragment_mapping.start}:{fragment_mapping.numel}")
for state_key in flat_state.keys():
dump_param_fragment(dir, tp_index, dp_index, state_key, flat_state[state_key], name,
fragment_mapping.start, fragment_mapping.numel)


cnt = 0


def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel):

global cnt # temp hack

param_base_path = os.path.join(dir, param_name, str(tp_index))
os.makedirs(param_base_path, exist_ok=True)

cnt += 1
counter = f"{dp_index:0>2d}"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dear @tjruwase ,

I'm currently examining a scenario where the maximum dp_index is 127, which, according to alphabetical order, is considered smaller than 13. This raises a question regarding the tensor sorting process in line 144 of our code. Given this context, could there potentially be any issues with how the tensors are sorted due to this ordering?

I appreciate your insight on this matter.

Best regards,
Junfeng

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A specific example is as follows:

./temp/model.transformer_encoder.layers.19.self_attn.in_proj_weight => fp32.100, torch.Size([2187604])
./temp/model.transformer_encoder.layers.19.self_attn.in_proj_weight => fp32.99, torch.Size([48144044])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rgtjf, thanks for sharing this issue. Do you mind creating a new ticket for it? I can see that line 130 may not generalize to larger scales. It would be great if you could share more details in a new ticket. Thanks!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjruwase I've opened a ticket to track the issue, #5283. Should there be any missing details or additional information required, please don't hesitate to let me know.


path = os.path.join(param_base_path, f"{state_name}.{counter}")

#print(f"{param_name}: {offset}: {numel} => {path}")

t = state_flat_tensor.narrow(0, offset, numel).clone()
_save_checkpoint(path, t)


def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
slices = []
for tp_index in range(tp_degree):
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
paths = sorted(list(glob.glob(f"{prefix_path}.*")))
shards = [torch.load(p) for p in paths]
slice = torch.cat(shards, dim=0).reshape(slice_shape)
slices.append(slice)

return slices


def _get_vocab_divisibility_padding_tensor(ds_checkpoint, padded_vocab_tensor):
checkpoint_info = ds_checkpoint.get_checkpoint_info()
if padded_vocab_tensor.shape[0] > checkpoint_info[ORIGINAL_VOCAB_SIZE]:
return padded_vocab_tensor[-1]
else:
return torch.zeros(padded_vocab_tensor.shape[1])


def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape):
name, shape = name_and_shape
slice_base_path = os.path.join(slice_dir, name)
param_base_path = os.path.join(dir, name)

universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
parameters_to_average = universal_checkpoint_info.get(PARAMETER_TO_AVERAGE_PATTERNS, [])
parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, [])
vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, [])
for state in ("fp32", "exp_avg", "exp_avg_sq"):
slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape)
final_path = os.path.join(param_base_path, f"{state}.pt")

#print(f"Expected shape: {shape}")
#print(f"Fragment sizes:", list(frag.shape for frag in slices))
ckpt_dict = {}
if any(re.match(pattern, name) for pattern in parameters_to_average):
param = sum(slices) / len(slices)
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
# print(f'merge {name} using average')
else:
cat_dim = 1 if any(re.match(pattern, name) for pattern in parameters_with_row_parallelism) else 0
# print(f"merge {name} with CAT DIM: {cat_dim}")
param = torch.cat(slices, dim=cat_dim)
ckpt_dict[CAT_DIM] = cat_dim

if any(re.match(pattern, name) for pattern in vocabulary_parameters):
#print(f"Before {param.shape=}")
# strip padding
#param = _strip_vocab_padding(ds_checkpoint, param)
ckpt_dict[VOCAB_DIVISIBILITY_PADDING_TENSOR] = _get_vocab_divisibility_padding_tensor(ds_checkpoint, param)
#print(f"After {param.shape=}")

#print(f"Final shape: {param.shape}")
ckpt_dict[PARAM] = param
_save_checkpoint(final_path, ckpt_dict)


def _get_chunks(l, n):
for i in range(0, len(l), n):
yield l[i:i + n]


def _do_parallel_work(do_work, work_chunks, num_workers):
pool = multiprocessing.Pool(num_workers)
for batch in tqdm.tqdm(work_chunks):
pool.map(do_work, batch)
pool.close()
pool.join()


def _extract_zero_shard_files(args, ds_checkpoint, temp_dir):
_3d_range_list = list(
itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree),
range(ds_checkpoint.dp_degree)))
# pprint(f'{_3d_range_list=}')
work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers))
# pprint(f'{work_chunks=}')

# extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0])
do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint)
_do_parallel_work(do_work, work_chunks, args.num_extract_workers)


def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
work_chunks = list(_get_chunks(list(slice_shapes.items()), args.num_merge_workers))
#pprint(work_chunks)
zero_output_folder = os.path.join(args.output_folder, "zero")
do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree)
_do_parallel_work(do_work, work_chunks, args.num_merge_workers)


def _save_optimizer_state(args, ds_checkpoint):
sharded_states = [BASE_OPTIMIZER_STATE, PARAM_SLICE_MAPPINGS, SINGLE_PARTITION_OF_FP32_GROUPS]
sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=0, tp_index=0, dp_index=0)

optim_sd = sd[OPTIMIZER_STATE_DICT]
output_sd = {k: v for k, v in optim_sd.items() if k not in sharded_states}
zero_output_folder = os.path.join(args.output_folder, "zero")
output_file_path = os.path.join(zero_output_folder, f"optimizer_state.pt")
_save_checkpoint(output_file_path, output_sd)


def _check_for_required_state(ds_checkpoint):
universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO)
assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.'


def main():
print(f'Convert DeepSpeed Checkpoint to Universal Checkpoint')

args = parse_arguments()
print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}')

ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)
_check_for_required_state(ds_checkpoint)

iteration = ds_checkpoint.get_iteration()
#_create_latest_file(args.output_folder, iteration)
checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree,
ds_checkpoint.pp_degree)

slice_shapes = []
for mp_rank_file in ds_checkpoint.mp_rank_files:
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu'))
slice_shapes += mp_sd[PARAM_SHAPES]

# fix back to normal flat dict, merge duplicates for tp>1
slice_shapes = dict((k, v) for d in slice_shapes for k, v in d.items())
temp_dir = os.path.join(args.output_folder, 'tmp')

print('*** 1. Extracting ZeRO fragments')
_extract_zero_shard_files(args, ds_checkpoint, temp_dir)

print('*** 2. Merging slices')
_merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir)

print('*** 3. Saving common optimizer states')
_save_optimizer_state(args, ds_checkpoint)

if not args.keep_temp_folder:
shutil.rmtree(temp_dir, ignore_errors=True)

# Copy mp* files into output folder
for f in glob.glob(os.path.join(args.input_folder, 'mp*')):
shutil.copy2(f, args.output_folder)

# Update latest to output folder
checkpoint_root_folder, step_folder = os.path.split(args.output_folder)
latest_file = os.path.join(checkpoint_root_folder, 'latest_universal')
with open(latest_file, "w") as f:
f.write(step_folder)

print('*** Done!')


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2712,6 +2712,9 @@ def load_checkpoint(self,
if self._optimizer_has_ckpt_event_epilogue():
self.optimizer.checkpoint_event_epilogue()

if self.load_universal_checkpoint():
self.optimizer.update_lp_params()

return load_path, client_states

def _load_checkpoint(self,
Expand Down
25 changes: 14 additions & 11 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def build(self, log=False):

class TiedLayerSpec(LayerSpec):

def __init__(self, key, typename, *module_args, forward_fn=None, tied_weight_attr='weight', **module_kwargs):
def __init__(self, key, typename, *module_args, forward_fn=None, tied_weight_attr=['weight'], **module_kwargs):
super().__init__(typename, *module_args, **module_kwargs)
self.key = key
self.forward_fn = forward_fn
self.tied_weight_attr = tied_weight_attr
self.tied_weight_attr = [tied_weight_attr] if type(tied_weight_attr) == str else tied_weight_attr


class PipelineModule(nn.Module):
Expand Down Expand Up @@ -423,23 +423,26 @@ def _partition_layers(self, method='uniform'):
def allreduce_tied_weight_gradients(self):
'''All reduce the gradients of the tied weights between tied stages'''
for key, comm in self.tied_comms.items():
weight = getattr(self.tied_modules[key], comm['weight_attr'])
dist.all_reduce(weight.grad, group=comm['group'])
for attr_name in comm['weight_attr']:
weight = getattr(self.tied_modules[key], attr_name)
dist.all_reduce(weight.grad, group=comm['group'])

def get_tied_weights_and_groups(self):
weight_group_list = []
for key, comm in self.tied_comms.items():
weight = getattr(self.tied_modules[key], comm['weight_attr'])
weight_group_list.append((weight, comm['group']))
for attr_name in comm['weight_attr']:
weight = getattr(self.tied_modules[key], attr_name)
weight_group_list.append((weight, comm['group']))
return weight_group_list

def _synchronize_tied_weights(self):
for key, comm in self.tied_comms.items():
dist.broadcast(
getattr(comm['module'], comm['weight_attr']),
src=min(comm['ranks']),
group=comm['group'],
)
for attr_name in comm['weight_attr']:
dist.broadcast(
getattr(comm['module'], attr_name),
src=min(comm['ranks']),
group=comm['group'],
)

def _index_tied_modules(self):
''' Build communication structures for tied modules. '''
Expand Down
Loading