-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable universal checkpoint for zero stage 1 #4516
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
2a60f79
Enable uni_ckpt for z1
tjruwase 3dc989e
Merge branch 'master' into olruwase/ds_2921
tjruwase b13006b
Remove logging fix to seperate PR. Relocate conversion script to avoi…
tjruwase 64d8c0d
Formatting fix
tjruwase f21a5de
Merge branch 'master' into olruwase/ds_2921
tjruwase f5c6b2d
PR feedback
tjruwase 51b3af8
Merge branch 'olruwase/ds_2921' of github.com:microsoft/DeepSpeed int…
tjruwase d737cbc
Handle replicated params
tjruwase 3b9a384
Merge branch 'master' into olruwase/ds_2921
tjruwase d1cefd6
Detect bf16_optimizer
tjruwase 507fee8
Merge branch 'master' into olruwase/ds_2921
tjruwase f25ff5b
Docs
tjruwase f638d92
Merge branch 'master' of github.com:microsoft/DeepSpeed into olruwase…
tjruwase a1c41e0
Fix docs
tjruwase File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,303 @@ | ||
#!/usr/bin/env python | ||
|
||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from functools import partial | ||
import argparse | ||
import glob | ||
import itertools | ||
import multiprocessing | ||
import os | ||
import re | ||
import shutil | ||
import torch | ||
import tqdm | ||
# from pprint import pprint | ||
|
||
from deepspeed.checkpoint import DeepSpeedCheckpoint | ||
from deepspeed.checkpoint import ( | ||
OPTIMIZER_STATE_DICT, | ||
BASE_OPTIMIZER_STATE, | ||
SINGLE_PARTITION_OF_FP32_GROUPS, | ||
PARAM_SLICE_MAPPINGS, | ||
PARAM_SHAPES, | ||
PARAM, | ||
CAT_DIM, | ||
VOCAB_DIVISIBILITY_PADDING_TENSOR, | ||
ORIGINAL_VOCAB_SIZE, | ||
UNIVERSAL_CHECKPOINT_INFO, | ||
VOCABULARY_PARAMETER_PATTERNS, | ||
PIPELINE_REPLICATED_PARAMETER_PATTERNS, | ||
TP_REPLICATED_PARAMETER_PATTERNS, | ||
PARAMETER_TO_AVERAGE_PATTERNS, | ||
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, | ||
) | ||
|
||
|
||
def parse_arguments(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--input_folder', type=str, required=True, help='Input DeepSpeed Checkpoint folder') | ||
parser.add_argument('--output_folder', type=str, required=True, help='Output DeepSpeed checkpoint folder') | ||
parser.add_argument('--num_extract_workers', | ||
default=4, | ||
type=int, | ||
help='How many parallel processes to extract zero shards') | ||
parser.add_argument( | ||
'--num_merge_workers', | ||
default=2, | ||
type=int, | ||
help= | ||
'How many parallel processes to merge tp slices (more memory intensive, use much fewer than --num_extract_workers))' | ||
) | ||
parser.add_argument('--keep_temp_folder', | ||
action='store_true', | ||
help='Preserve temporary folder of intermediate checkpoint slice files. Useful for debugging.') | ||
args = parser.parse_args() | ||
print(f'args = {args}') | ||
return args | ||
|
||
|
||
def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree): | ||
path_list = [] | ||
iter_folder = f'iter_{iteration:07d}' | ||
for i in range(0, tp_degree): | ||
path_list.append([]) | ||
for j in range(0, pp_degree): | ||
rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}' | ||
ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt') | ||
path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path)) | ||
|
||
return path_list | ||
|
||
|
||
def _save_checkpoint(file_path, chkpt_sd): | ||
dir, _ = os.path.split(file_path) | ||
os.makedirs(dir, exist_ok=True) | ||
torch.save(chkpt_sd, file_path) | ||
|
||
|
||
def extract_zero_shards(dir, ds_checkpoint, indices_3D): | ||
pp_index, tp_index, dp_index = indices_3D | ||
sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index) | ||
|
||
# pprint(f"Processing {dp_index=} {pp_index=}, {tp_index=}") | ||
|
||
optim_sd = sd[OPTIMIZER_STATE_DICT] | ||
param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS] | ||
universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) | ||
pipeline_replicated_params = universal_checkpoint_info.get(PIPELINE_REPLICATED_PARAMETER_PATTERNS, []) | ||
# print(f'{pipeline_replicated_params=}') | ||
|
||
# dict | ||
state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"] | ||
# list | ||
fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS] | ||
param_groups_cnt = len(state_groups) | ||
|
||
for param_group_id in range(param_groups_cnt): | ||
|
||
flat_state = dict( | ||
exp_avg=state_groups[param_group_id]["exp_avg"], | ||
exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"], | ||
fp32=fp32_groups[param_group_id], | ||
) | ||
|
||
for name, fragment_mapping in param_slice_mappings[param_group_id].items(): | ||
if pp_index > 0 and any(re.match(pattern, name) for pattern in pipeline_replicated_params): | ||
# Skip tied weights that are replicated in first and last pp stages | ||
continue | ||
|
||
# pprint(f"dpt{dp_index}{pp_index}{tp_index} {param_group_id} {name} => {fragment_mapping.start}:{fragment_mapping.numel}") | ||
for state_key in flat_state.keys(): | ||
dump_param_fragment(dir, tp_index, dp_index, state_key, flat_state[state_key], name, | ||
fragment_mapping.start, fragment_mapping.numel) | ||
|
||
|
||
cnt = 0 | ||
|
||
|
||
def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel): | ||
|
||
global cnt # temp hack | ||
|
||
param_base_path = os.path.join(dir, param_name, str(tp_index)) | ||
os.makedirs(param_base_path, exist_ok=True) | ||
|
||
cnt += 1 | ||
counter = f"{dp_index:0>2d}" | ||
|
||
path = os.path.join(param_base_path, f"{state_name}.{counter}") | ||
|
||
#print(f"{param_name}: {offset}: {numel} => {path}") | ||
|
||
t = state_flat_tensor.narrow(0, offset, numel).clone() | ||
_save_checkpoint(path, t) | ||
|
||
|
||
def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape): | ||
slices = [] | ||
for tp_index in range(tp_degree): | ||
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}") | ||
paths = sorted(list(glob.glob(f"{prefix_path}.*"))) | ||
shards = [torch.load(p) for p in paths] | ||
slice = torch.cat(shards, dim=0).reshape(slice_shape) | ||
slices.append(slice) | ||
|
||
return slices | ||
|
||
|
||
def _get_vocab_divisibility_padding_tensor(universal_checkpoint_info, padded_vocab_tensor): | ||
original_vocab_size = universal_checkpoint_info.get(ORIGINAL_VOCAB_SIZE) | ||
if padded_vocab_tensor.shape[0] > original_vocab_size: | ||
return padded_vocab_tensor[-1] | ||
else: | ||
return torch.zeros(padded_vocab_tensor.shape[1]) | ||
|
||
|
||
def merge_tp_slices(ds_checkpoint, dir, slice_dir, tp_degree, name_and_shape): | ||
name, shape = name_and_shape | ||
slice_base_path = os.path.join(slice_dir, name) | ||
param_base_path = os.path.join(dir, name) | ||
|
||
universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) | ||
replicated_parameters = universal_checkpoint_info.get(TP_REPLICATED_PARAMETER_PATTERNS, []) | ||
parameters_to_average = universal_checkpoint_info.get(PARAMETER_TO_AVERAGE_PATTERNS, []) | ||
parameters_with_row_parallelism = universal_checkpoint_info.get(PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, []) | ||
vocabulary_parameters = universal_checkpoint_info.get(VOCABULARY_PARAMETER_PATTERNS, []) | ||
for state in ("fp32", "exp_avg", "exp_avg_sq"): | ||
slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape) | ||
final_path = os.path.join(param_base_path, f"{state}.pt") | ||
|
||
#print(f"Expected shape: {shape}") | ||
#print(f"Fragment sizes:", list(frag.shape for frag in slices)) | ||
ckpt_dict = {} | ||
if any(re.match(pattern, name) for pattern in replicated_parameters): | ||
if len(slices) > 1: | ||
assert all([slices[0].equal(other_slice) for other_slice in slices[1:]]) | ||
param = slices[0] | ||
# print(f'replicate {name} using first slice') | ||
elif any(re.match(pattern, name) for pattern in parameters_to_average): | ||
param = sum(slices) / len(slices) | ||
tjruwase marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# print(f'merge {name} using average') | ||
else: | ||
cat_dim = 1 if any(re.match(pattern, name) for pattern in parameters_with_row_parallelism) else 0 | ||
# print(f"merge {name} with CAT DIM: {cat_dim}") | ||
param = torch.cat(slices, dim=cat_dim) | ||
ckpt_dict[CAT_DIM] = cat_dim | ||
|
||
if any(re.match(pattern, name) for pattern in vocabulary_parameters): | ||
#print(f"Before {param.shape=}") | ||
# strip padding | ||
#param = _strip_vocab_padding(ds_checkpoint, param) | ||
ckpt_dict[VOCAB_DIVISIBILITY_PADDING_TENSOR] = _get_vocab_divisibility_padding_tensor( | ||
universal_checkpoint_info, param) | ||
#print(f"After {param.shape=}") | ||
|
||
#print(f"Final shape: {param.shape}") | ||
ckpt_dict[PARAM] = param | ||
_save_checkpoint(final_path, ckpt_dict) | ||
|
||
|
||
def _get_chunks(l, n): | ||
for i in range(0, len(l), n): | ||
yield l[i:i + n] | ||
|
||
|
||
def _do_parallel_work(do_work, work_chunks, num_workers): | ||
pool = multiprocessing.Pool(num_workers) | ||
for batch in tqdm.tqdm(work_chunks): | ||
pool.map(do_work, batch) | ||
pool.close() | ||
pool.join() | ||
|
||
|
||
def _extract_zero_shard_files(args, ds_checkpoint, temp_dir): | ||
_3d_range_list = list( | ||
itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree), | ||
range(ds_checkpoint.dp_degree))) | ||
# pprint(f'{_3d_range_list=}') | ||
work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers)) | ||
# pprint(f'{work_chunks=}') | ||
|
||
# extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0]) | ||
do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint) | ||
_do_parallel_work(do_work, work_chunks, args.num_extract_workers) | ||
|
||
|
||
def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir): | ||
work_chunks = list(_get_chunks(list(slice_shapes.items()), args.num_merge_workers)) | ||
#pprint(work_chunks) | ||
zero_output_folder = os.path.join(args.output_folder, "zero") | ||
do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree) | ||
_do_parallel_work(do_work, work_chunks, args.num_merge_workers) | ||
|
||
|
||
def _save_optimizer_state(args, ds_checkpoint): | ||
sharded_states = [BASE_OPTIMIZER_STATE, PARAM_SLICE_MAPPINGS, SINGLE_PARTITION_OF_FP32_GROUPS] | ||
sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=0, tp_index=0, dp_index=0) | ||
|
||
optim_sd = sd[OPTIMIZER_STATE_DICT] | ||
output_sd = {k: v for k, v in optim_sd.items() if k not in sharded_states} | ||
zero_output_folder = os.path.join(args.output_folder, "zero") | ||
output_file_path = os.path.join(zero_output_folder, f"optimizer_state.pt") | ||
_save_checkpoint(output_file_path, output_sd) | ||
|
||
|
||
def _check_for_required_state(ds_checkpoint): | ||
universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(UNIVERSAL_CHECKPOINT_INFO) | ||
assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.' | ||
|
||
|
||
def main(): | ||
print(f'Convert DeepSpeed Checkpoint to Universal Checkpoint') | ||
|
||
args = parse_arguments() | ||
print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}') | ||
|
||
ds_checkpoint = DeepSpeedCheckpoint(args.input_folder) | ||
_check_for_required_state(ds_checkpoint) | ||
|
||
iteration = ds_checkpoint.get_iteration() | ||
#_create_latest_file(args.output_folder, iteration) | ||
checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree, | ||
ds_checkpoint.pp_degree) | ||
|
||
slice_shapes = [] | ||
for mp_rank_file in ds_checkpoint.mp_rank_files: | ||
mp_sd = torch.load(mp_rank_file, map_location=torch.device('cpu')) | ||
slice_shapes += mp_sd[PARAM_SHAPES] | ||
|
||
# fix back to normal flat dict, merge duplicates for tp>1 | ||
slice_shapes = dict((k, v) for d in slice_shapes for k, v in d.items()) | ||
temp_dir = os.path.join(args.output_folder, 'tmp') | ||
|
||
print('*** 1. Extracting ZeRO fragments') | ||
_extract_zero_shard_files(args, ds_checkpoint, temp_dir) | ||
|
||
print('*** 2. Merging slices') | ||
_merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir) | ||
|
||
print('*** 3. Saving common optimizer states') | ||
_save_optimizer_state(args, ds_checkpoint) | ||
|
||
if not args.keep_temp_folder: | ||
shutil.rmtree(temp_dir, ignore_errors=True) | ||
|
||
# Copy mp* files into output folder | ||
for f in glob.glob(os.path.join(args.input_folder, 'mp*')): | ||
shutil.copy2(f, args.output_folder) | ||
|
||
# Update latest to output folder | ||
checkpoint_root_folder, step_folder = os.path.split(args.output_folder) | ||
latest_file = os.path.join(checkpoint_root_folder, 'latest_universal') | ||
with open(latest_file, "w") as f: | ||
f.write(step_folder) | ||
|
||
print('*** Done!') | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dear @tjruwase ,
I'm currently examining a scenario where the maximum dp_index is 127, which, according to alphabetical order, is considered smaller than 13. This raises a question regarding the tensor sorting process in line 144 of our code. Given this context, could there potentially be any issues with how the tensors are sorted due to this ordering?
I appreciate your insight on this matter.
Best regards,
Junfeng
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A specific example is as follows:
./temp/model.transformer_encoder.layers.19.self_attn.in_proj_weight => fp32.100, torch.Size([2187604])
./temp/model.transformer_encoder.layers.19.self_attn.in_proj_weight => fp32.99, torch.Size([48144044])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rgtjf, thanks for sharing this issue. Do you mind creating a new ticket for it? I can see that line 130 may not generalize to larger scales. It would be great if you could share more details in a new ticket. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tjruwase I've opened a ticket to track the issue, #5283. Should there be any missing details or additional information required, please don't hesitate to let me know.