Skip to content

Commit

Permalink
add zero optimizer parallel (#593)
Browse files Browse the repository at this point in the history
* add zero optimizer parallel

* code check

* add some comments

* update

* add some info

* zero helper

* bug fix

* reconstruct

* comm fusion

* ema update

* update

* update

* update

* update

* checkpoint merging

* fix bug

* fix bug

* fix bug

---------

Co-authored-by: zhaoting <zhaoting23@huawei.com>
  • Loading branch information
CaitinZhao and zhaoting authored Sep 11, 2024
1 parent aa1c32d commit 5831703
Show file tree
Hide file tree
Showing 7 changed files with 971 additions and 2 deletions.
13 changes: 13 additions & 0 deletions mindone/models/modules/parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from mindspore import nn

from .conv import Conv1d, Conv2d, Conv3d
from .dense import Dense

# {Original MindSpore Cell: New Cell in ZeRO3}
PARALLEL_MODULES = {
nn.Conv1d: Conv1d,
nn.Conv2d: Conv2d,
nn.Conv3d: Conv3d,
nn.Dense: Dense,
}
__all__ = ["Conv1d", "Conv2d", "Conv3d", "Dense"]
73 changes: 73 additions & 0 deletions mindone/models/modules/parallel/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from mindspore import nn, ops
from mindspore.communication import get_group_size, get_rank
from mindspore.communication.management import GlobalComm
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode

from .param_wrapper import ZeroParamWrapper


class _Conv(nn.Cell):
def __init__(self, net, zero_stage: int = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None):
super(_Conv, self).__init__(auto_prefix=False)
self.net = net
self.set_param_wrapper(zero_stage, op_group, cell_type)

def set_param_wrapper(self, zero_stage, op_group, cell_type=None):
self.param_wrapper_w = nn.Identity()
self.param_wrapper_b = nn.Identity()
if zero_stage == 3:
# Init parallel settings
is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL
op_group_size = get_group_size(op_group) if is_parallel else 1
op_rank_id = get_rank(op_group) if is_parallel else 0
self.param_wrapper_w = ZeroParamWrapper(self.net.weight, zero_stage, op_group, cell_type)
split_op = ops.Split(0, op_group_size)
if self.param_wrapper_w.need_rewrite:
self.net.weight.assign_value(split_op(self.net.weight)[op_rank_id])
if self.net.has_bias:
self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, op_group, cell_type)
if self.param_wrapper_b.need_rewrite:
self.net.bias.assign_value(split_op(self.net.bias)[op_rank_id])


class Conv1d(_Conv):
def construct(self, x):
x = self.net.expand_dims(x, 2)
output = self.net.conv2d(x, self.param_wrapper_w(self.net.weight))
if self.net.has_bias:
output = self.net.bias_add(output, self.param_wrapper_b(self.net.bias))

output = self.net.squeeze(output)
return output


class Conv2d(_Conv):
def construct(self, x):
output = self.net.conv2d(x, self.param_wrapper_w(self.net.weight))
if self.net.has_bias:
output = self.net.bias_add(output, self.param_wrapper_b(self.net.bias))
return output


class Conv3d(_Conv):
def construct(self, x):
weight = self.param_wrapper_w(self.net.weight)
bias = self.param_wrapper_b(self.net.bias)
if self.net.group == 1:
out = self.net.conv3d(x, weight)
if self.net.has_bias:
out = self.net.bias_add(out, bias)
else:
features = self.net.split_1(x)
weights = self.net.split_0(weight)
outputs = ()
for i in range(self.net.group):
output = self.net.conv3d(features[i], weights[i])
outputs = outputs + (output,)
out = self.net.concat(outputs)
if self.net.bias is not None:
new_shape = [1 for _ in range(out.ndim)]
new_shape[1] = self.net.out_channels
out = out + bias.reshape(new_shape)
return out
45 changes: 45 additions & 0 deletions mindone/models/modules/parallel/dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from mindspore import nn, ops
from mindspore.communication import get_group_size, get_rank
from mindspore.communication.management import GlobalComm
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode

from .param_wrapper import ZeroParamWrapper


class Dense(nn.Cell):
def __init__(self, net, zero_stage: int = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None):
super(Dense, self).__init__(auto_prefix=False)
self.net = net
self.set_param_wrapper(zero_stage, op_group, cell_type)

def set_param_wrapper(self, zero_stage, op_group, cell_type=None):
self.param_wrapper_w = nn.Identity()
self.param_wrapper_b = nn.Identity()
if zero_stage == 3:
# Init parallel settings
is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL
op_group_size = get_group_size(op_group) if is_parallel else 1
op_rank_id = get_rank(op_group) if is_parallel else 0
self.param_wrapper_w = ZeroParamWrapper(self.net.weight, zero_stage, op_group, cell_type)
split_op = ops.Split(0, op_group_size)
if self.param_wrapper_w.need_rewrite:
self.net.weight.assign_value(split_op(self.net.weight)[op_rank_id])
if self.net.has_bias:
self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, op_group, cell_type)
if self.param_wrapper_b.need_rewrite:
self.net.bias.assign_value(split_op(self.net.bias)[op_rank_id])

