From 6d529d587c82635291dbdeb836af08f32482a16b Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Tue, 15 Oct 2024 13:05:06 +0800 Subject: [PATCH 01/14] add Llama3 model --- examples/llama3/llama/__init__.py | 0 examples/llama3/llama/models/__init__.py | 1 + examples/llama3/llama/models/activation.py | 28 ++ .../llama3/llama/models/llama/__init__.py | 1 + examples/llama3/llama/models/llama/layer.py | 244 ++++++++++++++++++ examples/llama3/llama/models/llama/network.py | 155 +++++++++++ examples/llama3/test.py | 29 +++ examples/llama3/tools/convert.py | 66 +++++ examples/llama3/tools/fileio/__init__.py | 1 + examples/llama3/tools/fileio/safetensors.py | 21 ++ 10 files changed, 546 insertions(+) create mode 100644 examples/llama3/llama/__init__.py create mode 100644 examples/llama3/llama/models/__init__.py create mode 100644 examples/llama3/llama/models/activation.py create mode 100644 examples/llama3/llama/models/llama/__init__.py create mode 100644 examples/llama3/llama/models/llama/layer.py create mode 100644 examples/llama3/llama/models/llama/network.py create mode 100644 examples/llama3/test.py create mode 100644 examples/llama3/tools/convert.py create mode 100644 examples/llama3/tools/fileio/__init__.py create mode 100644 examples/llama3/tools/fileio/safetensors.py diff --git a/examples/llama3/llama/__init__.py b/examples/llama3/llama/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/llama3/llama/models/__init__.py b/examples/llama3/llama/models/__init__.py new file mode 100644 index 0000000000..a1829d24c5 --- /dev/null +++ b/examples/llama3/llama/models/__init__.py @@ -0,0 +1 @@ +from .llama import LlamaModel diff --git a/examples/llama3/llama/models/activation.py b/examples/llama3/llama/models/activation.py new file mode 100644 index 0000000000..22a4a66112 --- /dev/null +++ b/examples/llama3/llama/models/activation.py @@ -0,0 +1,28 @@ +import logging +from collections import OrderedDict + +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor + +logger = logging.getLogger(__name__) + + +class QuickGELU(nn.Cell): + def construct(self, x: Tensor): + return x * ops.sigmoid(1.702 * x) + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "quick_gelu": QuickGELU, + "gelu": nn.GELU, + "silu": nn.SiLU, +} +ACT2FN = ClassInstantier(ACT2CLS) diff --git a/examples/llama3/llama/models/llama/__init__.py b/examples/llama3/llama/models/llama/__init__.py new file mode 100644 index 0000000000..73c46f222f --- /dev/null +++ b/examples/llama3/llama/models/llama/__init__.py @@ -0,0 +1 @@ +from .network import LlamaModel diff --git a/examples/llama3/llama/models/llama/layer.py b/examples/llama3/llama/models/llama/layer.py new file mode 100644 index 0000000000..9c3d7da7f7 --- /dev/null +++ b/examples/llama3/llama/models/llama/layer.py @@ -0,0 +1,244 @@ +import logging +from typing import Optional, Tuple + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Parameter, Tensor +from mindspore.ops.operations.nn_ops import FlashAttentionScore + +from ..activation import ACT2FN + +logger = logging.getLogger(__name__) + + +class LlamaRMSNorm(nn.Cell): + def __init__(self, hidden_size: int, eps: float = 1e-6, dtype: ms.Type = ms.float32) -> None: + super().__init__() + self.weight = Parameter(ops.ones(hidden_size, dtype=dtype)) + self.variance_epsilon = eps + + def construct(self, hidden_states: Tensor) -> Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(ms.float32) + variance = ops.pow(hidden_states, 2) + variance = ops.mean(variance, axis=-1, keep_dims=True) + hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class LlamaRotaryEmbedding(nn.Cell): + def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0) -> None: + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.inv_freq = 1.0 / (base ** (ops.arange(0, self.dim, 2, dtype=ms.float32) / self.dim)) + + def construct(self, x: Tensor, position_ids: Tensor) -> Tensor: + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = ops.broadcast_to(self.inv_freq[None, :, None], (position_ids.shape[0], -1, 1)) + position_ids_expanded = position_ids[:, None, :].to(ms.float32) + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + freqs = ops.matmul(inv_freq_expanded.to(ms.float32), position_ids_expanded.to(ms.float32)) + freqs = ops.transpose(freqs, (0, 2, 1)) + emb = ops.concat((freqs, freqs), axis=-1) + cos = ops.cos(emb) + sin = ops.sin(emb) + output = ops.stack([cos.to(x.dtype), sin.to(x.dtype)]) + return output + + +def rotate_half(x: Tensor) -> Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return ops.concat((-x2, x1), axis=-1) + + +def apply_rotary_pos_emb( + q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, unsqueeze_dim: int = 1 +) -> Tuple[Tensor, Tensor]: + cos = ops.unsqueeze(cos, unsqueeze_dim) + sin = ops.unsqueeze(sin, unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Cell): + def __init__( + self, + intermediate_size: int = 14336, + hidden_size: int = 4096, + hidden_act: str = "silu", + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Dense(self.hidden_size, self.intermediate_size, has_bias=False, dtype=dtype) + self.up_proj = nn.Dense(self.hidden_size, self.intermediate_size, has_bias=False, dtype=dtype) + self.down_proj = nn.Dense(self.intermediate_size, self.hidden_size, has_bias=False, dtype=dtype) + self.act_fn = ACT2FN[hidden_act] + + def construct(self, hidden_state: Tensor) -> Tensor: + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :] + hidden_states = ops.broadcast_to(hidden_states, (batch, num_key_value_heads, n_rep, slen, head_dim)) + hidden_states = ops.reshape(hidden_states, (batch, num_key_value_heads * n_rep, slen, head_dim)) + return hidden_states + + +class LlamaAttention(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + attention_dropout: float = 0.0, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + + self.attention_dropout = attention_dropout + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Dense(self.hidden_size, self.num_heads * self.head_dim, has_bias=False, dtype=dtype) + self.k_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=False, dtype=dtype) + self.v_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=False, dtype=dtype) + self.o_proj = nn.Dense(self.num_heads * self.head_dim, self.hidden_size, has_bias=False, dtype=dtype) + + def construct( + self, + hidden_states: Tensor, + position_embeddings: Tensor, + attention_mask: Optional[Tensor] = None, + ) -> Tensor: + bsz, q_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = ops.reshape(query_states, (bsz, q_len, self.num_heads, self.head_dim)) + query_states = ops.transpose(query_states, (0, 2, 1, 3)) + + key_states = ops.reshape(key_states, (bsz, q_len, self.num_key_value_heads, self.head_dim)) + key_states = ops.transpose(key_states, (0, 2, 1, 3)) + + value_states = ops.reshape(value_states, (bsz, q_len, self.num_key_value_heads, self.head_dim)) + value_states = ops.transpose(value_states, (0, 2, 1, 3)) + + cos, sin = position_embeddings[0], position_embeddings[1] + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states = ops.transpose(key_states, (0, 1, 3, 2)) + attn_weights = ops.matmul(query_states, key_states) / ms.numpy.sqrt(self.head_dim) + + # upcast attention to fp32 + attn_weights = attn_weights.to(ms.float32) + if attention_mask is not None: + attn_weights = ops.masked_fill(attn_weights, ~attention_mask, -ms.numpy.inf) + + attn_weights = ops.softmax(attn_weights, axis=-1).to(query_states.dtype) + attn_weights = ops.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = ops.matmul(attn_weights, value_states) + + attn_output = ops.transpose(attn_output, (0, 2, 1, 3)) + attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) + attn_output = self.o_proj(attn_output) + + return attn_output + + +class LlamaFlashAttention(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + attention_dropout: float = 0.0, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + + self.attention_dropout = attention_dropout + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Dense(self.hidden_size, self.num_heads * self.head_dim, has_bias=False, dtype=dtype) + self.k_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=False, dtype=dtype) + self.v_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=False, dtype=dtype) + self.o_proj = nn.Dense(self.num_heads * self.head_dim, self.hidden_size, has_bias=False, dtype=dtype) + + self.flash_attention = FlashAttentionScore( + self.num_heads, keep_prob=1 - self.attention_dropout, scale_value=self.head_dim**-0.5, input_layout="BSND" + ) + + def construct( + self, + hidden_states: Tensor, + position_embeddings: Tensor, + attention_mask: Optional[Tensor] = None, + ) -> Tensor: + bsz, q_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = ops.reshape(query_states, (bsz, q_len, self.num_heads, self.head_dim)) + query_states = ops.transpose(query_states, (0, 2, 1, 3)) + + key_states = ops.reshape(key_states, (bsz, q_len, self.num_key_value_heads, self.head_dim)) + key_states = ops.transpose(key_states, (0, 2, 1, 3)) + + value_states = ops.reshape(value_states, (bsz, q_len, self.num_key_value_heads, self.head_dim)) + value_states = ops.transpose(value_states, (0, 2, 1, 3)) + + cos, sin = position_embeddings[0], position_embeddings[1] + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Reshape to the expected shape and dtype for Flash Attention + query_states = ops.transpose(query_states, (0, 2, 1, 3)) + key_states = ops.transpose(key_states, (0, 2, 1, 3)) + value_states = ops.transpose(value_states, (0, 2, 1, 3)) + + _, _, _, attn_output = self.flash_attention( + query_states, key_states, value_states, None, None, None, attention_mask + ) + attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) + attn_output = self.o_proj(attn_output) + + return attn_output diff --git a/examples/llama3/llama/models/llama/network.py b/examples/llama3/llama/models/llama/network.py new file mode 100644 index 0000000000..03a1d7ceb9 --- /dev/null +++ b/examples/llama3/llama/models/llama/network.py @@ -0,0 +1,155 @@ +from typing import Literal, Optional + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.common.initializer import Normal, Zero, initializer + +from .layer import LlamaAttention, LlamaFlashAttention, LlamaMLP, LlamaRMSNorm, LlamaRotaryEmbedding + +Llama_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention": LlamaFlashAttention, +} + + +class LlamaDecoderLayer(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + intermediate_size: int = 14336, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + hidden_act: str = "silu", + attn_implementation: Literal["eager", "flash_attention"] = "eager", + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + + self.self_attn = Llama_ATTENTION_CLASSES[attn_implementation]( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_dropout=attention_dropout, + dtype=dtype, + ) + + self.mlp = LlamaMLP( + intermediate_size=intermediate_size, hidden_size=hidden_size, hidden_act=hidden_act, dtype=dtype + ) + self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) + self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) + + def construct( + self, + hidden_states: Tensor, + position_embeddings: Tensor, + attention_mask: Optional[Tensor] = None, + ) -> Tensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states, + position_embeddings, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class LlamaModel(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + intermediate_size: int = 14336, + max_position_embeddings: int = 32768, + num_attention_heads: int = 32, + num_hidden_layers: int = 32, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-5, + rope_theta: float = 1000000.0, + attention_dropout: float = 0.0, + hidden_act: str = "silu", + initializer_range: float = 0.02, + attn_implementation: Literal["eager", "flash_attention"] = "eager", + gradient_checkpointing: bool = False, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.layers = nn.CellList( + [ + LlamaDecoderLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + attention_dropout=attention_dropout, + hidden_act=hidden_act, + attn_implementation=attn_implementation, + dtype=dtype, + ) + for _ in range(num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) + self.rotary_emb = LlamaRotaryEmbedding( + hidden_size // num_attention_heads, + max_position_embeddings=max_position_embeddings, + base=rope_theta, + ) + + # post-init + self.initializer_range = initializer_range + self.init_weights() + + # recompute + if gradient_checkpointing: + self.layers.recompute() + + def init_weights(self): + def _init_weights(module): + std = self.initializer_range + if isinstance(module, nn.Dense): + module.weight.set_data(initializer(Normal(std, 0.0), module.weight.shape, module.weight.dtype)) + if module.bias is not None: + module.bias.set_data(initializer(Zero(), module.bias.shape, module.bias.dtype)) + + self.apply(_init_weights) + + def construct( + self, + inputs_embeds: Tensor, + attention_mask: Optional[Tensor] = None, + ) -> Tensor: + position_ids = ops.arange(0, inputs_embeds.shape[1], dtype=ms.int64) + position_ids = ops.unsqueeze(position_ids, 0) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(inputs_embeds, position_ids) + + hidden_states = inputs_embeds + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + position_embeddings, + attention_mask=attention_mask, + ) + + hidden_states = self.norm(hidden_states) + return hidden_states diff --git a/examples/llama3/test.py b/examples/llama3/test.py new file mode 100644 index 0000000000..325ca6c012 --- /dev/null +++ b/examples/llama3/test.py @@ -0,0 +1,29 @@ +import numpy as np +from llama.models import LlamaModel + +import mindspore as ms +import mindspore.nn as nn + + +def count_params(model: nn.Cell) -> int: + total_params = sum([param.size for param in model.get_parameters()]) + return total_params + + +def main(): + ms.set_context(mode=ms.GRAPH_MODE) + network = LlamaModel(attn_implementation="flash_attention", dtype=ms.bfloat16) + ms.load_checkpoint("model.ckpt", network) + + params = count_params(network) + print(f"Parameter number: {params:,}") + + inputs = ms.Tensor(np.ones((4, 256, 4096)), dtype=ms.bfloat16) + outputs = network(inputs) + + print(outputs.shape) + print(outputs) + + +if __name__ == "__main__": + main() diff --git a/examples/llama3/tools/convert.py b/examples/llama3/tools/convert.py new file mode 100644 index 0000000000..889dfc898d --- /dev/null +++ b/examples/llama3/tools/convert.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python +import argparse +import glob +import logging +import os +from typing import Dict + +from fileio import load_safetensors +from tqdm import tqdm + +import mindspore as ms +from mindspore import Tensor + +logger = logging.getLogger(__name__) + + +def load(root_path: str, force_fp32: bool = False) -> Dict[str, Tensor]: + # TODO: this method may cause OOM on computer with low memory + # use a better solution later + pattern = os.path.join(root_path, "*.safetensors") + filelist = sorted(glob.glob(pattern)) + + filenames = [os.path.basename(x) for x in filelist] + logger.info(f"Files need to be converted: `{filenames}`") + + params: Dict[str, Tensor] = dict() + for x in tqdm(filelist, desc="Loading the safetensors"): + params_chunk = load_safetensors(x, force_fp32=force_fp32) + if params.keys().isdisjoint(params_chunk.keys()): + params.update(params_chunk) + else: + same_keys = set(params.keys()).intersection(params_chunk.keys()) + raise RuntimeError(f"Duplicated keys found: `{same_keys}`.") + return params + + +def convert(params: Dict[str, Tensor]) -> Dict[str, Tensor]: + # compatibility between MS and PyTorch naming and formating + return params + + +def save(ckpt: Dict[str, Tensor], output: str) -> None: + output = os.path.abspath(output) + logger.info(f"Saving to {output}...") + ms.save_checkpoint(ckpt, output) + logger.info(f"Saving to {output}...Done!") + + +def main(): + parser = argparse.ArgumentParser(description="Convert LLama checkpoints into Mindspore Format") + parser.add_argument("src", help="Directory storing the safetensors") + parser.add_argument("-o", "--output", default="models/llama.ckpt", help="Name of the output Mindspore checkpoint") + parser.add_argument("--force_fp32", action="store_true", help="Force to save the ckpt in fp32 format.") + + args = parser.parse_args() + + params = load(args.src, force_fp32=args.force_fp32) + params = convert(params) + save(params, args.output) + + +if __name__ == "__main__": + fmt = "%(asctime)s %(levelname)s: %(message)s" + datefmt = "[%Y-%m-%d %H:%M:%S]" + logging.basicConfig(level=logging.INFO, format=fmt, datefmt=datefmt) + main() diff --git a/examples/llama3/tools/fileio/__init__.py b/examples/llama3/tools/fileio/__init__.py new file mode 100644 index 0000000000..e64328c9f6 --- /dev/null +++ b/examples/llama3/tools/fileio/__init__.py @@ -0,0 +1 @@ +from .safetensors import load_safetensors diff --git a/examples/llama3/tools/fileio/safetensors.py b/examples/llama3/tools/fileio/safetensors.py new file mode 100644 index 0000000000..8641cbf326 --- /dev/null +++ b/examples/llama3/tools/fileio/safetensors.py @@ -0,0 +1,21 @@ +import os +from typing import Dict, Union + +import numpy as np +import safetensors.numpy + +from mindspore import Parameter, Tensor + + +def load_safetensors(filename: Union[str, os.PathLike], force_fp32: bool = False) -> Dict[str, Tensor]: + flat = safetensors.numpy.load_file(filename) + output = _np2ms(flat, force_fp32=force_fp32) + return output + + +def _np2ms(np_dict: Dict[str, np.ndarray], force_fp32: bool = False) -> Dict[str, Tensor]: + for k, v in np_dict.items(): + if force_fp32: + v = v.astype(np.float32) + np_dict[k] = Parameter(v, name=k) + return np_dict From 5be79a321e3d997c21bcceeb4c7cee5cd0944b32 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Tue, 15 Oct 2024 17:53:36 +0800 Subject: [PATCH 02/14] adapting to movie gen --- examples/llama3/llama/models/__init__.py | 2 +- .../llama3/llama/models/llama/__init__.py | 2 +- examples/llama3/llama/models/llama/layer.py | 160 +++++++----------- examples/llama3/llama/models/llama/network.py | 114 ++++++++++--- examples/llama3/test.py | 10 +- 5 files changed, 156 insertions(+), 132 deletions(-) diff --git a/examples/llama3/llama/models/__init__.py b/examples/llama3/llama/models/__init__.py index a1829d24c5..97d9daa6cd 100644 --- a/examples/llama3/llama/models/__init__.py +++ b/examples/llama3/llama/models/__init__.py @@ -1 +1 @@ -from .llama import LlamaModel +from .llama import LlamaModel, llama3_8B diff --git a/examples/llama3/llama/models/llama/__init__.py b/examples/llama3/llama/models/llama/__init__.py index 73c46f222f..45247a36e7 100644 --- a/examples/llama3/llama/models/llama/__init__.py +++ b/examples/llama3/llama/models/llama/__init__.py @@ -1 +1 @@ -from .network import LlamaModel +from .network import LlamaModel, llama3_8B diff --git a/examples/llama3/llama/models/llama/layer.py b/examples/llama3/llama/models/llama/layer.py index 9c3d7da7f7..c21edbb4b2 100644 --- a/examples/llama3/llama/models/llama/layer.py +++ b/examples/llama3/llama/models/llama/layer.py @@ -27,46 +27,6 @@ def construct(self, hidden_states: Tensor) -> Tensor: return self.weight * hidden_states.to(input_dtype) -class LlamaRotaryEmbedding(nn.Cell): - def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0) -> None: - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.inv_freq = 1.0 / (base ** (ops.arange(0, self.dim, 2, dtype=ms.float32) / self.dim)) - - def construct(self, x: Tensor, position_ids: Tensor) -> Tensor: - # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = ops.broadcast_to(self.inv_freq[None, :, None], (position_ids.shape[0], -1, 1)) - position_ids_expanded = position_ids[:, None, :].to(ms.float32) - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - freqs = ops.matmul(inv_freq_expanded.to(ms.float32), position_ids_expanded.to(ms.float32)) - freqs = ops.transpose(freqs, (0, 2, 1)) - emb = ops.concat((freqs, freqs), axis=-1) - cos = ops.cos(emb) - sin = ops.sin(emb) - output = ops.stack([cos.to(x.dtype), sin.to(x.dtype)]) - return output - - -def rotate_half(x: Tensor) -> Tensor: - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return ops.concat((-x2, x1), axis=-1) - - -def apply_rotary_pos_emb( - q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, unsqueeze_dim: int = 1 -) -> Tuple[Tensor, Tensor]: - cos = ops.unsqueeze(cos, unsqueeze_dim) - sin = ops.unsqueeze(sin, unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - class LlamaMLP(nn.Cell): def __init__( self, @@ -88,9 +48,9 @@ def construct(self, hidden_state: Tensor) -> Tensor: def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor: - batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states + batch, num_key_value_heads, slen, head_dim = hidden_states.shape hidden_states = hidden_states[:, :, None, :, :] hidden_states = ops.broadcast_to(hidden_states, (batch, num_key_value_heads, n_rep, slen, head_dim)) hidden_states = ops.reshape(hidden_states, (batch, num_key_value_heads * n_rep, slen, head_dim)) @@ -104,6 +64,7 @@ def __init__( num_attention_heads: int = 32, num_key_value_heads: int = 8, attention_dropout: float = 0.0, + attention_bias: bool = False, dtype: ms.Type = ms.float32, ) -> None: super().__init__() @@ -120,35 +81,34 @@ def __init__( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) - self.q_proj = nn.Dense(self.hidden_size, self.num_heads * self.head_dim, has_bias=False, dtype=dtype) - self.k_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=False, dtype=dtype) - self.v_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=False, dtype=dtype) - self.o_proj = nn.Dense(self.num_heads * self.head_dim, self.hidden_size, has_bias=False, dtype=dtype) + self.q_proj = nn.Dense(self.hidden_size, self.num_heads * self.head_dim, has_bias=attention_bias, dtype=dtype) + self.k_proj = nn.Dense( + self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=attention_bias, dtype=dtype + ) + self.v_proj = nn.Dense( + self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=attention_bias, dtype=dtype + ) + self.o_proj = nn.Dense(self.num_heads * self.head_dim, self.hidden_size, has_bias=attention_bias, dtype=dtype) - def construct( - self, - hidden_states: Tensor, - position_embeddings: Tensor, - attention_mask: Optional[Tensor] = None, - ) -> Tensor: + def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None) -> Tensor: bsz, q_len, _ = hidden_states.shape + kv_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + _, kv_len, _ = kv_hidden_states.shape + query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + key_states = self.k_proj(kv_hidden_states) + value_states = self.v_proj(kv_hidden_states) query_states = ops.reshape(query_states, (bsz, q_len, self.num_heads, self.head_dim)) query_states = ops.transpose(query_states, (0, 2, 1, 3)) - key_states = ops.reshape(key_states, (bsz, q_len, self.num_key_value_heads, self.head_dim)) + key_states = ops.reshape(key_states, (bsz, kv_len, self.num_key_value_heads, self.head_dim)) key_states = ops.transpose(key_states, (0, 2, 1, 3)) - value_states = ops.reshape(value_states, (bsz, q_len, self.num_key_value_heads, self.head_dim)) + value_states = ops.reshape(value_states, (bsz, kv_len, self.num_key_value_heads, self.head_dim)) value_states = ops.transpose(value_states, (0, 2, 1, 3)) - cos, sin = position_embeddings[0], position_embeddings[1] - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -157,9 +117,6 @@ def construct( # upcast attention to fp32 attn_weights = attn_weights.to(ms.float32) - if attention_mask is not None: - attn_weights = ops.masked_fill(attn_weights, ~attention_mask, -ms.numpy.inf) - attn_weights = ops.softmax(attn_weights, axis=-1).to(query_states.dtype) attn_weights = ops.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = ops.matmul(attn_weights, value_states) @@ -171,62 +128,47 @@ def construct( return attn_output -class LlamaFlashAttention(nn.Cell): +class LlamaFlashAttention(LlamaAttention): def __init__( self, hidden_size: int = 4096, num_attention_heads: int = 32, num_key_value_heads: int = 8, attention_dropout: float = 0.0, + attention_bias: bool = False, dtype: ms.Type = ms.float32, ) -> None: - super().__init__() - - self.attention_dropout = attention_dropout - self.hidden_size = hidden_size - self.num_heads = num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Dense(self.hidden_size, self.num_heads * self.head_dim, has_bias=False, dtype=dtype) - self.k_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=False, dtype=dtype) - self.v_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=False, dtype=dtype) - self.o_proj = nn.Dense(self.num_heads * self.head_dim, self.hidden_size, has_bias=False, dtype=dtype) - + super().__init__( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_dropout=attention_dropout, + attention_bias=attention_bias, + dtype=dtype, + ) self.flash_attention = FlashAttentionScore( self.num_heads, keep_prob=1 - self.attention_dropout, scale_value=self.head_dim**-0.5, input_layout="BSND" ) - def construct( - self, - hidden_states: Tensor, - position_embeddings: Tensor, - attention_mask: Optional[Tensor] = None, - ) -> Tensor: + def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None) -> Tensor: bsz, q_len, _ = hidden_states.shape + kv_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states + _, kv_len, _ = kv_hidden_states.shape + query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + key_states = self.k_proj(kv_hidden_states) + value_states = self.v_proj(kv_hidden_states) query_states = ops.reshape(query_states, (bsz, q_len, self.num_heads, self.head_dim)) query_states = ops.transpose(query_states, (0, 2, 1, 3)) - key_states = ops.reshape(key_states, (bsz, q_len, self.num_key_value_heads, self.head_dim)) + key_states = ops.reshape(key_states, (bsz, kv_len, self.num_key_value_heads, self.head_dim)) key_states = ops.transpose(key_states, (0, 2, 1, 3)) - value_states = ops.reshape(value_states, (bsz, q_len, self.num_key_value_heads, self.head_dim)) + value_states = ops.reshape(value_states, (bsz, kv_len, self.num_key_value_heads, self.head_dim)) value_states = ops.transpose(value_states, (0, 2, 1, 3)) - cos, sin = position_embeddings[0], position_embeddings[1] - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -235,10 +177,34 @@ def construct( key_states = ops.transpose(key_states, (0, 2, 1, 3)) value_states = ops.transpose(value_states, (0, 2, 1, 3)) - _, _, _, attn_output = self.flash_attention( - query_states, key_states, value_states, None, None, None, attention_mask - ) + _, _, _, attn_output = self.flash_attention(query_states, key_states, value_states, None, None, None, None) attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) attn_output = self.o_proj(attn_output) return attn_output + + +class PatchEmbed3D(nn.Cell): + def __init__( + self, + patch_size: Tuple[int, int, int] = (1, 2, 2), + in_channels: int = 8, + hidden_size: int = 4096, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv3d( + in_channels, hidden_size, kernel_size=patch_size, stride=self.patch_size, pad_mode="pad", dtype=dtype + ) + + def construct(self, x: Tensor) -> Tensor: + _, t, _, h, w = x.shape + assert t % self.patch_size[0] == 0 + assert h % self.patch_size[1] == 0 + assert w % self.patch_size[2] == 0 + + x = ops.transpose(x, (0, 2, 1, 3, 4)) + x = self.proj(x) # (B C T H W) + x = x.flatten(start_dim=2).swapaxes(1, 2) + return x diff --git a/examples/llama3/llama/models/llama/network.py b/examples/llama3/llama/models/llama/network.py index 03a1d7ceb9..2c2c5ce6a1 100644 --- a/examples/llama3/llama/models/llama/network.py +++ b/examples/llama3/llama/models/llama/network.py @@ -1,12 +1,12 @@ -from typing import Literal, Optional +from typing import Literal, Tuple import mindspore as ms import mindspore.nn as nn import mindspore.ops as ops -from mindspore import Tensor +from mindspore import Tensor, load_checkpoint from mindspore.common.initializer import Normal, Zero, initializer -from .layer import LlamaAttention, LlamaFlashAttention, LlamaMLP, LlamaRMSNorm, LlamaRotaryEmbedding +from .layer import LlamaAttention, LlamaFlashAttention, LlamaMLP, LlamaRMSNorm, PatchEmbed3D Llama_ATTENTION_CLASSES = { "eager": LlamaAttention, @@ -23,6 +23,7 @@ def __init__( num_key_value_heads: int = 8, rms_norm_eps: float = 1e-5, attention_dropout: float = 0.0, + attention_bias: bool = False, hidden_act: str = "silu", attn_implementation: Literal["eager", "flash_attention"] = "eager", dtype: ms.Type = ms.float32, @@ -35,6 +36,16 @@ def __init__( num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, attention_dropout=attention_dropout, + attention_bias=attention_bias, + dtype=dtype, + ) + + self.cross_attn = Llama_ATTENTION_CLASSES[attn_implementation]( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_dropout=attention_dropout, + attention_bias=attention_bias, dtype=dtype, ) @@ -47,21 +58,23 @@ def __init__( def construct( self, hidden_states: Tensor, + encoder_hidden_states: Tensor, position_embeddings: Tensor, - attention_mask: Optional[Tensor] = None, ) -> Tensor: + # 3.1.3 Add Positional Embedding + hidden_states = hidden_states + position_embeddings.to(hidden_states.dtype) + residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - # Self Attention - hidden_states = self.self_attn( - hidden_states, - position_embeddings, - attention_mask=attention_mask, - ) + # Self Attention (Bi-Directional Attention) + hidden_states = self.self_attn(hidden_states) hidden_states = residual + hidden_states + # 3.1.3 Cross Attention + hidden_states = self.cross_attn(hidden_states, encoder_hidden_states=encoder_hidden_states) + # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) @@ -74,22 +87,25 @@ def construct( class LlamaModel(nn.Cell): def __init__( self, + in_channels: int = 8, hidden_size: int = 4096, intermediate_size: int = 14336, - max_position_embeddings: int = 32768, num_attention_heads: int = 32, num_hidden_layers: int = 32, num_key_value_heads: int = 8, rms_norm_eps: float = 1e-5, - rope_theta: float = 1000000.0, attention_dropout: float = 0.0, + attention_bias: bool = False, hidden_act: str = "silu", initializer_range: float = 0.02, + kernel_size: Tuple[int, int, int] = (1, 2, 2), + max_length: Tuple[int, int, int] = (16, 24, 44), attn_implementation: Literal["eager", "flash_attention"] = "eager", gradient_checkpointing: bool = False, dtype: ms.Type = ms.float32, ) -> None: super().__init__() + self.kernel_size = kernel_size self.layers = nn.CellList( [ LlamaDecoderLayer( @@ -99,6 +115,7 @@ def __init__( num_key_value_heads=num_key_value_heads, rms_norm_eps=rms_norm_eps, attention_dropout=attention_dropout, + attention_bias=attention_bias, hidden_act=hidden_act, attn_implementation=attn_implementation, dtype=dtype, @@ -107,11 +124,12 @@ def __init__( ] ) self.norm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) - self.rotary_emb = LlamaRotaryEmbedding( - hidden_size // num_attention_heads, - max_position_embeddings=max_position_embeddings, - base=rope_theta, - ) + + self.pos_embedding_table_h = nn.Embedding(max_length[0], hidden_size, dtype=dtype) + self.pos_embedding_table_w = nn.Embedding(max_length[1], hidden_size, dtype=dtype) + self.pos_embedding_table_t = nn.Embedding(max_length[2], hidden_size, dtype=dtype) + + self.latent_embedder = PatchEmbed3D(kernel_size, in_channels, hidden_size, dtype=dtype) # post-init self.initializer_range = initializer_range @@ -128,28 +146,68 @@ def _init_weights(module): module.weight.set_data(initializer(Normal(std, 0.0), module.weight.shape, module.weight.dtype)) if module.bias is not None: module.bias.set_data(initializer(Zero(), module.bias.shape, module.bias.dtype)) + elif isinstance(module, nn.Embedding): + module.embedding_table.set_data( + initializer(Normal(std, 0.0), module.embedding_table.shape, module.embedding_table.dtype) + ) self.apply(_init_weights) + def learnable_position_embedding(self, latent_embedding: Tensor) -> Tensor: + # 3.1.3 + _, t, _, h, w = latent_embedding.shape + t_inds = ops.arange(t // self.kernel_size[0], dtype=ms.int64) + h_inds = ops.arange(h // self.kernel_size[1], dtype=ms.int64) + w_inds = ops.arange(w // self.kernel_size[2], dtype=ms.int64) + + position_ids = ops.meshgrid(t_inds, h_inds, w_inds, indexing="ij") + position_ids = ops.stack(position_ids, axis=-1) + position_ids = ops.reshape(position_ids, (-1, 3)) + + h_inds, w_inds, t_inds = ops.unbind(position_ids, dim=-1) + pos_embed_h = self.pos_embedding_table_h(h_inds) + pos_embed_w = self.pos_embedding_table_w(w_inds) + pos_embed_t = self.pos_embedding_table_t(t_inds) + return pos_embed_h + pos_embed_w + pos_embed_t + def construct( self, - inputs_embeds: Tensor, - attention_mask: Optional[Tensor] = None, + latent_embedding: Tensor, + text_embedding: Tensor, ) -> Tensor: - position_ids = ops.arange(0, inputs_embeds.shape[1], dtype=ms.int64) - position_ids = ops.unsqueeze(position_ids, 0) - + """ + latent_embedding: (N, T, C, H, W) tensor of inputs (latent representations of video) + text_embedding: (N, L, C') tensor of the text embedding + """ # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(inputs_embeds, position_ids) + position_embeddings = self.learnable_position_embedding(latent_embedding) + + # patchify + inputs_embeds = self.latent_embedder(latent_embedding) hidden_states = inputs_embeds for decoder_layer in self.layers: - hidden_states = decoder_layer( - hidden_states, - position_embeddings, - attention_mask=attention_mask, - ) + hidden_states = decoder_layer(hidden_states, text_embedding, position_embeddings) hidden_states = self.norm(hidden_states) return hidden_states + + +def llama3_8B(from_pretrained=None, **kwargs): + model = LlamaModel( + attention_bias=False, + attention_dropout=0.0, + hidden_act="silu", + hidden_size=4096, + initializer_range=0.02, + intermediate_size=14336, + num_attention_heads=32, + num_hidden_layers=32, + num_key_value_heads=32, + rms_norm_eps=1e-05, + **kwargs, + ) + if from_pretrained is not None: + load_checkpoint(from_pretrained, model) + return model diff --git a/examples/llama3/test.py b/examples/llama3/test.py index 325ca6c012..2f374767f8 100644 --- a/examples/llama3/test.py +++ b/examples/llama3/test.py @@ -1,5 +1,5 @@ import numpy as np -from llama.models import LlamaModel +from llama.models import llama3_8B import mindspore as ms import mindspore.nn as nn @@ -12,14 +12,14 @@ def count_params(model: nn.Cell) -> int: def main(): ms.set_context(mode=ms.GRAPH_MODE) - network = LlamaModel(attn_implementation="flash_attention", dtype=ms.bfloat16) - ms.load_checkpoint("model.ckpt", network) + network = llama3_8B(attn_implementation="flash_attention", dtype=ms.bfloat16) params = count_params(network) print(f"Parameter number: {params:,}") - inputs = ms.Tensor(np.ones((4, 256, 4096)), dtype=ms.bfloat16) - outputs = network(inputs) + latent_embedding = ms.Tensor(np.ones((1, 16, 8, 24, 44)), dtype=ms.bfloat16) + text_embedding = ms.Tensor(np.ones((1, 64, 256)), dtype=ms.bfloat16) + outputs = network(latent_embedding, text_embedding) print(outputs.shape) print(outputs) From be0c8b65ae4ed4a52214c6fba7c77753e53f015f Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 16 Oct 2024 15:49:37 +0800 Subject: [PATCH 03/14] model adapted to movie gen --- examples/llama3/llama/__init__.py | 1 + examples/llama3/llama/models/__init__.py | 2 +- .../llama3/llama/models/llama/__init__.py | 2 +- examples/llama3/llama/models/llama/layer.py | 41 +++- examples/llama3/llama/models/llama/network.py | 215 ++++++++++++++---- examples/llama3/test.py | 10 +- 6 files changed, 225 insertions(+), 46 deletions(-) diff --git a/examples/llama3/llama/__init__.py b/examples/llama3/llama/__init__.py index e69de29bb2..aed4fa323c 100644 --- a/examples/llama3/llama/__init__.py +++ b/examples/llama3/llama/__init__.py @@ -0,0 +1 @@ +from .models import * diff --git a/examples/llama3/llama/models/__init__.py b/examples/llama3/llama/models/__init__.py index 97d9daa6cd..1dd581a786 100644 --- a/examples/llama3/llama/models/__init__.py +++ b/examples/llama3/llama/models/__init__.py @@ -1 +1 @@ -from .llama import LlamaModel, llama3_8B +from .llama import * diff --git a/examples/llama3/llama/models/llama/__init__.py b/examples/llama3/llama/models/llama/__init__.py index 45247a36e7..6cf34ce83b 100644 --- a/examples/llama3/llama/models/llama/__init__.py +++ b/examples/llama3/llama/models/llama/__init__.py @@ -1 +1 @@ -from .network import LlamaModel, llama3_8B +from .network import * diff --git a/examples/llama3/llama/models/llama/layer.py b/examples/llama3/llama/models/llama/layer.py index c21edbb4b2..35ae260390 100644 --- a/examples/llama3/llama/models/llama/layer.py +++ b/examples/llama3/llama/models/llama/layer.py @@ -195,7 +195,13 @@ def __init__( super().__init__() self.patch_size = patch_size self.proj = nn.Conv3d( - in_channels, hidden_size, kernel_size=patch_size, stride=self.patch_size, pad_mode="pad", dtype=dtype + in_channels, + hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + pad_mode="pad", + has_bias=False, + dtype=dtype, ) def construct(self, x: Tensor) -> Tensor: @@ -208,3 +214,36 @@ def construct(self, x: Tensor) -> Tensor: x = self.proj(x) # (B C T H W) x = x.flatten(start_dim=2).swapaxes(1, 2) return x + + +class TimestepEmbedder(nn.Cell): + def __init__( + self, + hidden_size: int, + frequency_embedding_size: int = 256, + hidden_act: str = "silu", + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.mlp = nn.SequentialCell( + nn.Dense(frequency_embedding_size, hidden_size, has_bias=False, dtype=dtype), + ACT2FN[hidden_act], + nn.Dense(hidden_size, hidden_size, has_bias=False, dtype=dtype), + ) + self.frequency_embedding_size = frequency_embedding_size + self.dtype = dtype + + @staticmethod + def timestep_embedding(t: Tensor, dim: int, max_period: int = 10000) -> Tensor: + half = dim // 2 + freqs = ops.exp(-ms.numpy.log(max_period) * ops.arange(start=0, end=half, dtype=ms.float32) / half) + args = t[:, None].to(ms.float32) * freqs[None] + embedding = ops.concat([ops.cos(args), ops.sin(args)], axis=-1) + if dim % 2: + embedding = ops.concat([embedding, ops.zeros_like(embedding[:, :1])], axis=-1) + return embedding + + def construct(self, t: Tensor) -> Tensor: + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq.to(self.dtype)) + return t_emb diff --git a/examples/llama3/llama/models/llama/network.py b/examples/llama3/llama/models/llama/network.py index 2c2c5ce6a1..a075fc08f3 100644 --- a/examples/llama3/llama/models/llama/network.py +++ b/examples/llama3/llama/models/llama/network.py @@ -3,10 +3,14 @@ import mindspore as ms import mindspore.nn as nn import mindspore.ops as ops -from mindspore import Tensor, load_checkpoint -from mindspore.common.initializer import Normal, Zero, initializer +from mindspore import Parameter, Tensor, load_checkpoint -from .layer import LlamaAttention, LlamaFlashAttention, LlamaMLP, LlamaRMSNorm, PatchEmbed3D +from mindone.models.utils import normal_, zeros_ + +from ..activation import ACT2FN +from .layer import LlamaAttention, LlamaFlashAttention, LlamaMLP, LlamaRMSNorm, PatchEmbed3D, TimestepEmbedder + +__all__ = ["LlamaModel", "llama3_1B", "llama3_5B", "llama3_30B"] Llama_ATTENTION_CLASSES = { "eager": LlamaAttention, @@ -14,6 +18,10 @@ } +def t2i_modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: + return x * (1 + scale) + shift + + class LlamaDecoderLayer(nn.Cell): def __init__( self, @@ -52,6 +60,9 @@ def __init__( self.mlp = LlamaMLP( intermediate_size=intermediate_size, hidden_size=hidden_size, hidden_act=hidden_act, dtype=dtype ) + + self.scale_shift_table = Parameter(ops.randn(6, hidden_size, dtype=dtype) / hidden_size**0.5) + self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) @@ -59,31 +70,66 @@ def construct( self, hidden_states: Tensor, encoder_hidden_states: Tensor, - position_embeddings: Tensor, + modulation_parameters: Tensor, + position_embedding: Tensor, ) -> Tensor: - # 3.1.3 Add Positional Embedding - hidden_states = hidden_states + position_embeddings.to(hidden_states.dtype) + B = hidden_states.shape[0] - residual = hidden_states + # 3.1.3 Positional Embedding + hidden_states = hidden_states + position_embedding.to(hidden_states.dtype) - hidden_states = self.input_layernorm(hidden_states) + # 3.1.3 Adaptive Layer Norm + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ops.chunk( + self.scale_shift_table[None] + modulation_parameters.reshape(B, 6, -1), 6, axis=1 + ) # Self Attention (Bi-Directional Attention) + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = t2i_modulate(hidden_states, shift_msa, scale_msa) hidden_states = self.self_attn(hidden_states) + hidden_states = gate_msa * hidden_states hidden_states = residual + hidden_states # 3.1.3 Cross Attention + residual = hidden_states hidden_states = self.cross_attn(hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = t2i_modulate(hidden_states, shift_mlp, scale_mlp) hidden_states = self.mlp(hidden_states) + hidden_states = gate_mlp * hidden_states hidden_states = residual + hidden_states return hidden_states +class LlamaFinalLayer(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + patch_size: Tuple[int, int, int] = (1, 2, 2), + out_channels: int = 8, + rms_norm_eps: float = 1e-5, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) + self.proj = nn.Dense( + hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, has_bias=False, dtype=dtype + ) + self.scale_shift_table = Parameter(ops.randn(2, hidden_size, dtype=dtype) / hidden_size**0.5) + + def construct(self, hidden_states: Tensor, timestep_embedding: Tensor): + shift, scale = ops.chunk(self.scale_shift_table[None] + timestep_embedding[:, None], 2, axis=1) + hidden_states = t2i_modulate(self.input_layernorm(hidden_states), shift, scale) + hidden_states = self.proj(hidden_states) + return hidden_states + + class LlamaModel(nn.Cell): def __init__( self, @@ -98,18 +144,22 @@ def __init__( attention_bias: bool = False, hidden_act: str = "silu", initializer_range: float = 0.02, - kernel_size: Tuple[int, int, int] = (1, 2, 2), + patch_size: Tuple[int, int, int] = (1, 2, 2), max_length: Tuple[int, int, int] = (16, 24, 44), attn_implementation: Literal["eager", "flash_attention"] = "eager", gradient_checkpointing: bool = False, dtype: ms.Type = ms.float32, ) -> None: super().__init__() - self.kernel_size = kernel_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.layers = nn.CellList( [ LlamaDecoderLayer( - hidden_size=hidden_size, + hidden_size=self.hidden_size, intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, @@ -123,13 +173,23 @@ def __init__( for _ in range(num_hidden_layers) ] ) - self.norm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) + self.final_layer = LlamaFinalLayer( + hidden_size=self.hidden_size, + patch_size=patch_size, + out_channels=self.out_channels, + rms_norm_eps=rms_norm_eps, + dtype=dtype, + ) - self.pos_embedding_table_h = nn.Embedding(max_length[0], hidden_size, dtype=dtype) - self.pos_embedding_table_w = nn.Embedding(max_length[1], hidden_size, dtype=dtype) - self.pos_embedding_table_t = nn.Embedding(max_length[2], hidden_size, dtype=dtype) + self.pos_embedding_table_h = nn.Embedding(max_length[0], self.hidden_size, dtype=dtype) + self.pos_embedding_table_w = nn.Embedding(max_length[1], self.hidden_size, dtype=dtype) + self.pos_embedding_table_t = nn.Embedding(max_length[2], self.hidden_size, dtype=dtype) - self.latent_embedder = PatchEmbed3D(kernel_size, in_channels, hidden_size, dtype=dtype) + self.latent_embedder = PatchEmbed3D(patch_size, self.in_channels, self.hidden_size, dtype=dtype) + self.timestep_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype) + self.adaLN_modulation = nn.SequentialCell( + ACT2FN[hidden_act], nn.Dense(hidden_size, 6 * hidden_size, has_bias=False, dtype=dtype) + ) # post-init self.initializer_range = initializer_range @@ -140,25 +200,39 @@ def __init__( self.layers.recompute() def init_weights(self): + std = self.initializer_range + def _init_weights(module): - std = self.initializer_range if isinstance(module, nn.Dense): - module.weight.set_data(initializer(Normal(std, 0.0), module.weight.shape, module.weight.dtype)) + normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.set_data(initializer(Zero(), module.bias.shape, module.bias.dtype)) + zeros_(module.weight) elif isinstance(module, nn.Embedding): - module.embedding_table.set_data( - initializer(Normal(std, 0.0), module.embedding_table.shape, module.embedding_table.dtype) - ) + normal_(module.embedding_table, mean=0.0, std=std) self.apply(_init_weights) + # Initialize patch_embed like nn.Dense (instead of nn.Conv3d): + normal_(self.latent_embedder.proj.weight, mean=0.0, std=std) + if self.latent_embedder.proj.bias is not None: + zeros_(self.latent_embedder.proj.bias) + + # Zero-out adaLN modulation layer: + zeros_(self.adaLN_modulation[-1].weight) + if self.adaLN_modulation[-1].bias is not None: + zeros_(self.adaLN_modulation[-1].bias) + + # Zero-out final layer as DiT does + zeros_(self.final_layer.proj.weight) + if self.final_layer.proj.bias is not None: + zeros_(self.final_layer.proj.bias) + def learnable_position_embedding(self, latent_embedding: Tensor) -> Tensor: # 3.1.3 _, t, _, h, w = latent_embedding.shape - t_inds = ops.arange(t // self.kernel_size[0], dtype=ms.int64) - h_inds = ops.arange(h // self.kernel_size[1], dtype=ms.int64) - w_inds = ops.arange(w // self.kernel_size[2], dtype=ms.int64) + t_inds = ops.arange(t // self.patch_size[0], dtype=ms.int64) + h_inds = ops.arange(h // self.patch_size[1], dtype=ms.int64) + w_inds = ops.arange(w // self.patch_size[2], dtype=ms.int64) position_ids = ops.meshgrid(t_inds, h_inds, w_inds, indexing="ij") position_ids = ops.stack(position_ids, axis=-1) @@ -170,41 +244,106 @@ def learnable_position_embedding(self, latent_embedding: Tensor) -> Tensor: pos_embed_t = self.pos_embedding_table_t(t_inds) return pos_embed_h + pos_embed_w + pos_embed_t + def unpatchify(self, hidden_states: Tensor, t: int, h: int, w: int) -> Tensor: + """ + hidden_states: (N, T, patch_size[0] * patch_size[1] * patch_size[2] * C) + """ + bs = hidden_states.shape[0] + c = self.out_channels + p0, p1, p2 = self.patch_size[0], self.patch_size[1], self.patch_size[2] + nt, nh, nw = t // p0, h // p1, w // p2 + + hidden_states = ops.reshape(hidden_states, (bs, nt, nh, nw, p0, p1, p2, c)) + # bs, nt, p0, c, nh, p1, nw, p2, c + hidden_states = ops.transpose(hidden_states, (0, 1, 4, 7, 2, 5, 3, 6)) + output = ops.reshape(hidden_states, (bs, nt * p0, c, nh * p1, nw * p2)) + return output + def construct( self, latent_embedding: Tensor, + timestep: Tensor, text_embedding: Tensor, ) -> Tensor: """ latent_embedding: (N, T, C, H, W) tensor of inputs (latent representations of video) + timestep: (N,) tensor to indicate denoising step text_embedding: (N, L, C') tensor of the text embedding """ - # create position embeddings to be shared across the decoder layers - position_embeddings = self.learnable_position_embedding(latent_embedding) + _, t, _, h, w = latent_embedding.shape + + # create position embedding to be shared across the decoder layers + position_embedding = self.learnable_position_embedding(latent_embedding) - # patchify - inputs_embeds = self.latent_embedder(latent_embedding) + # patchify and embed latent in transformer hidden dim. + latent_embedding = self.latent_embedder(latent_embedding) - hidden_states = inputs_embeds + # 6.1.2 shared timestep embedding & modulation. It does not mention the detail structure, we follow PixArt-Alpha here + timestep_embedding = self.timestep_embedder(timestep) + modulation_parameters = self.adaLN_modulation(timestep_embedding) + # main block + hidden_states = latent_embedding for decoder_layer in self.layers: - hidden_states = decoder_layer(hidden_states, text_embedding, position_embeddings) + hidden_states = decoder_layer(hidden_states, text_embedding, modulation_parameters, position_embedding) - hidden_states = self.norm(hidden_states) - return hidden_states + # final layer + hidden_states = self.final_layer(hidden_states, timestep_embedding) + + # unpatchify + output = self.unpatchify(hidden_states, t, h, w) + return output -def llama3_8B(from_pretrained=None, **kwargs): +def llama3_1B(from_pretrained=None, **kwargs): model = LlamaModel( attention_bias=False, attention_dropout=0.0, hidden_act="silu", - hidden_size=4096, + hidden_size=1536, initializer_range=0.02, - intermediate_size=14336, - num_attention_heads=32, + intermediate_size=4096, + num_attention_heads=16, + num_hidden_layers=24, + num_key_value_heads=16, + rms_norm_eps=1e-05, + **kwargs, + ) + if from_pretrained is not None: + load_checkpoint(from_pretrained, model) + return model + + +def llama3_5B(from_pretrained=None, **kwargs): + model = LlamaModel( + attention_bias=False, + attention_dropout=0.0, + hidden_act="silu", + hidden_size=3072, + initializer_range=0.02, + intermediate_size=8192, + num_attention_heads=24, num_hidden_layers=32, - num_key_value_heads=32, + num_key_value_heads=24, + rms_norm_eps=1e-05, + **kwargs, + ) + if from_pretrained is not None: + load_checkpoint(from_pretrained, model) + return model + + +def llama3_30B(from_pretrained=None, **kwargs): + model = LlamaModel( + attention_bias=False, + attention_dropout=0.0, + hidden_act="silu", + hidden_size=6144, + initializer_range=0.02, + intermediate_size=16384, + num_attention_heads=48, + num_hidden_layers=48, + num_key_value_heads=48, rms_norm_eps=1e-05, **kwargs, ) diff --git a/examples/llama3/test.py b/examples/llama3/test.py index 2f374767f8..510003d062 100644 --- a/examples/llama3/test.py +++ b/examples/llama3/test.py @@ -1,5 +1,5 @@ import numpy as np -from llama.models import llama3_8B +from llama.models import llama3_1B import mindspore as ms import mindspore.nn as nn @@ -12,17 +12,17 @@ def count_params(model: nn.Cell) -> int: def main(): ms.set_context(mode=ms.GRAPH_MODE) - network = llama3_8B(attn_implementation="flash_attention", dtype=ms.bfloat16) + network = llama3_1B(attn_implementation="flash_attention", dtype=ms.bfloat16) params = count_params(network) print(f"Parameter number: {params:,}") latent_embedding = ms.Tensor(np.ones((1, 16, 8, 24, 44)), dtype=ms.bfloat16) - text_embedding = ms.Tensor(np.ones((1, 64, 256)), dtype=ms.bfloat16) - outputs = network(latent_embedding, text_embedding) + timestep = ms.Tensor([35], dtype=ms.int64) + text_embedding = ms.Tensor(np.ones((1, 64, network.hidden_size)), dtype=ms.bfloat16) + outputs = network(latent_embedding, timestep, text_embedding) print(outputs.shape) - print(outputs) if __name__ == "__main__": From f6561e2390e64a0326fa2500550c23b25996d633 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 16 Oct 2024 16:10:57 +0800 Subject: [PATCH 04/14] add caption embedder --- examples/llama3/llama/models/llama/layer.py | 18 ++++++++++++++++++ examples/llama3/llama/models/llama/network.py | 15 ++++++++++++++- examples/llama3/test.py | 2 +- 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/examples/llama3/llama/models/llama/layer.py b/examples/llama3/llama/models/llama/layer.py index 35ae260390..3c73dfc3b3 100644 --- a/examples/llama3/llama/models/llama/layer.py +++ b/examples/llama3/llama/models/llama/layer.py @@ -247,3 +247,21 @@ def construct(self, t: Tensor) -> Tensor: t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq.to(self.dtype)) return t_emb + + +class CaptionEmbedder(nn.Cell): + def __init__( + self, + in_channels: int, + hidden_size: int, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.proj = nn.SequentialCell( + nn.Dense(in_channels, hidden_size, has_bias=False, dtype=dtype), + nn.LayerNorm((hidden_size,), dtype=dtype), + ) + + def construct(self, caption: Tensor) -> Tensor: + caption_emb = self.proj(caption) + return caption_emb diff --git a/examples/llama3/llama/models/llama/network.py b/examples/llama3/llama/models/llama/network.py index a075fc08f3..f6eb9ded8b 100644 --- a/examples/llama3/llama/models/llama/network.py +++ b/examples/llama3/llama/models/llama/network.py @@ -8,7 +8,15 @@ from mindone.models.utils import normal_, zeros_ from ..activation import ACT2FN -from .layer import LlamaAttention, LlamaFlashAttention, LlamaMLP, LlamaRMSNorm, PatchEmbed3D, TimestepEmbedder +from .layer import ( + CaptionEmbedder, + LlamaAttention, + LlamaFlashAttention, + LlamaMLP, + LlamaRMSNorm, + PatchEmbed3D, + TimestepEmbedder, +) __all__ = ["LlamaModel", "llama3_1B", "llama3_5B", "llama3_30B"] @@ -146,6 +154,7 @@ def __init__( initializer_range: float = 0.02, patch_size: Tuple[int, int, int] = (1, 2, 2), max_length: Tuple[int, int, int] = (16, 24, 44), + caption_channels: int = 4096, attn_implementation: Literal["eager", "flash_attention"] = "eager", gradient_checkpointing: bool = False, dtype: ms.Type = ms.float32, @@ -190,6 +199,7 @@ def __init__( self.adaLN_modulation = nn.SequentialCell( ACT2FN[hidden_act], nn.Dense(hidden_size, 6 * hidden_size, has_bias=False, dtype=dtype) ) + self.caption_embedder = CaptionEmbedder(caption_channels, hidden_size, dtype=dtype) # post-init self.initializer_range = initializer_range @@ -282,6 +292,9 @@ def construct( timestep_embedding = self.timestep_embedder(timestep) modulation_parameters = self.adaLN_modulation(timestep_embedding) + # 3.1.4 text embedding + text_embedding = self.caption_embedder(text_embedding) + # main block hidden_states = latent_embedding for decoder_layer in self.layers: diff --git a/examples/llama3/test.py b/examples/llama3/test.py index 510003d062..d57dac40b0 100644 --- a/examples/llama3/test.py +++ b/examples/llama3/test.py @@ -19,7 +19,7 @@ def main(): latent_embedding = ms.Tensor(np.ones((1, 16, 8, 24, 44)), dtype=ms.bfloat16) timestep = ms.Tensor([35], dtype=ms.int64) - text_embedding = ms.Tensor(np.ones((1, 64, network.hidden_size)), dtype=ms.bfloat16) + text_embedding = ms.Tensor(np.ones((1, 64, 4096)), dtype=ms.bfloat16) outputs = network(latent_embedding, timestep, text_embedding) print(outputs.shape) From 9249fc39cc11bf7e26e3c6410fdc271757258a3d Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 16 Oct 2024 16:18:25 +0800 Subject: [PATCH 05/14] use fp32 layernorm --- examples/llama3/llama/models/llama/layer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/llama3/llama/models/llama/layer.py b/examples/llama3/llama/models/llama/layer.py index 3c73dfc3b3..76458ff676 100644 --- a/examples/llama3/llama/models/llama/layer.py +++ b/examples/llama3/llama/models/llama/layer.py @@ -27,6 +27,16 @@ def construct(self, hidden_states: Tensor) -> Tensor: return self.weight * hidden_states.to(input_dtype) +class LlamaLayerNorm(nn.LayerNorm): + def construct(self, hidden_states: Tensor) -> Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(ms.float32) + hidden_states, _, _ = self.layer_norm( + hidden_states, self.gamma.to(hidden_states.dtype), self.beta.to(hidden_states.dtype) + ) + return hidden_states.to(input_dtype) + + class LlamaMLP(nn.Cell): def __init__( self, @@ -259,7 +269,7 @@ def __init__( super().__init__() self.proj = nn.SequentialCell( nn.Dense(in_channels, hidden_size, has_bias=False, dtype=dtype), - nn.LayerNorm((hidden_size,), dtype=dtype), + LlamaLayerNorm((hidden_size,), dtype=dtype), ) def construct(self, caption: Tensor) -> Tensor: From 61d9bf575c0aa5e56fa05ca3b7ef7709e19d7387 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Thu, 17 Oct 2024 10:51:10 +0800 Subject: [PATCH 06/14] add opensora wrapper --- examples/llama3/llama/models/llama/network.py | 20 ++++---- .../opensora/models/stdit/__init__.py | 1 + .../opensora/models/stdit/stdit_llama3.py | 50 +++++++++++++++++++ examples/opensora_hpcai/scripts/args_train.py | 6 ++- examples/opensora_hpcai/scripts/train.py | 5 +- mindone/trainers/zero.py | 14 +++++- 6 files changed, 84 insertions(+), 12 deletions(-) create mode 100644 examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py diff --git a/examples/llama3/llama/models/llama/network.py b/examples/llama3/llama/models/llama/network.py index f6eb9ded8b..69739e6f60 100644 --- a/examples/llama3/llama/models/llama/network.py +++ b/examples/llama3/llama/models/llama/network.py @@ -1,4 +1,4 @@ -from typing import Literal, Tuple +from typing import Literal, Optional, Tuple import mindspore as ms import mindspore.nn as nn @@ -45,7 +45,6 @@ def __init__( dtype: ms.Type = ms.float32, ) -> None: super().__init__() - self.hidden_size = hidden_size self.self_attn = Llama_ATTENTION_CLASSES[attn_implementation]( hidden_size=hidden_size, @@ -142,6 +141,7 @@ class LlamaModel(nn.Cell): def __init__( self, in_channels: int = 8, + out_channels: Optional[int] = None, hidden_size: int = 4096, intermediate_size: int = 14336, num_attention_heads: int = 32, @@ -162,16 +162,18 @@ def __init__( super().__init__() self.patch_size = patch_size self.in_channels = in_channels - self.out_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads self.layers = nn.CellList( [ LlamaDecoderLayer( hidden_size=self.hidden_size, intermediate_size=intermediate_size, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, rms_norm_eps=rms_norm_eps, attention_dropout=attention_dropout, attention_bias=attention_bias, @@ -184,7 +186,7 @@ def __init__( ) self.final_layer = LlamaFinalLayer( hidden_size=self.hidden_size, - patch_size=patch_size, + patch_size=self.patch_size, out_channels=self.out_channels, rms_norm_eps=rms_norm_eps, dtype=dtype, @@ -194,12 +196,12 @@ def __init__( self.pos_embedding_table_w = nn.Embedding(max_length[1], self.hidden_size, dtype=dtype) self.pos_embedding_table_t = nn.Embedding(max_length[2], self.hidden_size, dtype=dtype) - self.latent_embedder = PatchEmbed3D(patch_size, self.in_channels, self.hidden_size, dtype=dtype) + self.latent_embedder = PatchEmbed3D(self.patch_size, self.in_channels, self.hidden_size, dtype=dtype) self.timestep_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype) self.adaLN_modulation = nn.SequentialCell( - ACT2FN[hidden_act], nn.Dense(hidden_size, 6 * hidden_size, has_bias=False, dtype=dtype) + ACT2FN[hidden_act], nn.Dense(self.hidden_size, 6 * self.hidden_size, has_bias=False, dtype=dtype) ) - self.caption_embedder = CaptionEmbedder(caption_channels, hidden_size, dtype=dtype) + self.caption_embedder = CaptionEmbedder(caption_channels, self.hidden_size, dtype=dtype) # post-init self.initializer_range = initializer_range diff --git a/examples/opensora_hpcai/opensora/models/stdit/__init__.py b/examples/opensora_hpcai/opensora/models/stdit/__init__.py index 7957e9000f..dc2c63cb06 100644 --- a/examples/opensora_hpcai/opensora/models/stdit/__init__.py +++ b/examples/opensora_hpcai/opensora/models/stdit/__init__.py @@ -1,3 +1,4 @@ from .stdit import STDiT_XL_2 from .stdit2 import STDiT2_XL_2 from .stdit3 import STDiT3_3B_2, STDiT3_XL_2 +from .stdit_llama3 import STDiTLlama3Wrapper diff --git a/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py b/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py new file mode 100644 index 0000000000..8e4a264dab --- /dev/null +++ b/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py @@ -0,0 +1,50 @@ +from typing import Literal, Optional + +from llama3.llama import llama3_1B, llama3_5B, llama3_30B + +import mindspore.nn as nn +from mindspore import Tensor + + +class STDiTLlama3Wrapper(nn.Cell): + def __init__(self, model_size: Literal["1B", "5B", "30B"] = "1B", **kwargs): + super().__init__(auto_prefix=False) + + attn_implementation = "flash_attention" if kwargs.get("enable_flashattn", False) else "eager" + gradient_checkpointing = kwargs.get("use_recompute", False) + + model_kwargs = dict( + in_channels=4, + out_channels=8, + attn_implementatio=attn_implementation, + gradient_checkpointing=gradient_checkpointing, + ) + + if model_size == "1B": + self.llama = llama3_1B(**model_kwargs) + elif model_size == "5B": + self.llama = llama3_5B(**model_kwargs) + else: + self.llama = llama3_30B(**model_kwargs) + + self.patch_size = self.llama.patch_size + self.hidden_size = self.llama.hidden_size + self.num_heads = self.llama.num_attention_heads + self.input_sq_size = None + self.in_channels = self.llama.in_channels + + def construct( + self, + x: Tensor, + timestep: Tensor, + y: Tensor, + mask: Optional[Tensor] = None, + frames_mask: Optional[Tensor] = None, + fps: Optional[Tensor] = None, + height: Optional[Tensor] = None, + width: Optional[Tensor] = None, + **kwargs, + ): + latent_embedding = x + text_embedding = y + return self.llama(latent_embedding, timestep, text_embedding) diff --git a/examples/opensora_hpcai/scripts/args_train.py b/examples/opensora_hpcai/scripts/args_train.py index c140ca36bb..5d2e0a1aa1 100644 --- a/examples/opensora_hpcai/scripts/args_train.py +++ b/examples/opensora_hpcai/scripts/args_train.py @@ -49,7 +49,11 @@ def parse_train_args(parser): ) # model parser.add_argument( - "--model_version", default="v1", type=str, choices=["v1", "v1.1"], help="OpenSora model version." + "--model_version", + default="v1", + type=str, + choices=["v1", "v1.1", "v1.2", "llama3_1b"], + help="OpenSora model version.", ) parser.add_argument( "--pretrained_model_path", diff --git a/examples/opensora_hpcai/scripts/train.py b/examples/opensora_hpcai/scripts/train.py index 1b4146c666..9bae6f39db 100644 --- a/examples/opensora_hpcai/scripts/train.py +++ b/examples/opensora_hpcai/scripts/train.py @@ -26,7 +26,7 @@ from opensora.acceleration.parallel_states import create_parallel_group from opensora.datasets.aspect import ASPECT_RATIOS, get_image_size from opensora.models.layers.operation_selector import set_dynamic_mode -from opensora.models.stdit import STDiT2_XL_2, STDiT3_XL_2, STDiT_XL_2 +from opensora.models.stdit import STDiT2_XL_2, STDiT3_XL_2, STDiT_XL_2, STDiTLlama3Wrapper from opensora.models.vae.vae import SD_CONFIG, OpenSoraVAE_V1_2, VideoAutoencoderKL from opensora.pipelines import ( DiffusionWithLoss, @@ -455,6 +455,9 @@ def main(args): model_extra_args["qk_norm"] = True model_extra_args["freeze_y_embedder"] = args.freeze_y_embedder latte_model = STDiT3_XL_2(**model_extra_args) + elif args.model_version == "llama3_1b": + model_name = "Llama3-1B" + latte_model = STDiTLlama3Wrapper(model_size="1B") else: raise ValueError(f"Unknown model version: {args.model_version}") logger.info(f"{model_name} input size: {latent_size if args.bucket_config is None else 'Variable'}") diff --git a/mindone/trainers/zero.py b/mindone/trainers/zero.py index 478904a149..8d5e1c5353 100644 --- a/mindone/trainers/zero.py +++ b/mindone/trainers/zero.py @@ -587,7 +587,19 @@ def prepare_train_network( is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL if not is_parallel and zero_stage == 0: _logger.info("No need prepare train_network with zero.") - return network, optimizer + train_network = TrainOneStepWrapper( + network, + optimizer, + scale_sense=scale_sense, + ema=ema, + updates=updates, + drop_overflow_update=drop_overflow_update, + gradient_accumulation_steps=gradient_accumulation_steps, + clip_grad=clip_grad, + clip_norm=clip_norm, + verbose=verbose, + ) + return train_network if zero_stage not in [0, 1, 2, 3]: raise ValueError("Not support zero_stage {zero_stage}") From bd7fe307c19dde1d0bb5f1a70b58306b1ee794f3 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Thu, 17 Oct 2024 12:52:14 +0800 Subject: [PATCH 07/14] fix embedding & fix wrapper --- examples/llama3/llama/models/llama/network.py | 14 +++++++------- .../opensora/models/stdit/stdit_llama3.py | 11 ++++++++--- examples/opensora_hpcai/scripts/inference.py | 5 ++++- examples/opensora_hpcai/scripts/train.py | 2 +- 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/examples/llama3/llama/models/llama/network.py b/examples/llama3/llama/models/llama/network.py index 69739e6f60..2537f6e630 100644 --- a/examples/llama3/llama/models/llama/network.py +++ b/examples/llama3/llama/models/llama/network.py @@ -153,7 +153,7 @@ def __init__( hidden_act: str = "silu", initializer_range: float = 0.02, patch_size: Tuple[int, int, int] = (1, 2, 2), - max_length: Tuple[int, int, int] = (16, 24, 44), + max_length: Tuple[int, int, int] = (32, 16, 16), caption_channels: int = 4096, attn_implementation: Literal["eager", "flash_attention"] = "eager", gradient_checkpointing: bool = False, @@ -192,9 +192,9 @@ def __init__( dtype=dtype, ) - self.pos_embedding_table_h = nn.Embedding(max_length[0], self.hidden_size, dtype=dtype) - self.pos_embedding_table_w = nn.Embedding(max_length[1], self.hidden_size, dtype=dtype) - self.pos_embedding_table_t = nn.Embedding(max_length[2], self.hidden_size, dtype=dtype) + self.pos_embedding_table_t = nn.Embedding(max_length[0], self.hidden_size, dtype=dtype) + self.pos_embedding_table_h = nn.Embedding(max_length[1], self.hidden_size, dtype=dtype) + self.pos_embedding_table_w = nn.Embedding(max_length[2], self.hidden_size, dtype=dtype) self.latent_embedder = PatchEmbed3D(self.patch_size, self.in_channels, self.hidden_size, dtype=dtype) self.timestep_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype) @@ -250,11 +250,11 @@ def learnable_position_embedding(self, latent_embedding: Tensor) -> Tensor: position_ids = ops.stack(position_ids, axis=-1) position_ids = ops.reshape(position_ids, (-1, 3)) - h_inds, w_inds, t_inds = ops.unbind(position_ids, dim=-1) + t_inds, h_inds, w_inds = ops.unbind(position_ids, dim=-1) + pos_embed_t = self.pos_embedding_table_t(t_inds) pos_embed_h = self.pos_embedding_table_h(h_inds) pos_embed_w = self.pos_embedding_table_w(w_inds) - pos_embed_t = self.pos_embedding_table_t(t_inds) - return pos_embed_h + pos_embed_w + pos_embed_t + return pos_embed_t + pos_embed_h + pos_embed_w def unpatchify(self, hidden_states: Tensor, t: int, h: int, w: int) -> Tensor: """ diff --git a/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py b/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py index 8e4a264dab..2a2a4826bf 100644 --- a/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py +++ b/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py @@ -3,6 +3,7 @@ from llama3.llama import llama3_1B, llama3_5B, llama3_30B import mindspore.nn as nn +import mindspore.ops as ops from mindspore import Tensor @@ -16,7 +17,7 @@ def __init__(self, model_size: Literal["1B", "5B", "30B"] = "1B", **kwargs): model_kwargs = dict( in_channels=4, out_channels=8, - attn_implementatio=attn_implementation, + attn_implementation=attn_implementation, gradient_checkpointing=gradient_checkpointing, ) @@ -44,7 +45,11 @@ def construct( height: Optional[Tensor] = None, width: Optional[Tensor] = None, **kwargs, - ): + ) -> Tensor: + x = ops.transpose(x, (0, 2, 1, 3, 4)) + y = ops.squeeze(y, axis=1) latent_embedding = x text_embedding = y - return self.llama(latent_embedding, timestep, text_embedding) + output = self.llama(latent_embedding, timestep, text_embedding) + output = ops.transpose(output, (0, 2, 1, 3, 4)) + return output diff --git a/examples/opensora_hpcai/scripts/inference.py b/examples/opensora_hpcai/scripts/inference.py index 3e0defe0a1..2f7b02fc4d 100644 --- a/examples/opensora_hpcai/scripts/inference.py +++ b/examples/opensora_hpcai/scripts/inference.py @@ -20,7 +20,7 @@ from opensora.acceleration.parallel_states import set_sequence_parallel_group from opensora.datasets.aspect import ASPECT_RATIO_MAP, ASPECT_RATIOS, get_image_size, get_num_frames -from opensora.models.stdit import STDiT2_XL_2, STDiT3_XL_2, STDiT_XL_2 +from opensora.models.stdit import STDiT2_XL_2, STDiT3_XL_2, STDiT_XL_2, STDiTLlama3Wrapper from opensora.models.text_encoder.t5 import get_text_encoder_and_tokenizer from opensora.models.vae.vae import SD_CONFIG, OpenSoraVAE_V1_2, VideoAutoencoderKL from opensora.pipelines import InferPipeline, InferPipelineFiTLike @@ -253,6 +253,9 @@ def main(args): model_extra_args["qk_norm"] = True logger.info(f"{model_name} init") latte_model = STDiT3_XL_2(**model_extra_args) + elif args.model_version == "llama3_1b": + model_name = "Llama3-1B" + latte_model = STDiTLlama3Wrapper(model_size="1B", **model_extra_args) else: raise ValueError(f"Unknown model version: {args.model_version}") diff --git a/examples/opensora_hpcai/scripts/train.py b/examples/opensora_hpcai/scripts/train.py index 9bae6f39db..eea9affed9 100644 --- a/examples/opensora_hpcai/scripts/train.py +++ b/examples/opensora_hpcai/scripts/train.py @@ -457,7 +457,7 @@ def main(args): latte_model = STDiT3_XL_2(**model_extra_args) elif args.model_version == "llama3_1b": model_name = "Llama3-1B" - latte_model = STDiTLlama3Wrapper(model_size="1B") + latte_model = STDiTLlama3Wrapper(model_size="1B", **model_extra_args) else: raise ValueError(f"Unknown model version: {args.model_version}") logger.info(f"{model_name} input size: {latent_size if args.bucket_config is None else 'Variable'}") From df136b53c2810d5ea28a8770bc45e78ce3e22ab1 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Thu, 17 Oct 2024 14:52:17 +0800 Subject: [PATCH 08/14] fix inference with cfg=1.0 --- .../opensora/models/stdit/stdit_llama3.py | 14 +++++++++++++- .../opensora/pipelines/infer_pipeline.py | 2 +- .../opensora/schedulers/rectified_flow.py | 2 ++ examples/opensora_hpcai/scripts/inference.py | 6 +++++- 4 files changed, 21 insertions(+), 3 deletions(-) diff --git a/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py b/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py index 2a2a4826bf..3c1eb85211 100644 --- a/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py +++ b/examples/opensora_hpcai/opensora/models/stdit/stdit_llama3.py @@ -1,10 +1,11 @@ +import os from typing import Literal, Optional from llama3.llama import llama3_1B, llama3_5B, llama3_30B import mindspore.nn as nn import mindspore.ops as ops -from mindspore import Tensor +from mindspore import Tensor, load_checkpoint, load_param_into_net class STDiTLlama3Wrapper(nn.Cell): @@ -53,3 +54,14 @@ def construct( output = self.llama(latent_embedding, timestep, text_embedding) output = ops.transpose(output, (0, 2, 1, 3, 4)) return output + + def load_from_checkpoint(self, ckpt_path): + if not os.path.exists(ckpt_path): + print(f"WARNING: {ckpt_path} not found. No checkpoint loaded!!") + else: + sd = load_checkpoint(ckpt_path) + sd = {k.replace("network.llama.", "").replace("_backbone.", ""): v for k, v in sd.items()} + + m, u = load_param_into_net(self, sd, strict_load=True) + print("net param not load: ", m, len(m)) + print("ckpt param not load: ", u, len(u)) diff --git a/examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py b/examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py index 09d08dca4d..c123e9174d 100644 --- a/examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py +++ b/examples/opensora_hpcai/opensora/pipelines/infer_pipeline.py @@ -128,9 +128,9 @@ def data_prepare(self, inputs): # for token/text drop in caption embedder for condition-free guidance training. The null mask is the same as text mask. n = x.shape[0] # (n_tokens, dim_emb) -> (b n_tokens dim_emb) - null_emb = self.model.y_embedder.y_embedding[None, :, :].repeat(n, axis=0) if self.use_cfg: + null_emb = self.model.y_embedder.y_embedding[None, :, :].repeat(n, axis=0) y = ops.cat([text_emb, null_emb], axis=0) x_in = ops.concat([x] * 2, axis=0) assert y.shape[0] == x_in.shape[0], "shape mismatch!" diff --git a/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py b/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py index 8a481cde8c..25522c0b4f 100644 --- a/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py +++ b/examples/opensora_hpcai/opensora/schedulers/rectified_flow.py @@ -87,6 +87,8 @@ def __call__( noise_added = mask_t_upper pred = model(z, t, **model_kwargs) + # FIXME: a tmp solution for inference with cfg==1.0 + pred = pred[:, :4] # update z dt = timesteps[i] - timesteps[i + 1] if i < len(timesteps) - 1 else timesteps[i] diff --git a/examples/opensora_hpcai/scripts/inference.py b/examples/opensora_hpcai/scripts/inference.py index 2f7b02fc4d..539ae0a641 100644 --- a/examples/opensora_hpcai/scripts/inference.py +++ b/examples/opensora_hpcai/scripts/inference.py @@ -523,7 +523,11 @@ def parse_args(): help="path to load a config yaml file that describes the setting which will override the default arguments", ) parser.add_argument( - "--model_version", default="v1", type=str, choices=["v1", "v1.1", "v1.2"], help="OpenSora model version." + "--model_version", + default="v1", + type=str, + choices=["v1", "v1.1", "v1.2", "llama3_1b"], + help="OpenSora model version.", ) parser.add_argument("--image_size", type=int, nargs="+", help="image size in [256, 512]") parser.add_argument("--resolution", type=str, help=f"Supported video resolutions: {list(ASPECT_RATIOS.keys())}") From 239f8222c9d5231900f2877636278b9f9456bdde Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Fri, 18 Oct 2024 13:22:00 +0800 Subject: [PATCH 09/14] change to mint --- examples/llama3/llama/models/activation.py | 4 +- examples/llama3/llama/models/llama/layer.py | 86 ++++++++++--------- examples/llama3/llama/models/llama/network.py | 25 +++--- 3 files changed, 61 insertions(+), 54 deletions(-) diff --git a/examples/llama3/llama/models/activation.py b/examples/llama3/llama/models/activation.py index 22a4a66112..7b54d885a1 100644 --- a/examples/llama3/llama/models/activation.py +++ b/examples/llama3/llama/models/activation.py @@ -1,8 +1,8 @@ import logging from collections import OrderedDict +import mindspore.mint as mint import mindspore.nn as nn -import mindspore.ops as ops from mindspore import Tensor logger = logging.getLogger(__name__) @@ -10,7 +10,7 @@ class QuickGELU(nn.Cell): def construct(self, x: Tensor): - return x * ops.sigmoid(1.702 * x) + return x * mint.sigmoid(1.702 * x) class ClassInstantier(OrderedDict): diff --git a/examples/llama3/llama/models/llama/layer.py b/examples/llama3/llama/models/llama/layer.py index 76458ff676..d5fec37d55 100644 --- a/examples/llama3/llama/models/llama/layer.py +++ b/examples/llama3/llama/models/llama/layer.py @@ -2,6 +2,7 @@ from typing import Optional, Tuple import mindspore as ms +import mindspore.mint as mint import mindspore.nn as nn import mindspore.ops as ops from mindspore import Parameter, Tensor @@ -15,15 +16,15 @@ class LlamaRMSNorm(nn.Cell): def __init__(self, hidden_size: int, eps: float = 1e-6, dtype: ms.Type = ms.float32) -> None: super().__init__() - self.weight = Parameter(ops.ones(hidden_size, dtype=dtype)) + self.weight = Parameter(mint.ones(hidden_size, dtype=dtype)) self.variance_epsilon = eps def construct(self, hidden_states: Tensor) -> Tensor: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(ms.float32) - variance = ops.pow(hidden_states, 2) - variance = ops.mean(variance, axis=-1, keep_dims=True) - hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon) + variance = mint.pow(hidden_states, 2) + variance = mint.mean(variance, dim=-1, keepdim=True) + hidden_states = hidden_states * mint.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) @@ -31,8 +32,12 @@ class LlamaLayerNorm(nn.LayerNorm): def construct(self, hidden_states: Tensor) -> Tensor: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(ms.float32) - hidden_states, _, _ = self.layer_norm( - hidden_states, self.gamma.to(hidden_states.dtype), self.beta.to(hidden_states.dtype) + hidden_states = mint.layer_norm( + hidden_states, + self.normalized_shape, + self.gamma.to(hidden_states.dtype), + self.beta.to(hidden_states.dtype), + eps=self.epsilon, ) return hidden_states.to(input_dtype) @@ -48,9 +53,9 @@ def __init__( super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.gate_proj = nn.Dense(self.hidden_size, self.intermediate_size, has_bias=False, dtype=dtype) - self.up_proj = nn.Dense(self.hidden_size, self.intermediate_size, has_bias=False, dtype=dtype) - self.down_proj = nn.Dense(self.intermediate_size, self.hidden_size, has_bias=False, dtype=dtype) + self.gate_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=False, dtype=dtype) + self.up_proj = mint.nn.Linear(self.hidden_size, self.intermediate_size, bias=False, dtype=dtype) + self.down_proj = mint.nn.Linear(self.intermediate_size, self.hidden_size, bias=False, dtype=dtype) self.act_fn = ACT2FN[hidden_act] def construct(self, hidden_state: Tensor) -> Tensor: @@ -62,7 +67,7 @@ def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor: return hidden_states batch, num_key_value_heads, slen, head_dim = hidden_states.shape hidden_states = hidden_states[:, :, None, :, :] - hidden_states = ops.broadcast_to(hidden_states, (batch, num_key_value_heads, n_rep, slen, head_dim)) + hidden_states = mint.broadcast_to(hidden_states, (batch, num_key_value_heads, n_rep, slen, head_dim)) hidden_states = ops.reshape(hidden_states, (batch, num_key_value_heads * n_rep, slen, head_dim)) return hidden_states @@ -91,14 +96,14 @@ def __init__( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) - self.q_proj = nn.Dense(self.hidden_size, self.num_heads * self.head_dim, has_bias=attention_bias, dtype=dtype) - self.k_proj = nn.Dense( - self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=attention_bias, dtype=dtype + self.q_proj = mint.nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=attention_bias, dtype=dtype) + self.k_proj = mint.nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=attention_bias, dtype=dtype ) - self.v_proj = nn.Dense( - self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=attention_bias, dtype=dtype + self.v_proj = mint.nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=attention_bias, dtype=dtype ) - self.o_proj = nn.Dense(self.num_heads * self.head_dim, self.hidden_size, has_bias=attention_bias, dtype=dtype) + self.o_proj = mint.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=attention_bias, dtype=dtype) def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None) -> Tensor: bsz, q_len, _ = hidden_states.shape @@ -111,27 +116,27 @@ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tenso value_states = self.v_proj(kv_hidden_states) query_states = ops.reshape(query_states, (bsz, q_len, self.num_heads, self.head_dim)) - query_states = ops.transpose(query_states, (0, 2, 1, 3)) + query_states = mint.permute(query_states, (0, 2, 1, 3)) key_states = ops.reshape(key_states, (bsz, kv_len, self.num_key_value_heads, self.head_dim)) - key_states = ops.transpose(key_states, (0, 2, 1, 3)) + key_states = mint.permute(key_states, (0, 2, 1, 3)) value_states = ops.reshape(value_states, (bsz, kv_len, self.num_key_value_heads, self.head_dim)) - value_states = ops.transpose(value_states, (0, 2, 1, 3)) + value_states = mint.permute(value_states, (0, 2, 1, 3)) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - key_states = ops.transpose(key_states, (0, 1, 3, 2)) - attn_weights = ops.matmul(query_states, key_states) / ms.numpy.sqrt(self.head_dim) + key_states = mint.permute(key_states, (0, 1, 3, 2)) + attn_weights = mint.matmul(query_states, key_states) / mint.sqrt(Tensor(self.head_dim)) # upcast attention to fp32 attn_weights = attn_weights.to(ms.float32) - attn_weights = ops.softmax(attn_weights, axis=-1).to(query_states.dtype) - attn_weights = ops.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = ops.matmul(attn_weights, value_states) + attn_weights = mint.softmax(attn_weights, dim=-1).to(query_states.dtype) + attn_weights = mint.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = mint.matmul(attn_weights, value_states) - attn_output = ops.transpose(attn_output, (0, 2, 1, 3)) + attn_output = mint.permute(attn_output, (0, 2, 1, 3)) attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) attn_output = self.o_proj(attn_output) @@ -171,21 +176,21 @@ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tenso value_states = self.v_proj(kv_hidden_states) query_states = ops.reshape(query_states, (bsz, q_len, self.num_heads, self.head_dim)) - query_states = ops.transpose(query_states, (0, 2, 1, 3)) + query_states = mint.permute(query_states, (0, 2, 1, 3)) key_states = ops.reshape(key_states, (bsz, kv_len, self.num_key_value_heads, self.head_dim)) - key_states = ops.transpose(key_states, (0, 2, 1, 3)) + key_states = mint.permute(key_states, (0, 2, 1, 3)) value_states = ops.reshape(value_states, (bsz, kv_len, self.num_key_value_heads, self.head_dim)) - value_states = ops.transpose(value_states, (0, 2, 1, 3)) + value_states = mint.permute(value_states, (0, 2, 1, 3)) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) # Reshape to the expected shape and dtype for Flash Attention - query_states = ops.transpose(query_states, (0, 2, 1, 3)) - key_states = ops.transpose(key_states, (0, 2, 1, 3)) - value_states = ops.transpose(value_states, (0, 2, 1, 3)) + query_states = mint.permute(query_states, (0, 2, 1, 3)) + key_states = mint.permute(key_states, (0, 2, 1, 3)) + value_states = mint.permute(value_states, (0, 2, 1, 3)) _, _, _, attn_output = self.flash_attention(query_states, key_states, value_states, None, None, None, None) attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) @@ -220,9 +225,10 @@ def construct(self, x: Tensor) -> Tensor: assert h % self.patch_size[1] == 0 assert w % self.patch_size[2] == 0 - x = ops.transpose(x, (0, 2, 1, 3, 4)) + x = mint.permute(x, (0, 2, 1, 3, 4)) x = self.proj(x) # (B C T H W) - x = x.flatten(start_dim=2).swapaxes(1, 2) + x = mint.flatten(x, start_dim=2) + x = mint.permute(x, (0, 2, 1)) return x @@ -236,9 +242,9 @@ def __init__( ) -> None: super().__init__() self.mlp = nn.SequentialCell( - nn.Dense(frequency_embedding_size, hidden_size, has_bias=False, dtype=dtype), + mint.nn.Linear(frequency_embedding_size, hidden_size, bias=False, dtype=dtype), ACT2FN[hidden_act], - nn.Dense(hidden_size, hidden_size, has_bias=False, dtype=dtype), + mint.nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype), ) self.frequency_embedding_size = frequency_embedding_size self.dtype = dtype @@ -246,11 +252,11 @@ def __init__( @staticmethod def timestep_embedding(t: Tensor, dim: int, max_period: int = 10000) -> Tensor: half = dim // 2 - freqs = ops.exp(-ms.numpy.log(max_period) * ops.arange(start=0, end=half, dtype=ms.float32) / half) - args = t[:, None].to(ms.float32) * freqs[None] - embedding = ops.concat([ops.cos(args), ops.sin(args)], axis=-1) + freqs = mint.exp(-mint.log(Tensor(max_period)) * mint.arange(start=0, end=half, dtype=ms.float32) / half) + args = ops.unsqueeze(t, 1).to(ms.float32) * ops.unsqueeze(freqs, 0) + embedding = mint.cat([mint.cos(args), mint.sin(args)], dim=-1) if dim % 2: - embedding = ops.concat([embedding, ops.zeros_like(embedding[:, :1])], axis=-1) + embedding = mint.cat([embedding, mint.zeros_like(embedding[:, :1])], dim=-1) return embedding def construct(self, t: Tensor) -> Tensor: @@ -268,7 +274,7 @@ def __init__( ) -> None: super().__init__() self.proj = nn.SequentialCell( - nn.Dense(in_channels, hidden_size, has_bias=False, dtype=dtype), + mint.nn.Linear(in_channels, hidden_size, bias=False, dtype=dtype), LlamaLayerNorm((hidden_size,), dtype=dtype), ) diff --git a/examples/llama3/llama/models/llama/network.py b/examples/llama3/llama/models/llama/network.py index 2537f6e630..f1e38189e0 100644 --- a/examples/llama3/llama/models/llama/network.py +++ b/examples/llama3/llama/models/llama/network.py @@ -1,6 +1,7 @@ from typing import Literal, Optional, Tuple import mindspore as ms +import mindspore.mint as mint import mindspore.nn as nn import mindspore.ops as ops from mindspore import Parameter, Tensor, load_checkpoint @@ -68,7 +69,7 @@ def __init__( intermediate_size=intermediate_size, hidden_size=hidden_size, hidden_act=hidden_act, dtype=dtype ) - self.scale_shift_table = Parameter(ops.randn(6, hidden_size, dtype=dtype) / hidden_size**0.5) + self.scale_shift_table = Parameter(mint.normal(size=(6, hidden_size)).to(dtype) / hidden_size**0.5) self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) @@ -86,8 +87,8 @@ def construct( hidden_states = hidden_states + position_embedding.to(hidden_states.dtype) # 3.1.3 Adaptive Layer Norm - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ops.chunk( - self.scale_shift_table[None] + modulation_parameters.reshape(B, 6, -1), 6, axis=1 + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk( + self.scale_shift_table[None] + modulation_parameters.reshape(B, 6, -1), 6, dim=1 ) # Self Attention (Bi-Directional Attention) @@ -128,10 +129,10 @@ def __init__( self.proj = nn.Dense( hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, has_bias=False, dtype=dtype ) - self.scale_shift_table = Parameter(ops.randn(2, hidden_size, dtype=dtype) / hidden_size**0.5) + self.scale_shift_table = Parameter(mint.normal(size=(2, hidden_size)).to(dtype) / hidden_size**0.5) def construct(self, hidden_states: Tensor, timestep_embedding: Tensor): - shift, scale = ops.chunk(self.scale_shift_table[None] + timestep_embedding[:, None], 2, axis=1) + shift, scale = mint.chunk(self.scale_shift_table[None] + timestep_embedding[:, None], 2, dim=1) hidden_states = t2i_modulate(self.input_layernorm(hidden_states), shift, scale) hidden_states = self.proj(hidden_states) return hidden_states @@ -199,7 +200,7 @@ def __init__( self.latent_embedder = PatchEmbed3D(self.patch_size, self.in_channels, self.hidden_size, dtype=dtype) self.timestep_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype) self.adaLN_modulation = nn.SequentialCell( - ACT2FN[hidden_act], nn.Dense(self.hidden_size, 6 * self.hidden_size, has_bias=False, dtype=dtype) + ACT2FN[hidden_act], mint.nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=False, dtype=dtype) ) self.caption_embedder = CaptionEmbedder(caption_channels, self.hidden_size, dtype=dtype) @@ -215,7 +216,7 @@ def init_weights(self): std = self.initializer_range def _init_weights(module): - if isinstance(module, nn.Dense): + if isinstance(module, mint.nn.Linear): normal_(module.weight, mean=0.0, std=std) if module.bias is not None: zeros_(module.weight) @@ -242,12 +243,12 @@ def _init_weights(module): def learnable_position_embedding(self, latent_embedding: Tensor) -> Tensor: # 3.1.3 _, t, _, h, w = latent_embedding.shape - t_inds = ops.arange(t // self.patch_size[0], dtype=ms.int64) - h_inds = ops.arange(h // self.patch_size[1], dtype=ms.int64) - w_inds = ops.arange(w // self.patch_size[2], dtype=ms.int64) + t_inds = mint.arange(t // self.patch_size[0], dtype=ms.int64) + h_inds = mint.arange(h // self.patch_size[1], dtype=ms.int64) + w_inds = mint.arange(w // self.patch_size[2], dtype=ms.int64) position_ids = ops.meshgrid(t_inds, h_inds, w_inds, indexing="ij") - position_ids = ops.stack(position_ids, axis=-1) + position_ids = mint.stack(position_ids, dim=-1) position_ids = ops.reshape(position_ids, (-1, 3)) t_inds, h_inds, w_inds = ops.unbind(position_ids, dim=-1) @@ -267,7 +268,7 @@ def unpatchify(self, hidden_states: Tensor, t: int, h: int, w: int) -> Tensor: hidden_states = ops.reshape(hidden_states, (bs, nt, nh, nw, p0, p1, p2, c)) # bs, nt, p0, c, nh, p1, nw, p2, c - hidden_states = ops.transpose(hidden_states, (0, 1, 4, 7, 2, 5, 3, 6)) + hidden_states = mint.permute(hidden_states, (0, 1, 4, 7, 2, 5, 3, 6)) output = ops.reshape(hidden_states, (bs, nt * p0, c, nh * p1, nw * p2)) return output From 915c3c9c62f2e7f895869d07b9c39186ccb91e77 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Fri, 18 Oct 2024 14:30:11 +0800 Subject: [PATCH 10/14] add linear patch embedder --- examples/llama3/llama/models/llama/layer.py | 46 +++++++++++++------ examples/llama3/llama/models/llama/network.py | 10 +++- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/examples/llama3/llama/models/llama/layer.py b/examples/llama3/llama/models/llama/layer.py index d5fec37d55..629e6033dc 100644 --- a/examples/llama3/llama/models/llama/layer.py +++ b/examples/llama3/llama/models/llama/layer.py @@ -28,20 +28,6 @@ def construct(self, hidden_states: Tensor) -> Tensor: return self.weight * hidden_states.to(input_dtype) -class LlamaLayerNorm(nn.LayerNorm): - def construct(self, hidden_states: Tensor) -> Tensor: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(ms.float32) - hidden_states = mint.layer_norm( - hidden_states, - self.normalized_shape, - self.gamma.to(hidden_states.dtype), - self.beta.to(hidden_states.dtype), - eps=self.epsilon, - ) - return hidden_states.to(input_dtype) - - class LlamaMLP(nn.Cell): def __init__( self, @@ -232,6 +218,35 @@ def construct(self, x: Tensor) -> Tensor: return x +class LinearPatchEmbed3D(nn.Cell): + def __init__( + self, + patch_size: Tuple[int, int, int] = (1, 2, 2), + in_channels: int = 8, + hidden_size: int = 4096, + dtype: ms.Type = ms.float32, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.proj = mint.nn.Linear( + patch_size[0] * patch_size[1] * patch_size[2] * in_channels, hidden_size, bias=False, dtype=dtype + ) + + def construct(self, x: Tensor) -> Tensor: + b, t, c, h, w = x.shape + assert t % self.patch_size[0] == 0 + assert h % self.patch_size[1] == 0 + assert w % self.patch_size[2] == 0 + + p0, p1, p2 = self.patch_size[0], self.patch_size[1], self.patch_size[2] + nt, nh, nw = t // p0, h // p1, w // p2 + x = ops.reshape(x, (b, nt, p0, c, nh, p1, nw, p2)) + x = mint.permute(x, (0, 1, 4, 6, 3, 2, 5, 7)) # (B, nt, nh, nw, c, p0, p1, p2) + x = ops.reshape(x, (b, nt * nh * nw, -1)) + x = self.proj(x) + return x + + class TimestepEmbedder(nn.Cell): def __init__( self, @@ -270,12 +285,13 @@ def __init__( self, in_channels: int, hidden_size: int, + eps: float = 1e-6, dtype: ms.Type = ms.float32, ) -> None: super().__init__() self.proj = nn.SequentialCell( mint.nn.Linear(in_channels, hidden_size, bias=False, dtype=dtype), - LlamaLayerNorm((hidden_size,), dtype=dtype), + LlamaRMSNorm((hidden_size,), epsilon=eps, dtype=dtype), ) def construct(self, caption: Tensor) -> Tensor: diff --git a/examples/llama3/llama/models/llama/network.py b/examples/llama3/llama/models/llama/network.py index f1e38189e0..69a44a44dc 100644 --- a/examples/llama3/llama/models/llama/network.py +++ b/examples/llama3/llama/models/llama/network.py @@ -11,6 +11,7 @@ from ..activation import ACT2FN from .layer import ( CaptionEmbedder, + LinearPatchEmbed3D, LlamaAttention, LlamaFlashAttention, LlamaMLP, @@ -158,6 +159,7 @@ def __init__( caption_channels: int = 4096, attn_implementation: Literal["eager", "flash_attention"] = "eager", gradient_checkpointing: bool = False, + use_linear_patch_embedder: bool = True, dtype: ms.Type = ms.float32, ) -> None: super().__init__() @@ -197,12 +199,16 @@ def __init__( self.pos_embedding_table_h = nn.Embedding(max_length[1], self.hidden_size, dtype=dtype) self.pos_embedding_table_w = nn.Embedding(max_length[2], self.hidden_size, dtype=dtype) - self.latent_embedder = PatchEmbed3D(self.patch_size, self.in_channels, self.hidden_size, dtype=dtype) + if use_linear_patch_embedder: + self.latent_embedder = LinearPatchEmbed3D(self.patch_size, self.in_channels, self.hidden_size, dtype=dtype) + else: + self.latent_embedder = PatchEmbed3D(self.patch_size, self.in_channels, self.hidden_size, dtype=dtype) + self.timestep_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype) self.adaLN_modulation = nn.SequentialCell( ACT2FN[hidden_act], mint.nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=False, dtype=dtype) ) - self.caption_embedder = CaptionEmbedder(caption_channels, self.hidden_size, dtype=dtype) + self.caption_embedder = CaptionEmbedder(caption_channels, self.hidden_size, eps=rms_norm_eps, dtype=dtype) # post-init self.initializer_range = initializer_range From 85097fa913f268cc0d9b6234b282be287bcd8220 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Fri, 18 Oct 2024 14:50:43 +0800 Subject: [PATCH 11/14] fix error from mint and eps --- examples/llama3/llama/models/llama/layer.py | 2 +- examples/llama3/llama/models/llama/network.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/llama3/llama/models/llama/layer.py b/examples/llama3/llama/models/llama/layer.py index 629e6033dc..b0ab294e57 100644 --- a/examples/llama3/llama/models/llama/layer.py +++ b/examples/llama3/llama/models/llama/layer.py @@ -291,7 +291,7 @@ def __init__( super().__init__() self.proj = nn.SequentialCell( mint.nn.Linear(in_channels, hidden_size, bias=False, dtype=dtype), - LlamaRMSNorm((hidden_size,), epsilon=eps, dtype=dtype), + LlamaRMSNorm((hidden_size,), eps=eps, dtype=dtype), ) def construct(self, caption: Tensor) -> Tensor: diff --git a/examples/llama3/llama/models/llama/network.py b/examples/llama3/llama/models/llama/network.py index 69a44a44dc..dddc605166 100644 --- a/examples/llama3/llama/models/llama/network.py +++ b/examples/llama3/llama/models/llama/network.py @@ -1,5 +1,7 @@ from typing import Literal, Optional, Tuple +import numpy as np + import mindspore as ms import mindspore.mint as mint import mindspore.nn as nn @@ -70,7 +72,7 @@ def __init__( intermediate_size=intermediate_size, hidden_size=hidden_size, hidden_act=hidden_act, dtype=dtype ) - self.scale_shift_table = Parameter(mint.normal(size=(6, hidden_size)).to(dtype) / hidden_size**0.5) + self.scale_shift_table = Parameter(Tensor(np.random.randn(6, hidden_size), dtype=dtype) / hidden_size**0.5) self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) @@ -130,7 +132,7 @@ def __init__( self.proj = nn.Dense( hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, has_bias=False, dtype=dtype ) - self.scale_shift_table = Parameter(mint.normal(size=(2, hidden_size)).to(dtype) / hidden_size**0.5) + self.scale_shift_table = Parameter(Tensor(np.random.randn(2, hidden_size), dtype=dtype) / hidden_size**0.5) def construct(self, hidden_states: Tensor, timestep_embedding: Tensor): shift, scale = mint.chunk(self.scale_shift_table[None] + timestep_embedding[:, None], 2, dim=1) From 02e535de55eec33a17ee7b045436dc72ec56073c Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Mon, 21 Oct 2024 09:58:39 +0800 Subject: [PATCH 12/14] fix nan loss for abs embedding --- examples/llama3/llama/models/llama/network.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/examples/llama3/llama/models/llama/network.py b/examples/llama3/llama/models/llama/network.py index dddc605166..43a392ce4d 100644 --- a/examples/llama3/llama/models/llama/network.py +++ b/examples/llama3/llama/models/llama/network.py @@ -90,8 +90,8 @@ def construct( hidden_states = hidden_states + position_embedding.to(hidden_states.dtype) # 3.1.3 Adaptive Layer Norm - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk( - self.scale_shift_table[None] + modulation_parameters.reshape(B, 6, -1), 6, dim=1 + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ops.chunk( + ops.unsqueeze(self.scale_shift_table, 0) + modulation_parameters.reshape(B, 6, -1), 6, axis=1 ) # Self Attention (Bi-Directional Attention) @@ -135,7 +135,9 @@ def __init__( self.scale_shift_table = Parameter(Tensor(np.random.randn(2, hidden_size), dtype=dtype) / hidden_size**0.5) def construct(self, hidden_states: Tensor, timestep_embedding: Tensor): - shift, scale = mint.chunk(self.scale_shift_table[None] + timestep_embedding[:, None], 2, dim=1) + shift, scale = ops.chunk( + ops.unsqueeze(self.scale_shift_table, 0) + ops.unsqueeze(timestep_embedding, 1), 2, axis=1 + ) hidden_states = t2i_modulate(self.input_layernorm(hidden_states), shift, scale) hidden_states = self.proj(hidden_states) return hidden_states @@ -157,7 +159,7 @@ def __init__( hidden_act: str = "silu", initializer_range: float = 0.02, patch_size: Tuple[int, int, int] = (1, 2, 2), - max_length: Tuple[int, int, int] = (32, 16, 16), + max_length: Tuple[int, int, int] = (128, 64, 64), caption_channels: int = 4096, attn_implementation: Literal["eager", "flash_attention"] = "eager", gradient_checkpointing: bool = False, @@ -171,6 +173,7 @@ def __init__( self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads + self.max_length = max_length self.layers = nn.CellList( [ @@ -251,12 +254,19 @@ def _init_weights(module): def learnable_position_embedding(self, latent_embedding: Tensor) -> Tensor: # 3.1.3 _, t, _, h, w = latent_embedding.shape - t_inds = mint.arange(t // self.patch_size[0], dtype=ms.int64) - h_inds = mint.arange(h // self.patch_size[1], dtype=ms.int64) - w_inds = mint.arange(w // self.patch_size[2], dtype=ms.int64) + p0, p1, p2 = self.patch_size[0], self.patch_size[1], self.patch_size[2] + nt, nh, nw = t // p0, h // p1, w // p2 + + assert nt < self.max_length[0] + assert nh < self.max_length[1] + assert nw < self.max_length[2] + + t_inds = mint.arange(nt, dtype=ms.int64) + h_inds = mint.arange(nh, dtype=ms.int64) + w_inds = mint.arange(nw, dtype=ms.int64) position_ids = ops.meshgrid(t_inds, h_inds, w_inds, indexing="ij") - position_ids = mint.stack(position_ids, dim=-1) + position_ids = ops.stack(position_ids, axis=-1) position_ids = ops.reshape(position_ids, (-1, 3)) t_inds, h_inds, w_inds = ops.unbind(position_ids, dim=-1) From a29741829401c8f0261a3a045e9d5505a3b895e8 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Mon, 21 Oct 2024 11:05:50 +0800 Subject: [PATCH 13/14] add Llama3 5B wrapper --- examples/opensora_hpcai/scripts/args_train.py | 2 +- examples/opensora_hpcai/scripts/inference.py | 5 ++++- examples/opensora_hpcai/scripts/train.py | 3 +++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/opensora_hpcai/scripts/args_train.py b/examples/opensora_hpcai/scripts/args_train.py index 5d2e0a1aa1..6c31042169 100644 --- a/examples/opensora_hpcai/scripts/args_train.py +++ b/examples/opensora_hpcai/scripts/args_train.py @@ -52,7 +52,7 @@ def parse_train_args(parser): "--model_version", default="v1", type=str, - choices=["v1", "v1.1", "v1.2", "llama3_1b"], + choices=["v1", "v1.1", "v1.2", "llama3_1b", "llama3_5b"], help="OpenSora model version.", ) parser.add_argument( diff --git a/examples/opensora_hpcai/scripts/inference.py b/examples/opensora_hpcai/scripts/inference.py index 539ae0a641..80fa19acb6 100644 --- a/examples/opensora_hpcai/scripts/inference.py +++ b/examples/opensora_hpcai/scripts/inference.py @@ -256,6 +256,9 @@ def main(args): elif args.model_version == "llama3_1b": model_name = "Llama3-1B" latte_model = STDiTLlama3Wrapper(model_size="1B", **model_extra_args) + elif args.model_version == "llama3_5b": + model_name = "Llama3-5B" + latte_model = STDiTLlama3Wrapper(model_size="5B", **model_extra_args) else: raise ValueError(f"Unknown model version: {args.model_version}") @@ -526,7 +529,7 @@ def parse_args(): "--model_version", default="v1", type=str, - choices=["v1", "v1.1", "v1.2", "llama3_1b"], + choices=["v1", "v1.1", "v1.2", "llama3_1b", "llama3_5b"], help="OpenSora model version.", ) parser.add_argument("--image_size", type=int, nargs="+", help="image size in [256, 512]") diff --git a/examples/opensora_hpcai/scripts/train.py b/examples/opensora_hpcai/scripts/train.py index eea9affed9..d273d8d4b0 100644 --- a/examples/opensora_hpcai/scripts/train.py +++ b/examples/opensora_hpcai/scripts/train.py @@ -458,6 +458,9 @@ def main(args): elif args.model_version == "llama3_1b": model_name = "Llama3-1B" latte_model = STDiTLlama3Wrapper(model_size="1B", **model_extra_args) + elif args.model_version == "llama3_5b": + model_name = "Llama3-5B" + latte_model = STDiTLlama3Wrapper(model_size="5B", **model_extra_args) else: raise ValueError(f"Unknown model version: {args.model_version}") logger.info(f"{model_name} input size: {latent_size if args.bucket_config is None else 'Variable'}") From 69f2a6ec25f5877176ebde4f5bfbd9ffa4d3bc5f Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Mon, 21 Oct 2024 15:11:46 +0800 Subject: [PATCH 14/14] add column parallel linear for MP --- examples/llama3/llama/parallel/__init__.py | 1 + examples/llama3/llama/parallel/layers.py | 93 +++++++++++++++ .../llama3/llama/parallel/parallel_states.py | 36 ++++++ .../llama3/tests/run_test_parallel_layer.sh | 13 +++ examples/llama3/tests/test_parallel_layer.py | 109 ++++++++++++++++++ 5 files changed, 252 insertions(+) create mode 100644 examples/llama3/llama/parallel/__init__.py create mode 100644 examples/llama3/llama/parallel/layers.py create mode 100644 examples/llama3/llama/parallel/parallel_states.py create mode 100755 examples/llama3/tests/run_test_parallel_layer.sh create mode 100644 examples/llama3/tests/test_parallel_layer.py diff --git a/examples/llama3/llama/parallel/__init__.py b/examples/llama3/llama/parallel/__init__.py new file mode 100644 index 0000000000..69a388db1d --- /dev/null +++ b/examples/llama3/llama/parallel/__init__.py @@ -0,0 +1 @@ +from .layers import * diff --git a/examples/llama3/llama/parallel/layers.py b/examples/llama3/llama/parallel/layers.py new file mode 100644 index 0000000000..9bd60a175d --- /dev/null +++ b/examples/llama3/llama/parallel/layers.py @@ -0,0 +1,93 @@ +import numbers +from typing import Callable, Optional, Tuple, Union + +import mindspore as ms +import mindspore.mint as mint +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.common.initializer import Initializer +from mindspore.communication import GlobalComm, get_group_size, get_rank + +__all__ = ["ColumnParallelLinear"] + + +def _communicate_along_dim(x: Tensor, dim: int, func: Callable[[Tensor], Tensor]) -> Tensor: + x = x.swapaxes(0, dim) + x = func(x) + x = x.swapaxes(dim, 0) + return x + + +def _split(x: Tensor, dim: int, rank: int, world_size: int) -> Tensor: + dim_size = x.shape[dim] + tensor_list = x.split(dim_size // world_size, axis=dim) + x = tensor_list[rank] + return x + + +class _CopyToModelParallelRegion(nn.Cell): + def __init__(self, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__(auto_prefix=False) + self.reduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=group) + + def construct(self, x: Tensor) -> Tensor: + return x + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = self.reduce(dout) + return (dout,) + + +class _GatherFromModelParallelRegion(nn.Cell): + def __init__(self, group: str = GlobalComm.WORLD_COMM_GROUP) -> None: + super().__init__(auto_prefix=False) + self.gather = ops.AllGather(group=group) + self.rank = get_rank(group) + self.world_size = get_group_size(group) + + def construct(self, x: Tensor) -> Tensor: + return _communicate_along_dim(x, -1, self.gather) + + def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: + dout = _split(dout, -1, self.rank, self.world_size) + return (dout,) + + +class ColumnParallelLinear(nn.Cell): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + weight_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + bias_init: Union[None, Tensor, str, Initializer, numbers.Number] = None, + gather_output: bool = True, + group: str = GlobalComm.WORLD_COMM_GROUP, + dtype: Optional[ms.Type] = None, + ): + super().__init__(auto_prefix=False) + + self.group_size = get_group_size(group) + assert out_features % self.group_size == 0 + self.out_features_per_partition = out_features // self.group_size + self.gather_output = gather_output + + self.copy_to_model_parallel_region = _CopyToModelParallelRegion(group) + self.linear = mint.nn.Linear( + in_features, + self.out_features_per_partition, + bias=bias, + weight_init=weight_init, + bias_init=bias_init, + dtype=dtype, + ) + if self.gather_output: + self.gather_from_model_parallel_region = _GatherFromModelParallelRegion(group) + + def construct(self, x: Tensor) -> Tensor: + x = self.copy_to_model_parallel_region(x) + x = self.linear(x) + if self.gather_output: + x = self.gather_from_model_parallel_region(x) + return x diff --git a/examples/llama3/llama/parallel/parallel_states.py b/examples/llama3/llama/parallel/parallel_states.py new file mode 100644 index 0000000000..3e6c5f73e1 --- /dev/null +++ b/examples/llama3/llama/parallel/parallel_states.py @@ -0,0 +1,36 @@ +from typing import Optional + +from mindspore.communication import create_group, get_group_size, get_rank + +_GLOBAL_PARALLEL_GROUPS = dict() + + +def set_model_parallel_group(group: str) -> None: + _GLOBAL_PARALLEL_GROUPS["model"] = group + + +def get_model_parallel_group() -> Optional[str]: + return _GLOBAL_PARALLEL_GROUPS.get("model", None) + + +def set_sequence_parallel_group(group: str) -> None: + _GLOBAL_PARALLEL_GROUPS["sequence"] = group + + +def get_sequence_parallel_group() -> Optional[str]: + return _GLOBAL_PARALLEL_GROUPS.get("sequence", None) + + +def create_parallel_group(model_parallel_shards: int = 1) -> None: + device_num = get_group_size() + if device_num % model_parallel_shards != 0: + raise ValueError( + f"Total number of devices ({device_num}) must be devisible by the number of model parallel shards ({model_parallel_shards})." + ) + + rank_id = get_rank() + mp_group_id = rank_id // model_parallel_shards + mp_group_rank_ids = list(range(mp_group_id * model_parallel_shards, (mp_group_id + 1) * model_parallel_shards)) + mp_group_name = f"mp_group_{mp_group_id}" + create_group(mp_group_name, mp_group_rank_ids) + set_model_parallel_group(mp_group_name) diff --git a/examples/llama3/tests/run_test_parallel_layer.sh b/examples/llama3/tests/run_test_parallel_layer.sh new file mode 100755 index 0000000000..d37225d562 --- /dev/null +++ b/examples/llama3/tests/run_test_parallel_layer.sh @@ -0,0 +1,13 @@ +#!/bin/sh + +SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +PROJECT_DIR="$(dirname "${SCRIPT_DIR}")" +EXAMPLE_DIR="$(dirname "${PROJECT_DIR}")" +PACKAGE_DIR="$(dirname "${EXAMPLE_DIR}")" + +export PYTHONPATH="${PROJECT_DIR}:${PACKAGE_DIR}:${PYTHONPATH}" + +LOGDIR="./log_test_parallel_layer_graph" +echo "Graph Mode:" +msrun --master_port=1234 --worker_num=2 --local_worker_num=2 --log_dir=$LOGDIR --join True ${SCRIPT_DIR}/test_parallel_layer.py --mode 0 +echo "Done. Check the log at '$LOGDIR'." diff --git a/examples/llama3/tests/test_parallel_layer.py b/examples/llama3/tests/test_parallel_layer.py new file mode 100644 index 0000000000..6c8afb3255 --- /dev/null +++ b/examples/llama3/tests/test_parallel_layer.py @@ -0,0 +1,109 @@ +import argparse + +import numpy as np +from llama.parallel import ColumnParallelLinear +from llama.parallel.parallel_states import create_parallel_group, get_model_parallel_group + +import mindspore as ms +import mindspore.mint as mint +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.communication import get_group_size, get_rank, init + +from mindone.utils.seed import set_random_seed + + +class MeanNet(nn.Cell): + def __init__(self, net: nn.Cell) -> None: + super().__init__() + self.net = net + + def construct(self, *inputs): + output = self.net(*inputs) + return output.mean() + + +def get_sample_data(): + x = ops.rand([4, 64, 256], dtype=ms.float32) # (N, T, H) + return x + + +def get_layer_config(): + config = dict(in_features=256, out_features=32, bias=True) + return config + + +def run_layer(mode: int = 0, dtype: ms.Type = ms.float32): + ms.set_context(mode=mode) + ms.set_auto_parallel_context(enable_alltoall=True) + init() + + # prepare data + set_random_seed(1024) + data = get_sample_data() + + # non parallel layer + set_random_seed(1024) + non_parallel_layer_cfg = get_layer_config() + non_parallel_layer = mint.nn.Linear(**non_parallel_layer_cfg, dtype=dtype) + + # parallel layer + create_parallel_group(get_group_size()) + group = get_model_parallel_group() + set_random_seed(1024) + parallel_layer_cfg = get_layer_config() + parallel_layer = ColumnParallelLinear(**parallel_layer_cfg, gather_output=True, group=group, dtype=dtype) + + mp_size = get_group_size(group) + mp_rank = get_rank(group) + for (_, w0), (_, w1) in zip(non_parallel_layer.parameters_and_names(), parallel_layer.parameters_and_names()): + w0_p = ops.chunk(w0, mp_size, axis=0)[mp_rank] + w1.set_data(w0_p) + + # test forward + non_parallel_out = non_parallel_layer(data) + parallel_out = parallel_layer(data) + + np.testing.assert_equal(non_parallel_out.shape, parallel_out.shape) + np.testing.assert_allclose(non_parallel_out.asnumpy(), parallel_out.asnumpy(), atol=1e-5) + print("Test 1 (Forward): Passed.") + + # test backward + non_parallel_mean_net = MeanNet(non_parallel_layer) + parallel_mean_net = MeanNet(parallel_layer) + + # check the parameter gradient + grad_fn = ops.grad(non_parallel_mean_net, grad_position=None, weights=non_parallel_mean_net.trainable_params()) + non_parallel_grads = grad_fn(data) + + grad_fn = ops.grad(parallel_mean_net, grad_position=None, weights=parallel_mean_net.trainable_params()) + parallel_grads = grad_fn(data) + + allgather = ops.AllGather(group=group) + syn_parallel_grads = list() + for x in parallel_grads: + syn_parallel_grads.append(allgather(x)) + + for grad_0, grad_1 in zip(non_parallel_grads, syn_parallel_grads): + np.testing.assert_allclose(grad_0.asnumpy(), grad_1.asnumpy(), atol=1e-5) + print("Test 2 (Backward: Parameter Gradient): Passed.") + + # check the input gradient + grad_fn = ops.grad(non_parallel_mean_net, grad_position=0) + non_parallel_grads = grad_fn(data) + + grad_fn = ops.grad(parallel_mean_net, grad_position=0) + parallel_grads = grad_fn(data) + + for grad_0, grad_1 in zip(non_parallel_grads, parallel_grads): + np.testing.assert_allclose(grad_0.asnumpy(), grad_1.asnumpy(), atol=1e-5) + print("Test 3 (Backward: Input Gradient): Passed.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", default=0, type=int, choices=[0, 1], help="Mode to test. (0: Graph Mode; 1: Pynative mode)" + ) + args = parser.parse_args() + run_layer(mode=args.mode)