Skip to content

Commit

Permalink
Option to disable the parallelization of the embedding with TP (#191)
Browse files Browse the repository at this point in the history
* [WIP] add option to disable embedding parallelization

* Issues fixed for BERT

* Remove commented code

* Add support for other models

* Fix RobertaParallelizer

* Update the training args

* Disabling the feature for now

* Disabling the feature for now

* Fix utils
  • Loading branch information
michaelbenayoun authored Aug 25, 2023
1 parent 87b2c33 commit 9b6a061
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 42 deletions.
17 changes: 14 additions & 3 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from optimum.neuron.utils.patching import ModelPatcher

from ...utils import logging
from ..distributed import Parallelizer, ParallelizersManager
from ..distributed.utils import ZeroRedundancyOptimizerCompatibleWithTensorParallelism
Expand Down Expand Up @@ -109,7 +111,7 @@ def __init__(self, *args, tp_plugin: Optional[TensorParallelismPlugin] = None, z
tp_size = 1
else:
tp_size = int(use_neuronx_distributed_tp)
tp_plugin = TensorParallelismPlugin(tensor_parallel_size=tp_size)
tp_plugin = TensorParallelismPlugin(tensor_parallel_size=tp_size, parallelize_embeddings=True)
self._model_cpu_parameters_to_xla = {}

if tp_plugin.should_parallelize:
Expand Down Expand Up @@ -305,8 +307,17 @@ def _prepare_model_for_tp(
model.to(torch.bfloat16)
else:
model.to(torch.float32)
parallel_layers.move_model_to_device(model, self.device)
model.tie_weights()

def _tie_or_clone_weights_for_tp(self, output_embeddings, input_embeddings):
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
output_embeddings.weight = input_embeddings.weight
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
output_embeddings.out_features = input_embeddings.num_embeddings

with ModelPatcher(patching_specs=[(model, "_tie_or_clone_weights", _tie_or_clone_weights_for_tp)]):
model.tie_weights()
parallel_layers.move_model_to_device(model, self.device)
model.tie_weights()
self._model_cpu_parameters_to_xla[id(model)] = dict(zip(cpu_ids, model.parameters()))
device_placement = False
return super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode)
Expand Down
2 changes: 2 additions & 0 deletions optimum/neuron/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def load_optimizer(self, accelerator, optimizer, model, input_dir, optimizer_ind
@dataclass
class TensorParallelismPlugin:
tensor_parallel_size: int = 1
parallelize_embeddings: bool = True

def __post_init__(self):
if self.tensor_parallel_size < 1:
Expand All @@ -162,6 +163,7 @@ def parallelize_model(
model,
orig_to_parallel=orig_to_parallel,
device=device,
parallelize_embeddings=self.parallelize_embeddings,
)
if return_orig_to_parallel:
return parallelized_model, orig_to_parallel
Expand Down
20 changes: 13 additions & 7 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _parallelize(
model: "PreTrainedModel",
orig_to_parallel: Optional[Dict[int, "torch.nn.Parameter"]] = None,
device: Optional["torch.device"] = None,
parallelize_embeddings: bool = True,
) -> "PreTrainedModel":
"""
Parallelizes the model by transforming regular layer into their parallel counterparts.
Expand All @@ -107,6 +108,9 @@ def _parallelize(
It might be deprecated soon.
device (`Optional[torch.device]`, defaults to `None`):
The device where the new parallel layers should be put.
parallelize_embeddings (`bool`, defaults to `True`):
Whether or not the embeddings should be parallelized.
This can be disabled in the case when the TP size does not divide the vocabulary size.
Returns:
`PreTrainedModel`: The parallelized model.
Expand All @@ -118,6 +122,7 @@ def parallelize(
model: "PreTrainedModel",
orig_to_parallel: Optional[Dict[int, "torch.nn.Parameter"]] = None,
device: Optional["torch.device"] = None,
parallelize_embeddings: bool = True,
) -> "PreTrainedModel":
"""
Parallelizes the model by transforming regular layer into their parallel counterparts using
Expand All @@ -134,11 +139,16 @@ def parallelize(
It might be deprecated soon.
device (`Optional[torch.device]`, defaults to `None`):
The device where the new parallel layers should be put.
parallelize_embeddings (`bool`, defaults to `True`):
Whether or not the embeddings should be parallelized.
This can be disabled in the case when the TP size does not divide the vocabulary size.
Returns:
`PreTrainedModel`: The parallelized model.
"""
model = cls._parallelize(model, orig_to_parallel=orig_to_parallel, device=device)
model = cls._parallelize(
model, orig_to_parallel=orig_to_parallel, device=device, parallelize_embeddings=parallelize_embeddings
)
weight_map = getattr(model, "_weight_map", {})
with torch.no_grad():
modules_to_initialize = []
Expand All @@ -152,12 +162,8 @@ def parallelize(
"attached to the model to load the proper weights from file."
)
split = name.rsplit(".", maxsplit=1)
if len(split) == 1:
module = model
attribute_name = split[0]
else:
module = model.get_submodule(split[0])
attribute_name = split[1]
module = model.get_submodule(split[0])
attribute_name = split[1]
try:
weight_info = WeightInformation(weight_map[name], name, device=device)
setattr(module, attribute_name, torch.nn.Parameter(load_tensor_for_weight(weight_info)))
Expand Down
8 changes: 6 additions & 2 deletions optimum/neuron/distributed/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ def _parallelize(
model: "PreTrainedModel",
orig_to_parallel: Optional[Dict[int, "torch.nn.Parameter"]],
device: Optional["torch.device"] = None,
parallelize_embeddings: bool = True,
) -> "PreTrainedModel":
model = GPTNeoParallelEmbedding.transform(model, model, device=device)
if parallelize_embeddings:
model = GPTNeoParallelEmbedding.transform(model, model, device=device)
for block in model.transformer.h:
block.attn.attention = GPTNeoParallelSelfAttention.transform(
model,
Expand Down Expand Up @@ -130,8 +132,10 @@ def _parallelize(
model: "PreTrainedModel",
orig_to_parallel: Optional[Dict[int, "torch.nn.Parameter"]],
device: Optional["torch.device"] = None,
parallelize_embeddings: bool = True,
) -> "PreTrainedModel":
model = LlamaParallelEmbedding.transform(model, model, device=device)
if parallelize_embeddings:
model = LlamaParallelEmbedding.transform(model, model, device=device)
for layer in model.model.layers:
layer.self_attn = LlamaParallelSelfAttention.transform(model, layer.self_attn, device=device)
layer.mlp = LLamaParallelMLP.transform(model, layer.mlp, device=device)
Expand Down
4 changes: 3 additions & 1 deletion optimum/neuron/distributed/encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,10 @@ def _parallelize(
model: "PreTrainedModel",
orig_to_parallel: Optional[Dict[int, "torch.nn.Parameter"]],
device: Optional["torch.device"] = None,
parallelize_embeddings: bool = True,
) -> "PreTrainedModel":
model = T5ParallelEmbedding.transform(model, model, device=device)
if parallelize_embeddings:
model = T5ParallelEmbedding.transform(model, model, device=device)
if model.encoder.embed_tokens is not None:
model.encoder.embed_tokens = model.shared
if model.decoder.embed_tokens is not None:
Expand Down
25 changes: 24 additions & 1 deletion optimum/neuron/distributed/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,23 @@
from typing import TYPE_CHECKING, Dict, Optional

from .base import Parallelizer
from .parallel_layers import ParallelSelfAttention, ParallelSelfOutput
from .parallel_layers import ParallelEmbedding, ParallelSelfAttention, ParallelSelfOutput


if TYPE_CHECKING:
import torch
from transformers import PreTrainedModel


class BertParallelEmbedding(ParallelEmbedding):
EMBEDDING_NAME = "bert.embeddings.word_embeddings"
LM_HEAD_NAME = {
"BertForPreTraining": "cls.predictions.decoder",
"BertLMHeadModel": "cls.predictions.decoder",
"BertForMaskedLM": "cls.predictions.decoder",
}


class BertParallelSelfAttention(ParallelSelfAttention):
ALL_HEAD_SIZE_NAME = "all_head_size"

Expand All @@ -40,7 +49,10 @@ def _parallelize(
model: "PreTrainedModel",
orig_to_parallel: Optional[Dict[int, "torch.nn.Parameter"]],
device: Optional["torch.device"] = None,
parallelize_embeddings: bool = True,
) -> "PreTrainedModel":
if parallelize_embeddings:
model = BertParallelEmbedding.transform(model, model, device=device)
for layer in model.bert.encoder.layer:
layer.attention.self = BertParallelSelfAttention.transform(
model,
Expand All @@ -57,6 +69,14 @@ def _parallelize(
return model


class RobertaParallelEmbedding(ParallelEmbedding):
EMBEDDING_NAME = "roberta.embeddings.word_embeddings"
LM_HEAD_NAME = {
"RobertaForCausalLM": "lm_head.decoder",
"RobertaForMaskedLM": "lm_head.decoder",
}


class RobertaParallelSelfAttention(BertParallelSelfAttention):
pass

Expand All @@ -72,7 +92,10 @@ def _parallelize(
model: "PreTrainedModel",
orig_to_parallel: Optional[Dict[int, "torch.nn.Parameter"]],
device: Optional["torch.device"] = None,
parallelize_embeddings: bool = True,
) -> "PreTrainedModel":
if parallelize_embeddings:
model = RobertaParallelEmbedding.transform(model, model, device=device)
for layer in model.roberta.encoder.layer:
layer.attention.self = RobertaParallelSelfAttention.transform(
model,
Expand Down
50 changes: 36 additions & 14 deletions optimum/neuron/distributed/parallel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

from abc import ABC, abstractclassmethod
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union

from ...utils import NormalizedConfigManager
from ...utils import NormalizedConfigManager, logging
from ..utils import is_neuronx_distributed_available
from .utils import WeightInformation, embedding_to_parallel_embedding, linear_to_parallel_linear

Expand All @@ -31,6 +31,9 @@
from transformers import PretrainedConfig, PreTrainedModel


logger = logging.get_logger()


class ParallelLayer(ABC):
@classmethod
def _get_module_and_attribute_name(
Expand Down Expand Up @@ -106,12 +109,13 @@ class ParallelEmbedding(ParallelLayer):
Attributes:
EMBEDDING_NAME (`str`, defaults to `"embedding"`):
The qualified name of the embedding layer.
LM_HEAD_NAME (`Optional[str]`, defaults to `None`):
The qualified name of the LM head tied to the embedding layer (if any).
LM_HEAD_NAME (`Optional[Union[str, Dict[str, str]]]`, defaults to `None`):
The qualified name of the LM head tied to the embedding layer (if any). It can be also a dictionary mapping
a class name to LM head qualified name.
"""

EMBEDDING_NAME: str = "embedding"
LM_HEAD_NAME: Optional[str] = None
LM_HEAD_NAME: Optional[Union[str, Dict[str, str]]] = None

@classmethod
def transform(
Expand All @@ -122,10 +126,16 @@ def transform(
device: Optional["torch.device"] = None,
) -> "torch.nn.Module":
if cls.LM_HEAD_NAME is not None:
parent_lm_head_module, parent_lm_head_attribute_name = cls._get_module_and_attribute_name(
layer, cls.LM_HEAD_NAME
)
model_has_lm_head = hasattr(parent_lm_head_module, parent_lm_head_attribute_name)
if isinstance(cls.LM_HEAD_NAME, dict):
lm_head_name = cls.LM_HEAD_NAME.get(model.__class__.__name__, None)
else:
lm_head_name = cls.LM_HEAD_NAME
model_has_lm_head = False
if lm_head_name is not None:
parent_lm_head_module, parent_lm_head_attribute_name = cls._get_module_and_attribute_name(
layer, lm_head_name
)
model_has_lm_head = hasattr(parent_lm_head_module, parent_lm_head_attribute_name)
else:
model_has_lm_head = False

Expand All @@ -147,11 +157,11 @@ def transform(
)
if model_has_lm_head:
if layer_qualified_name:
lm_head_weight_name = f"{layer_qualified_name}.{cls.LM_HEAD_NAME}.weight"
lm_head_bias_weight_name = f"{layer_qualified_name}.{cls.LM_HEAD_NAME}.bias"
lm_head_weight_name = f"{layer_qualified_name}.{lm_head_name}.weight"
lm_head_bias_weight_name = f"{layer_qualified_name}.{lm_head_name}.bias"
else:
lm_head_weight_name = f"{cls.LM_HEAD_NAME}.weight"
lm_head_bias_weight_name = f"{cls.LM_HEAD_NAME}.bias"
lm_head_weight_name = f"{lm_head_name}.weight"
lm_head_bias_weight_name = f"{lm_head_name}.bias"
if lm_head_weight_name in weight_map:
lm_head_weight_info = WeightInformation(
weight_map[lm_head_weight_name], lm_head_weight_name, device=device
Expand All @@ -161,9 +171,21 @@ def transform(
weight_map[lm_head_bias_weight_name], lm_head_bias_weight_name, device=device
)

embedding_layer = layer.get_submodule(cls.EMBEDDING_NAME)
tp_size = parallel_state.get_tensor_model_parallel_size()
if embedding_layer.num_embeddings % tp_size != 0:
import torch_xla.core.xla_model as xm

if xm.get_ordinal() == 0:
logger.warning(
f"Embedding parallelization for TP was skipped because the tensor parallel size ({tp_size}) does not "
f"divide the number of embeddings ({embedding_layer.num_embeddings})"
)
return layer

parallel_layers = embedding_to_parallel_embedding(
layer.get_submodule(cls.EMBEDDING_NAME),
lm_head_layer=layer.get_submodule(cls.LM_HEAD_NAME) if model_has_lm_head else None,
lm_head_layer=layer.get_submodule(lm_head_name) if model_has_lm_head else None,
embedding_weight_info=embedding_weight_info,
lm_head_weight_info=lm_head_weight_info,
lm_head_bias_weight_info=lm_head_bias_weight_info,
Expand Down
12 changes: 6 additions & 6 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def embedding_to_parallel_embedding(
None,
),
)
parallel_embedding_layer.weight.data = weight_data
parallel_embedding_layer.weight.copy_(weight_data)
else:
parallel_embedding_layer.weight.copy_(
embedding_layer.weight[tp_rank * row_size : (tp_rank + 1) * row_size, :]
Expand Down Expand Up @@ -280,6 +280,7 @@ def linear_to_parallel_linear(
kwargs["device"] = device

parallel_linear_layer = parallel_linear_class(linear_layer.in_features, linear_layer.out_features, **kwargs)

tp_rank = get_tensor_model_parallel_rank()
row_size, col_size = parallel_linear_layer.weight.shape

Expand All @@ -295,7 +296,7 @@ def linear_to_parallel_linear(
(tp_rank * col_size, (tp_rank + 1) * col_size),
),
)
parallel_linear_layer.weight.data = weight_data
parallel_linear_layer.weight.copy_(weight_data)
elif linear_layer.weight.device != torch.device("meta"):
parallel_linear_layer.weight.copy_(
linear_layer.weight[:, tp_rank * col_size : (tp_rank + 1) * col_size]
Expand All @@ -306,7 +307,7 @@ def linear_to_parallel_linear(
if linear_layer.bias is not None:
if linear_layer_bias_weight_info is not None:
bias_weight_data = load_tensor_for_weight(linear_layer_bias_weight_info)
parallel_linear_layer.bias.data = bias_weight_data
parallel_linear_layer.bias.copy_(bias_weight_data)
else:
parallel_linear_layer.bias.copy_(linear_layer.bias)

Expand All @@ -323,7 +324,7 @@ def linear_to_parallel_linear(
None,
),
)
parallel_linear_layer.weight.data = weight_data
parallel_linear_layer.weight.copy_(weight_data)

elif linear_layer.weight.device != torch.device("meta"):
parallel_linear_layer.weight.copy_(
Expand All @@ -347,8 +348,7 @@ def linear_to_parallel_linear(
linear_layer_bias_weight_info,
tensor_slices=tensor_slices,
)
parallel_linear_layer.bias.data = bias_weight_data

parallel_linear_layer.bias.copy_(bias_weight_data)
else:
if gather_output:
parallel_linear_layer.bias.copy_(linear_layer.bias)
Expand Down
9 changes: 8 additions & 1 deletion optimum/neuron/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,18 @@ class NeuronTrainingArgumentsMixin:
tensor_parallel_size: int = field(
default=1, metadata={"help": "The number of replicas the model will be sharded on."}
)
disable_embedding_parallelization: bool = field(
default=True,
metadata={"help": "Whether or not the embedding parallelization when doing TP should be disabled."},
)

def __post_init__(self):
# Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available`
patch_accelerate_is_tpu_available()

if not self.disable_embedding_parallelization:
raise NotImplementedError("Disabling the parallelization of the embeddings is not fully supported yet.")

if self.fsdp != "":
# Disabling FSDP until next release because it is still very experimental and not validated.
raise RuntimeError("FSDP is not supported yet.")
Expand All @@ -86,7 +93,7 @@ def __post_init__(self):
"The minimal required Transformers version to perform XLA FSDP is "
f"{TRANSFORMERS_MIN_VERSION_FOR_XLA_FSDP} but {transformers.__version__} is installed."
)
self.tp_plugin = TensorParallelismPlugin(self.tensor_parallel_size)
self.tp_plugin = TensorParallelismPlugin(self.tensor_parallel_size, not self.disable_embedding_parallelization)
super().__post_init__()

# Needed only to specialize the warning message for FSDP.
Expand Down
Loading

0 comments on commit 9b6a061

Please sign in to comment.