def construct(self, x):
x_shape = x.shape
if len(x_shape) != 2:
x = x.reshape(-1, x_shape[-1])
x = self.net.matmul(x, self.param_wrapper_w(self.net.weight))
if self.net.has_bias:
x = self.net.bias_add(x, self.param_wrapper_b(self.net.bias))
if self.net.activation_flag:
x = self.net.activation(x)
if len(x_shape) != 2:
out_shape = x_shape[:-1] + (x.shape[-1],)
x = x.reshape(out_shape)
return x
57 changes: 57 additions & 0 deletions mindone/models/modules/parallel/param_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import mindspore as ms
from mindspore import nn, ops
from mindspore.communication import get_group_size
from mindspore.communication.management import GlobalComm
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode


class ZeroParamWrapper(nn.Cell):
"""
a cell to Insert communication operators before and after parameters when `zero_stage == 3`.
"""

def __init__(
self, param: ms.Parameter, zero_stage: int = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None
):
super().__init__(auto_prefix=False)
self.op_group = op_group
self.zero_stage = zero_stage
self.cell_type = cell_type
if zero_stage != 3:
raise ValueError(f"ZeroParamWrapper not support zero_stage {zero_stage}.")

# Init parallel settings
self.is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL
self.op_group_size = get_group_size(self.op_group) if self.is_parallel else 1
self.allgather = ops.Identity()
self.reduce_scatter = None
self.dtype = param.dtype
self.allreduce = ops.AllReduce(group=self.op_group, op=ops.ReduceOp.SUM)

self.need_rewrite = self.check_rewrite(param)
if self.need_rewrite:
self.op_allgather = ops.AllGather(group=self.op_group)
self.op_reduce_scatter = ops.ReduceScatter(group=self.op_group, op=ops.ReduceOp.SUM)

def check_rewrite(self, param):
"""Check the parameter need to split or not."""
need_rewrite = self.is_parallel
B = param.shape[0]
if not param.parallel_optimizer or B < self.op_group_size or B % self.op_group_size != 0:
need_rewrite = False
return need_rewrite

def construct(self, param):
if self.need_rewrite:
if self.cell_type is not None:
param = param.to(self.cell_type)
return self.op_allgather(param)
return param

def bprop(self, param, out, dout):
if self.need_rewrite:
r = self.op_reduce_scatter(dout.to(self.dtype)) / self.op_group_size
return (r,)
dout = self.allreduce(dout.to(self.dtype)) / self.op_group_size
return (dout,)
20 changes: 18 additions & 2 deletions mindone/trainers/train_step.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Train step wrapper supporting setting drop overflow update, ema etc"""

from packaging import version

import mindspore as ms
Expand Down Expand Up @@ -39,6 +40,7 @@ class TrainOneStepWrapper(nn.TrainOneStepWithLossScaleCell):
scale_sense (Union[Tensor, Cell]): If this value is a Cell, it will be called
to update loss scale. If this value is a Tensor, the loss scale can be modified by `set_sense_scale`,
the shape should be :math:`()` or :math:`(1,)`.
zero_helper (class): Zero redundancy optimizer(ZeRO) build helper, default is None.
Returns:
Tuple of 3 Tensor, the loss, overflow flag and current loss scale value.
Expand All @@ -60,6 +62,7 @@ def __init__(
clip_grad=False,
clip_norm=1.0,
verbose=False,
zero_helper=None,
):
super().__init__(network, optimizer, scale_sense)
self.ema = ema
Expand All @@ -85,6 +88,14 @@ def __init__(
self.map = ops.Map()
self.partial = ops.Partial()

# zero init
self.zero_helper = zero_helper
self.zero_stage = zero_helper.zero_stage if zero_helper is not None else 0
self.run_optimizer = zero_helper.run_optimizer if zero_helper is not None else self.optimizer
self.grad_reducer = self.grad_reducer if self.zero_stage == 0 else nn.Identity()
if self.zero_stage != 0:
self.zero_helper.split_params()

def construct(self, *inputs):
# compute loss
weights = self.weights
Expand All @@ -104,6 +115,11 @@ def construct(self, *inputs):

# 1. compute gradients (of the up-scaled loss w.r.t. the model weights)
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)

# Gradient communication
if self.zero_helper is not None:
grads = self.zero_helper.cal_gradients(grads)

if self.accum_steps == 1:
grads = self.grad_reducer(grads)
scaling_sens = ops.depend(scaling_sens, grads)
Expand Down Expand Up @@ -140,7 +156,7 @@ def construct(self, *inputs):
if self.clip_grad:
grads = ops.clip_by_global_norm(grads, self.clip_norm)
# 7. optimize
loss = F.depend(loss, self.optimizer(grads))
loss = F.depend(loss, self.run_optimizer(grads))

# clear gradient accumulation states
loss = F.depend(loss, self.hyper_map(F.partial(_grad_clear_op), self.accumulated_grads))
Expand All @@ -157,7 +173,7 @@ def construct(self, *inputs):
if self.clip_grad:
grads = ops.clip_by_global_norm(grads, self.clip_norm)
# 7. optimize
loss = F.depend(loss, self.optimizer(grads))
loss = F.depend(loss, self.run_optimizer(grads))

# 8.ema
if self.ema is not None:
Expand Down
Loading

0 comments on commit 5831703

Please sign in to comment.