From ede4a57a1cf59de4f73b787788e743a374d932d8 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Fri, 27 Sep 2024 10:54:41 +0800 Subject: [PATCH] LLaVA-Next for opensora --- .../tools/caption/llava_next/README.md | 53 +++ .../tools/caption/llava_next/assets/.gitkeep | 0 .../caption/llava_next/llava/__init__.py | 0 .../llava_next/llava/model/__init__.py | 0 .../llava_next/llava/model/activation.py | 28 ++ .../llava_next/llava/model/clip/__init__.py | 1 + .../llava_next/llava/model/clip/layer.py | 130 +++++ .../llava_next/llava/model/clip/network.py | 228 +++++++++ .../llava_next/llava/model/common_layer.py | 112 +++++ .../llava/model/llava_next/__init__.py | 1 + .../llava/model/llava_next/network.py | 449 ++++++++++++++++++ .../llava/model/llava_next/utils.py | 95 ++++ .../llava/model/mistral/__init__.py | 1 + .../llava_next/llava/model/mistral/layer.py | 288 +++++++++++ .../llava_next/llava/model/mistral/network.py | 295 ++++++++++++ .../caption/llava_next/llava/model/padding.py | 30 ++ .../llava_next/llava/pipeline/__init__.py | 1 + .../llava/pipeline/helpers/__init__.py | 1 + .../pipeline/helpers/stopping_criteria.py | 56 +++ .../llava/pipeline/text_generation.py | 275 +++++++++++ .../tools/caption/llava_next/models/.gitkeep | 0 .../tools/caption/llava_next/predict.py | 94 ++++ .../caption/llava_next/tools/convert_llava.py | 67 +++ .../llava_next/tools/fileio/__init__.py | 1 + .../llava_next/tools/fileio/safetensors.py | 21 + 25 files changed, 2227 insertions(+) create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/README.md create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/assets/.gitkeep create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/__init__.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/model/__init__.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/model/activation.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/model/clip/__init__.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/model/clip/layer.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/model/clip/network.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/model/common_layer.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/model/llava_next/__init__.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/model/llava_next/network.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/model/llava_next/utils.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/model/mistral/__init__.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/model/mistral/layer.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/model/mistral/network.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/model/padding.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/pipeline/__init__.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/pipeline/helpers/__init__.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/pipeline/helpers/stopping_criteria.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/llava/pipeline/text_generation.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/models/.gitkeep create mode 100755 examples/opensora_hpcai/tools/caption/llava_next/predict.py create mode 100755 examples/opensora_hpcai/tools/caption/llava_next/tools/convert_llava.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/tools/fileio/__init__.py create mode 100644 examples/opensora_hpcai/tools/caption/llava_next/tools/fileio/safetensors.py diff --git a/examples/opensora_hpcai/tools/caption/llava_next/README.md b/examples/opensora_hpcai/tools/caption/llava_next/README.md new file mode 100644 index 0000000000..e107c53dba --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/README.md @@ -0,0 +1,53 @@ +# LLaVA-NeXT: Open Large Multimodal Models (MindSpore) + +This repo contains Mindspore model definitions, pre-trained weights and inference/sampling code for the [model](https://llava-vl.github.io/blog/2024-01-30-llava-next/). Referring to the [official project page](https://github.com/LLaVA-VL/LLaVA-NeXT). + +## Dependencies and Installation + +- CANN: 8.0.RC2 or later +- Python: 3.9 or later +- Mindspore: 2.3.1 + +## Getting Start + +### Downloading Pretrained Checkpoints + +Please download the model [llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) to the `./models` directory. And run + +```bash +python tools/convert_llava.py models/llava-v1.6-mistral-7b-hf -o models/llava-v1.6-mistral-7b-hf/model.ckpt +``` + +to convert the model weight in Mindspore `ckpt` format. + +### Inference + +To run the inference, you may use `predict.py` with the following command + +```bash +python predict.py --input_image path_to_your_input_image --prompt input_prompt +``` + +For example, running `python predict.py` with the default image [llava_v1_5_radar.jpg](https://github.com/user-attachments/assets/8e016871-82fd-488a-8629-5ca71222e0e3) and default prompt `What is shown in this image?` will give the following result: + +```text +[INST] +What is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multivariate chart that displays values for multiple variables represented on axes +starting from the same point. This particular radar chart is showing the performance of different models or systems across various metrics. + +The axes represent different metrics or benchmarks, such as MM-Vet, MM-Vet, MM-Vet, MM-Vet, MM-Vet, MM-V +``` + +## Benchmark + +### Inference + +To perform the benchmark, you may first download the image [llava_v1_5_radar.jpg](https://github.com/user-attachments/assets/8e016871-82fd-488a-8629-5ca71222e0e3) and save it in `./assets`, and then run `python predict --benchmark` to get the throughput. + +| Model | Context | Batch Size | Throughput (tokens/second)| +|-----------------------|---------------|------------|---------------------------| +| llava-v1.6-mistral-7b | D910*x1-MS2.3 | 1 | 21.2 | + +> Context: {Ascend chip}-{number of NPUs}-{mindspore version}.\ +> Throughput (tokens/second): number of generated tokens per second.\ +> We use the second round of inference as the benchmark result. diff --git a/examples/opensora_hpcai/tools/caption/llava_next/assets/.gitkeep b/examples/opensora_hpcai/tools/caption/llava_next/assets/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/__init__.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/model/__init__.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/model/activation.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/activation.py new file mode 100644 index 0000000000..22a4a66112 --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/activation.py @@ -0,0 +1,28 @@ +import logging +from collections import OrderedDict + +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor + +logger = logging.getLogger(__name__) + + +class QuickGELU(nn.Cell): + def construct(self, x: Tensor): + return x * ops.sigmoid(1.702 * x) + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "quick_gelu": QuickGELU, + "gelu": nn.GELU, + "silu": nn.SiLU, +} +ACT2FN = ClassInstantier(ACT2CLS) diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/model/clip/__init__.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/clip/__init__.py new file mode 100644 index 0000000000..3f7f715f79 --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/clip/__init__.py @@ -0,0 +1 @@ +from .network import CLIPVisionModel, CLIPVisionModelWithProjection diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/model/clip/layer.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/clip/layer.py new file mode 100644 index 0000000000..75f8da3a97 --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/clip/layer.py @@ -0,0 +1,130 @@ +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor + +from ..activation import ACT2FN +from ..common_layer import Embedding + + +class CLIPVisionEmbeddings(nn.Cell): + def __init__( + self, + hidden_size: int = 1024, + image_size: int = 336, + patch_size: int = 14, + num_channels: int = 3, + dtype: ms.dtype = ms.float32, + ) -> None: + super().__init__() + self.embed_dim = hidden_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + + self.class_embedding = ms.Parameter(Tensor(ops.randn(self.embed_dim), dtype=dtype)) + + self.patch_embedding = nn.Conv2d( + in_channels=num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + has_bias=False, + dtype=dtype, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = Embedding(self.num_positions, self.embed_dim, dtype=dtype) + self.position_ids = ops.broadcast_to(ops.arange(self.num_positions), (1, -1)) + + def construct(self, pixel_values: Tensor) -> Tensor: + batch_size = pixel_values.shape[0] + patch_embeds = self.patch_embedding(pixel_values) + patch_embeds = ops.flatten(patch_embeds, start_dim=2) + patch_embeds = ops.transpose(patch_embeds, (0, 2, 1)) + + class_embeds = ops.broadcast_to(self.class_embedding, (batch_size, 1, -1)) + embeddings = ops.concat([class_embeds, patch_embeds], axis=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class CLIPAttention(nn.Cell): + def __init__( + self, + hidden_size: int = 1024, + num_attention_heads: int = 16, + attention_dropout: float = 0.0, + dtype: ms.dtype = ms.float32, + ) -> None: + super().__init__() + self.embed_dim = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = attention_dropout + + self.k_proj = nn.Dense(self.embed_dim, self.embed_dim, dtype=dtype) + self.v_proj = nn.Dense(self.embed_dim, self.embed_dim, dtype=dtype) + self.q_proj = nn.Dense(self.embed_dim, self.embed_dim, dtype=dtype) + self.out_proj = nn.Dense(self.embed_dim, self.embed_dim, dtype=dtype) + + def _shape(self, tensor: Tensor, seq_len: int, bsz: int) -> Tensor: + tensor = ops.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)) + tensor = ops.transpose(tensor, (0, 2, 1, 3)) + return tensor + + def construct(self, hidden_states: Tensor) -> Tensor: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.shape + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz) + query_states = ops.reshape(query_states, proj_shape) + key_states = ops.reshape(key_states, proj_shape) + value_states = ops.reshape(value_states, proj_shape) + + attn_weights = ops.bmm(query_states, ops.transpose(key_states, (0, 2, 1))) + attn_weights = ops.softmax(attn_weights, axis=-1) + attn_probs = ops.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = ops.bmm(attn_probs, value_states) + attn_output = ops.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)) + attn_output = ops.transpose(attn_output, (0, 2, 1, 3)) + attn_output = ops.reshape(attn_output, (bsz, tgt_len, embed_dim)) + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class CLIPMLP(nn.Cell): + def __init__( + self, + hidden_size: int = 1024, + intermediate_size: int = 4096, + hidden_act: str = "quick_gelu", + dtype: ms.dtype = ms.float32, + ) -> None: + super().__init__() + self.activation_fn = ACT2FN[hidden_act] + self.fc1 = nn.Dense(hidden_size, intermediate_size, dtype=dtype) + self.fc2 = nn.Dense(intermediate_size, hidden_size, dtype=dtype) + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/model/clip/network.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/clip/network.py new file mode 100644 index 0000000000..78e102a69b --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/clip/network.py @@ -0,0 +1,228 @@ +from typing import Tuple + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor + +from ..common_layer import LayerNorm +from .layer import CLIPMLP, CLIPAttention, CLIPVisionEmbeddings + + +class CLIPEncoderLayer(nn.Cell): + def __init__( + self, + hidden_size: int = 1024, + intermediate_size: int = 4096, + num_attention_heads: int = 16, + attention_dropout: float = 0.0, + layer_norm_eps: float = 1e-5, + hidden_act: str = "quick_gelu", + dtype: ms.dtype = ms.float32, + ) -> None: + super().__init__() + self.embed_dim = hidden_size + self.self_attn = CLIPAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout=attention_dropout, + dtype=dtype, + ) + self.layer_norm1 = LayerNorm((self.embed_dim,), epsilon=layer_norm_eps, dtype=dtype) + self.mlp = CLIPMLP( + hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act, dtype=dtype + ) + self.layer_norm2 = LayerNorm((self.embed_dim,), epsilon=layer_norm_eps, dtype=dtype) + + def construct(self, hidden_states: Tensor) -> Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class CLIPEncoder(nn.Cell): + def __init__( + self, + num_hidden_layers: int = 24, + hidden_size: int = 1024, + intermediate_size: int = 4096, + num_attention_heads: int = 16, + attention_dropout: float = 0.0, + layer_norm_eps: float = 1e-5, + hidden_act: str = "quick_gelu", + dtype: ms.dtype = ms.float32, + ) -> None: + super().__init__() + self.layers = nn.CellList( + [ + CLIPEncoderLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + attention_dropout=attention_dropout, + layer_norm_eps=layer_norm_eps, + hidden_act=hidden_act, + dtype=dtype, + ) + for _ in range(num_hidden_layers) + ] + ) + + def construct(self, inputs_embeds: Tensor) -> Tuple[Tensor, Tuple[Tensor, ...]]: + encoder_states = () + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + encoder_states = encoder_states + (hidden_states,) + hidden_states = encoder_layer(hidden_states) + + encoder_states = encoder_states + (hidden_states,) + return hidden_states, encoder_states + + +class CLIPVisionTransformer(nn.Cell): + def __init__( + self, + image_size: int = 336, + patch_size: int = 14, + num_channels: int = 3, + num_hidden_layers: int = 24, + hidden_size: int = 1024, + intermediate_size: int = 4096, + num_attention_heads: int = 16, + attention_dropout: float = 0.0, + layer_norm_eps: float = 1e-5, + hidden_act: str = "quick_gelu", + dtype: ms.dtype = ms.float32, + ) -> None: + super().__init__() + embed_dim = hidden_size + + self.embeddings = CLIPVisionEmbeddings( + hidden_size=hidden_size, + image_size=image_size, + patch_size=patch_size, + num_channels=num_channels, + dtype=dtype, + ) + self.pre_layrnorm = LayerNorm((embed_dim,), epsilon=layer_norm_eps, dtype=dtype) + self.encoder = CLIPEncoder( + num_hidden_layers=num_hidden_layers, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + attention_dropout=attention_dropout, + layer_norm_eps=layer_norm_eps, + hidden_act=hidden_act, + dtype=dtype, + ) + self.post_layernorm = LayerNorm((embed_dim,), epsilon=layer_norm_eps, dtype=dtype) + + def construct(self, pixel_values: Tensor) -> Tuple[Tensor, Tuple[Tensor, ...]]: + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + last_hidden_state, encoder_states = self.encoder( + inputs_embeds=hidden_states, + ) + + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + return pooled_output, encoder_states + + +class CLIPVisionModel(nn.Cell): + def __init__( + self, + image_size: int = 336, + patch_size: int = 14, + num_channels: int = 3, + num_hidden_layers: int = 24, + hidden_size: int = 1024, + intermediate_size: int = 4096, + num_attention_heads: int = 16, + attention_dropout: float = 0.0, + layer_norm_eps: float = 1e-5, + hidden_act: str = "quick_gelu", + dtype: ms.dtype = ms.float32, + **kwargs, + ) -> None: + super().__init__() + + self.vision_model = CLIPVisionTransformer( + image_size=image_size, + patch_size=patch_size, + num_channels=num_channels, + num_hidden_layers=num_hidden_layers, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + attention_dropout=attention_dropout, + layer_norm_eps=layer_norm_eps, + hidden_act=hidden_act, + dtype=dtype, + ) + + def get_input_embeddings(self) -> nn.Cell: + return self.vision_model.embeddings.patch_embedding + + @ms.jit + def construct(self, pixel_values: Tensor) -> Tuple[Tensor, Tuple[Tensor, ...]]: + pooled_output, hidden_states = self.vision_model(pixel_values) + + return pooled_output, hidden_states + + +class CLIPVisionModelWithProjection(nn.Cell): + def __init__( + self, + projection_dim: int = 768, + image_size: int = 336, + patch_size: int = 14, + num_channels: int = 3, + num_hidden_layers: int = 24, + hidden_size: int = 1024, + intermediate_size: int = 4096, + num_attention_heads: int = 16, + attention_dropout: float = 0.0, + layer_norm_eps: float = 1e-5, + hidden_act: str = "quick_gelu", + dtype: ms.dtype = ms.float32, + **kwargs, + ) -> None: + super().__init__() + + self.vision_model = CLIPVisionTransformer( + image_size=image_size, + patch_size=patch_size, + num_channels=num_channels, + num_hidden_layers=num_hidden_layers, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + attention_dropout=attention_dropout, + layer_norm_eps=layer_norm_eps, + hidden_act=hidden_act, + dtype=dtype, + ) + + self.visual_projection = nn.Dense(hidden_size, projection_dim, has_bias=False, dtype=dtype) + + def get_input_embeddings(self) -> nn.Cell: + return self.vision_model.embeddings.patch_embedding + + @ms.jit + def construct(self, pixel_values: Tensor) -> Tuple[Tensor, Tuple[Tensor, ...]]: + pooled_output, hidden_states = self.vision_model(pixel_values) + image_embeds = self.visual_projection(pooled_output) + + return image_embeds, hidden_states diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/model/common_layer.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/common_layer.py new file mode 100644 index 0000000000..950ad0c41b --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/common_layer.py @@ -0,0 +1,112 @@ +"""Holding the layer where the parameter name and type is consistent with Pytorch""" +import numbers +from typing import List, Optional, Tuple, Union + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Parameter, Tensor +from mindspore.common.initializer import Initializer, initializer + +__all__ = ["LayerNorm", "Embedding"] + + +class LayerNorm(nn.LayerNorm): + def __init__( + self, + normalized_shape: Union[Tuple[int], List[int]], + begin_norm_axis: int = -1, + begin_params_axis: int = -1, + gamma_init: Union[Tensor, str, Initializer, numbers.Number] = "ones", + beta_init: Union[Tensor, str, Initializer, numbers.Number] = "zeros", + epsilon: float = 1e-7, + dtype: ms.dtype = ms.float32, + ): + """Initialize LayerNorm.""" + super(nn.LayerNorm, self).__init__() + if not isinstance(normalized_shape, (tuple, list)): + raise TypeError( + f"For '{self.cls_name}', the type of 'normalized_shape' must be tuple[int] or list[int], " + f"but got {normalized_shape} and the type is {type(normalized_shape)}." + ) + if not normalized_shape: + raise ValueError( + f"Expected normalized_shape to be at least 1-dimensional, i.e., containing at " + f"least one element, but got normalized_shape = {normalized_shape}" + ) + self.normalized_shape = normalized_shape + self.begin_norm_axis = begin_norm_axis + self.begin_params_axis = begin_params_axis + self.epsilon = epsilon + self.weight = Parameter(initializer(gamma_init, normalized_shape, dtype=dtype)) + self.bias = Parameter(initializer(beta_init, normalized_shape, dtype=dtype)) + self.layer_norm = ops.LayerNorm( + begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis, epsilon=self.epsilon + ) + + def construct(self, input_x): + y, _, _ = self.layer_norm(input_x, self.weight.astype(input_x.dtype), self.bias.astype(input_x.dtype)) + return y + + def extend_repr(self): + return "normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}".format( + self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.weight, self.bias + ) + + +class Embedding(nn.Embedding): + def __init__( + self, + vocab_size: int, + embedding_size: int, + use_one_hot: bool = False, + embedding_table: Union[Tensor, str, Initializer, numbers.Number] = "normal", + dtype: ms.dtype = ms.float32, + padding_idx: Optional[int] = None, + ): + """Initialize Embedding.""" + super(nn.Embedding, self).__init__() + self.vocab_size = vocab_size + self.embedding_size = embedding_size + self.use_one_hot = use_one_hot + self.dtype = dtype + self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size], dtype=dtype) + self.padding_idx = padding_idx + if padding_idx is not None: + self.padding_idx = padding_idx + if isinstance(self.init_tensor, Tensor) and self.init_tensor.init is not None: + self.init_tensor = self.init_tensor.init_data() + self.init_tensor = self.init_tensor.asnumpy() + self.init_tensor[self.padding_idx] = 0 + self.init_tensor = Tensor(self.init_tensor) + self.weight = Parameter(self.init_tensor) + self.expand = ops.ExpandDims() + self.reshape_flat = ops.Reshape() + self.shp_flat = (-1,) + self.gather = ops.Gather() + self.one_hot = ops.OneHot() + self.on_value = Tensor(1.0, self.dtype) + self.off_value = Tensor(0.0, self.dtype) + self.array_mul = ops.MatMul() + self.reshape = ops.Reshape() + self.get_shp = ops.Shape() + self.concat = ops.Concat() + + def construct(self, ids): + out_shape = self.get_shp(ids) + (self.embedding_size,) + flat_ids = self.reshape_flat(ids, self.shp_flat) + + if self.use_one_hot: + one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) + output_for_reshape = self.array_mul(one_hot_ids, self.weight) + else: + output_for_reshape = self.gather(self.weight, flat_ids, 0) + + output = self.reshape(output_for_reshape, out_shape) + return output + + def extend_repr(self): + return ( + f"vocab_size={self.vocab_size}, embedding_size={self.embedding_size}, use_one_hot={self.use_one_hot}, " + f"embedding_table={self.weight}, dtype={self.dtype}, padding_idx={self.padding_idx}" + ) diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/model/llava_next/__init__.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/llava_next/__init__.py new file mode 100644 index 0000000000..782f49447c --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/llava_next/__init__.py @@ -0,0 +1 @@ +from .network import LlavaNextForConditionalGeneration diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/model/llava_next/network.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/llava_next/network.py new file mode 100644 index 0000000000..5ca25b8a7f --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/llava_next/network.py @@ -0,0 +1,449 @@ +import math +from typing import Any, Dict, List, Literal, Optional, Tuple + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Parameter, Tensor + +from ..activation import ACT2FN +from ..clip import CLIPVisionModel +from ..mistral import MistralForCausalLM +from ..padding import pad_along_axis +from .utils import get_anyres_image_grid_shape, image_size_to_num_patches, unpad_image + + +class LlavaNextMultiModalProjector(nn.Cell): + def __init__( + self, + vision_hidden_size: int = 1024, + text_hidden_size: int = 4096, + projector_hidden_act: str = "gelu", + dtype: ms.dtype = ms.float32, + ) -> None: + super().__init__() + + self.linear_1 = nn.Dense(vision_hidden_size, text_hidden_size, has_bias=True, dtype=dtype) + self.act = ACT2FN[projector_hidden_act] + self.linear_2 = nn.Dense(text_hidden_size, text_hidden_size, has_bias=True, dtype=dtype) + + def construct(self, image_features: Tensor) -> Tensor: + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class LlavaNextForConditionalGeneration(nn.Cell): + def __init__( + self, + vision_config: Dict[str, Any], + text_config: Dict[str, Any], + image_grid_pinpoints: List[Tuple[int, int]], + projector_hidden_act: str = "gelu", + ignore_index: int = -100, + image_token_index: int = 32000, + vision_feature_select_strategy: str = "default", + vision_feature_layer: int = -2, + attn_implementation: Literal["eager", "flash_attention"] = "eager", + language_model_input_method: Literal["padding", "dynamic"] = "padding", + dtype: ms.dtype = ms.float32, + **kwargs: Any, + ) -> None: + super().__init__() + self.dtype = dtype + + self.vision_tower = CLIPVisionModel(**vision_config, dtype=dtype) + self.multi_modal_projector = LlavaNextMultiModalProjector( + vision_config["hidden_size"], + text_config["hidden_size"], + projector_hidden_act=projector_hidden_act, + dtype=dtype, + ) + embed_std = 1 / math.sqrt(text_config["hidden_size"]) + self.image_newline = Parameter(Tensor(ops.randn(text_config["hidden_size"]) * embed_std, dtype=dtype)) + + self.vocab_size = text_config["vocab_size"] + self.language_model = MistralForCausalLM(**text_config, attn_implementation=attn_implementation, dtype=dtype) + + self.text_config = text_config + self.vision_config = vision_config + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.image_grid_pinpoints = image_grid_pinpoints + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + self.pad_token_id = -1 + self._padding_side = "left" + + self.language_model_input_method = language_model_input_method + if self.language_model_input_method == "dynamic": + self._is_language_model_compiled = False + + def get_input_embeddings(self) -> nn.Cell: + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value: nn.Cell) -> None: + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Cell: + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings: nn.Cell) -> None: + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder: nn.Cell) -> None: + self.language_model.set_decoder(decoder) + + def get_decoder(self) -> nn.Cell: + return self.language_model.get_decoder() + + def _merge_input_ids_with_image_features( + self, + image_features: Tensor, + feature_lens: Tensor, + inputs_embeds: Tensor, + input_ids: Tensor, + attention_mask: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor]: + # ! in llava 1.6, number of patches is variable + num_images = feature_lens.shape[0] + num_image_features, embed_dim = image_features.shape + if feature_lens.sum() != num_image_features: + raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}") + batch_size = input_ids.shape[0] + _left_padding = ops.any(attention_mask[:, 0] == 0) + _right_padding = ops.any(attention_mask[:, -1] == 0) + + left_padding = True + if batch_size > 1: + if _left_padding and not _right_padding: + left_padding = True + elif not _left_padding and _right_padding: + left_padding = False + elif not _left_padding and not _right_padding: + # both side is 1, so cannot tell + left_padding = self._padding_side == "left" + else: + # invalid attention_mask + raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") + + # Whether to turn off right padding + # 1. Create a mask to know where special image tokens are + special_image_token_mask = input_ids == self.image_token_index + # special_image_token_mask: [bsz, seqlen] + num_special_image_tokens = ops.sum(special_image_token_mask, dim=-1) + # num_special_image_tokens: [bsz] + # Reserve for padding of num_images + total_num_special_image_tokens = ops.sum(special_image_token_mask) + if total_num_special_image_tokens != num_images: + raise ValueError( + f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})." + ) + # Compute the maximum embed dimension + # max_image_feature_lens is max_feature_lens per batch + feature_lens_batch = ops.split(feature_lens, num_special_image_tokens.tolist(), axis=0) + feature_lens_batch_sum = ops.stack([x.sum() for x in feature_lens_batch]) + embed_sequence_lengths = ( + (attention_mask == 1).to(ms.int32).sum(-1) - num_special_image_tokens + feature_lens_batch_sum + ) + max_embed_dim = embed_sequence_lengths.max().item() + + batch_indices, non_image_indices = ops.nonzero( + ops.logical_and(input_ids != self.image_token_index, attention_mask == 1) + ).unbind(dim=1) + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + # ! instead of special_image_token_mask * (num_image_patches - 1) + # special_image_token_mask * (num_feature_len - 1) + special_image_token_mask = special_image_token_mask.to(ms.int32) + special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1 + new_token_positions = ops.cumsum((special_image_token_mask + 1), -1) - 1 + if left_padding: + # shift right token positions so that they are ending at the same number + # the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:] + new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:] + + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = ops.zeros((batch_size, max_embed_dim, embed_dim), dtype=inputs_embeds.dtype) + final_attention_mask = ops.zeros((batch_size, max_embed_dim), dtype=attention_mask.dtype) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + image_to_overwrite = ops.full((batch_size, max_embed_dim), True, dtype=ms.bool_) + image_to_overwrite[batch_indices, text_to_overwrite] = False + embed_indices = ops.arange(max_embed_dim).unsqueeze(0) + embed_indices = ops.broadcast_to(embed_indices, (batch_size, max_embed_dim)) + embed_seq_lens = embed_sequence_lengths[:, None] + + if left_padding: + # exclude padding on the left + val = (max_embed_dim - embed_indices) <= embed_seq_lens + else: + # exclude padding on the right + val = embed_indices < embed_seq_lens + image_to_overwrite = ops.logical_and(image_to_overwrite, val) + + if image_to_overwrite.sum() != num_image_features: + raise ValueError( + f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. " + f"The number of image tokens is {ops.sum(special_image_token_mask)} while" + f" the number of image given to the model is {num_images}. " + f"This prevents correct indexing and breaks batch generation." + ) + final_embedding[image_to_overwrite] = image_features.reshape(-1, embed_dim) + final_attention_mask |= image_to_overwrite + position_ids = ops.masked_fill( + ops.cumsum(final_attention_mask.to(ms.int32), -1) - 1, final_attention_mask == 0, Tensor(1, dtype=ms.int32) + ) + + return final_embedding, final_attention_mask, position_ids + + def pack_image_features( + self, image_features: Tensor, image_sizes: Tensor, image_newline: Optional[Tensor] = None + ) -> Tuple[Tensor, Tensor]: + new_image_features = [] + feature_lens = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.vision_config["image_size"] // self.vision_config["patch_size"] + if height * width != base_image_feature.shape[0]: + raise ValueError("The number of patches is not consistent with the image size.") + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.image_grid_pinpoints, + self.vision_config["image_size"], + ) + image_feature = ops.reshape(image_feature, (num_patch_height, num_patch_width, height, width, -1)) + image_feature = ops.transpose(image_feature, (4, 0, 2, 1, 3)) + image_feature = ops.flatten(image_feature, start_dim=1, end_dim=2) + image_feature = ops.flatten(image_feature, start_dim=2, end_dim=3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + if image_newline is not None: + image_feature = ops.concat( + ( + image_feature, + ops.broadcast_to(image_newline[:, None, None], (*image_feature.shape[:-1], 1)).to( + image_feature.dtype + ), + ), + axis=-1, + ) + image_feature = ops.flatten(image_feature, start_dim=1, end_dim=2) + image_feature = ops.transpose(image_feature, (1, 0)) + image_feature = ops.concat((base_image_feature, image_feature), axis=0) + else: + image_feature = image_feature[0] + if image_newline is not None: + image_feature = ops.concat((image_feature, image_newline[None].to(image_feature)), axis=0) + new_image_features.append(image_feature) + feature_lens.append(image_feature.shape[0]) + image_features = ops.concat(new_image_features, axis=0) + feature_lens = Tensor(feature_lens, dtype=ms.int32) + return image_features, feature_lens + + def construct( + self, + input_ids: Optional[Tensor] = None, + pixel_values: Optional[Tensor] = None, + image_sizes: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + past_key_cache_list: Optional[Tensor] = None, + past_value_cache_list: Optional[Tensor] = None, + return_key_value_cache: bool = False, + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + # 1. Extract the input embeddings + # In case image_token_index is not in the embeddings (extra token but embedding don't have it) + for_inputs_embeds_ids = input_ids.copy() + for_inputs_embeds_ids[input_ids == self.image_token_index] = 0 + inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids) + + # 2. Merge text and images + if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size > 0: + # ! infer image_num_patches from image_sizes + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=self.image_grid_pinpoints, + patch_size=self.vision_config["image_size"], + ) + for imsize in image_sizes + ] + # figure out if pixel_values is concatenated or stacked + if len(pixel_values.shape) == 5: + # stacking when input is (batch_size, num_patches, num_channels, height, width) + _pixel_values_list = [ + pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches) + ] + pixel_values = ops.concat(_pixel_values_list, axis=0) + elif len(pixel_values.shape) != 4: + # otherwise has to be stacked from list of (num_patches, num_channels, height, width) + raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") + + _, hidden_states = self.vision_tower(pixel_values) + selected_image_feature = hidden_states[self.vision_feature_layer] + + if self.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + + image_features = self.multi_modal_projector(selected_image_feature) + image_features = ops.split(image_features, image_num_patches, axis=0) + image_features, feature_lens = self.pack_image_features( + image_features, + image_sizes, + image_newline=self.image_newline, + ) + + inputs_embeds = inputs_embeds.to(image_features.dtype) + inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features( + image_features, + feature_lens, + inputs_embeds, + input_ids, + attention_mask, + ) + + elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size == 0: + # there are no images + pass + + elif ( + past_key_cache_list is not None + and past_value_cache_list is not None + and pixel_values is not None + and input_ids.shape[1] == 1 + ): + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_cache_list[0, :, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: + # https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = ops.nonzero( + first_layer_past_key_value.to(ms.float32).sum(-2) == 0 + ).unbind(dim=1) + + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = ops.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.shape[-1] + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + if len(new_batch_index) > 0 and len(new_non_attended_tokens) > 0: + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = ops.concat((extended_attention_mask, attention_mask[:, -target_length:]), axis=1) + + position_ids = ops.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + + if self.language_model_input_method == "padding": + attention_mask = pad_along_axis(attention_mask, axis=-1) + past_key_cache_list = pad_along_axis(past_key_cache_list, axis=-2, shift=-1) + past_value_cache_list = pad_along_axis(past_value_cache_list, axis=-2, shift=-1) + + if self.language_model_input_method == "dynamic" and not self._is_language_model_compiled: + if past_key_cache_list is not None or past_value_cache_list is not None: + raise ValueError("Dynamic shape compling is not supported with KV caching yet.") + attention_mask_shape = list(attention_mask.shape) + position_ids_shape = list(position_ids.shape) + inputs_embeds_shape = list(inputs_embeds.shape) + + attention_mask_shape[-1] = None + position_ids_shape[-1] = None + inputs_embeds_shape[-2] = None + + self.language_model.set_inputs( + attention_mask=Tensor(shape=attention_mask_shape, dtype=attention_mask.dtype), + position_ids=Tensor(shape=position_ids_shape, dtype=position_ids.dtype), + inputs_embeds=Tensor(shape=inputs_embeds_shape, dtype=inputs_embeds.dtype), + ) + self._is_language_model_compiled = True + + logits, key_cache_list, value_cache_list = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_cache_list=past_key_cache_list, + past_value_cache_list=past_value_cache_list, + return_key_value_cache=return_key_value_cache, + ) + + return logits, key_cache_list, value_cache_list + + def prepare_inputs_for_generation( + self, + input_ids: Tensor, + pixel_values: Optional[Tensor] = None, + image_sizes: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + past_key_cache_list: Optional[Tensor] = None, + past_value_cache_list: Optional[Tensor] = None, + return_key_value_cache: bool = False, + **kwargs, + ) -> Dict[str, Optional[Tensor]]: + if past_key_cache_list is not None and past_value_cache_list is not None: + past_length = past_value_cache_list.shape[-2] + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + elif self.image_token_index in input_ids.asnumpy(): + input_ids = input_ids[:, input_ids.shape[1] - 1 :] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = ops.cumsum(attention_mask.to(ms.int32), -1) - 1 + position_ids = ops.masked_fill(position_ids, attention_mask == 0, Tensor(1, dtype=ms.int32)) + if past_key_cache_list is not None and past_value_cache_list is not None: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + model_inputs = dict( + { + "input_ids": input_ids, + "position_ids": position_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values.to(self.dtype), + "image_sizes": image_sizes, + "past_key_cache_list": past_key_cache_list, + "past_value_cache_list": past_value_cache_list, + "return_key_value_cache": return_key_value_cache, + } + ) + return model_inputs diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/model/llava_next/utils.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/llava_next/utils.py new file mode 100644 index 0000000000..959d74db1b --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/llava_next/utils.py @@ -0,0 +1,95 @@ +import math +from typing import List, Tuple, Union + +import numpy as np + +from mindspore import Tensor + + +def get_anyres_image_grid_shape( + image_size: Union[Tuple[int, int], List[int], Tensor, np.ndarray], + grid_pinpoints: List[Tuple[int, int]], + patch_size: int, +) -> Tuple[int, int]: + if not isinstance(grid_pinpoints, list): + raise ValueError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (Tensor, np.ndarray)): + raise ValueError( + f"image_size invalid type: {type(image_size)} not valid, should be either list, tuple, np.ndarray or tensor" + ) + image_size = image_size.tolist() + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +def image_size_to_num_patches( + image_size: Union[Tuple[int, int], List[int], Tensor, np.ndarray], + grid_pinpoints: List[Tuple[int, int]], + patch_size: int, +) -> int: + if not isinstance(grid_pinpoints, list): + raise ValueError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (Tensor, np.ndarray)): + raise ValueError(f"image_size invalid type {type(image_size)} with value {image_size}") + image_size = image_size.tolist() + + best_resolution = select_best_resolution(image_size, grid_pinpoints) + height, width = best_resolution + num_patches = math.ceil(height / patch_size) * math.ceil(width / patch_size) + 1 + return num_patches + + +def unpad_image(tensor: Tensor, original_size: Union[Tuple[int, int], Tensor]) -> Tensor: + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + # HACK: need to be int here, since for ms, ms.int / ms.int -> ms.int + if isinstance(original_size, Tensor): + original_height, original_width = original_height.item(), original_width.item() + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + +def select_best_resolution( + original_size: Tuple[int, int], possible_resolutions: List[Tuple[int, int]] +) -> Tuple[int, int]: + original_height, original_width = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + for height, width in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution + ): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (height, width) + + return best_fit diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/model/mistral/__init__.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/mistral/__init__.py new file mode 100644 index 0000000000..094ce7f2cf --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/mistral/__init__.py @@ -0,0 +1 @@ +from .network import MistralForCausalLM diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/model/mistral/layer.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/mistral/layer.py new file mode 100644 index 0000000000..6b46cdadaf --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/mistral/layer.py @@ -0,0 +1,288 @@ +import logging +from typing import Optional, Tuple + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Parameter, Tensor +from mindspore.ops.operations.nn_ops import FlashAttentionScore + +from ..activation import ACT2FN + +logger = logging.getLogger(__name__) + + +class MistralRMSNorm(nn.Cell): + def __init__(self, hidden_size: int, eps: float = 1e-6, dtype: ms.dtype = ms.float32) -> None: + super().__init__() + self.weight = Parameter(ops.ones(hidden_size, dtype=dtype)) + self.variance_epsilon = eps + + def construct(self, hidden_states: Tensor) -> Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(ms.float32) + variance = ops.pow(hidden_states, 2) + variance = ops.mean(variance, axis=-1, keep_dims=True) + hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class MistralRotaryEmbedding(nn.Cell): + def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0) -> None: + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.inv_freq = 1.0 / (self.base ** (ops.arange(0, self.dim, 2, dtype=ms.float32) / self.dim)) + + def construct(self, x: Tensor, position_ids: Tensor) -> Tuple[Tensor, Tensor]: + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = ops.broadcast_to(self.inv_freq[None, :, None], (position_ids.shape[0], -1, 1)) + position_ids_expanded = position_ids[:, None, :].to(ms.float32) + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + freqs = ops.matmul(inv_freq_expanded.to(ms.float32), position_ids_expanded.to(ms.float32)) + freqs = ops.transpose(freqs, (0, 2, 1)) + emb = ops.concat((freqs, freqs), axis=-1) + cos = ops.cos(emb) + sin = ops.sin(emb) + return cos.to(x.dtype), sin.to(x.dtype) + + +def rotate_half(x: Tensor) -> Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return ops.concat((-x2, x1), axis=-1) + + +def apply_rotary_pos_emb( + q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, unsqueeze_dim: int = 1 +) -> Tuple[Tensor, Tensor]: + cos = ops.unsqueeze(cos, unsqueeze_dim) + sin = ops.unsqueeze(sin, unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MistralMLP(nn.Cell): + def __init__( + self, + intermediate_size: int = 14336, + hidden_size: int = 4096, + hidden_act: str = "silu", + dtype: ms.dtype = ms.float32, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Dense(self.hidden_size, self.intermediate_size, has_bias=False, dtype=dtype) + self.up_proj = nn.Dense(self.hidden_size, self.intermediate_size, has_bias=False, dtype=dtype) + self.down_proj = nn.Dense(self.intermediate_size, self.hidden_size, has_bias=False, dtype=dtype) + self.act_fn = ACT2FN[hidden_act] + + def construct(self, hidden_state: Tensor) -> Tensor: + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :] + hidden_states = ops.broadcast_to(hidden_states, (batch, num_key_value_heads, n_rep, slen, head_dim)) + hidden_states = ops.reshape(hidden_states, (batch, num_key_value_heads * n_rep, slen, head_dim)) + return hidden_states + + +class MistralAttention(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + max_position_embeddings: int = 32768, + rope_theta: float = 1000000.0, + attention_dropout: float = 0.0, + dtype: ms.dtype = ms.float32, + ) -> None: + super().__init__() + + self.attention_dropout = attention_dropout + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Dense(self.hidden_size, self.num_heads * self.head_dim, has_bias=False, dtype=dtype) + self.k_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=False, dtype=dtype) + self.v_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=False, dtype=dtype) + self.o_proj = nn.Dense(self.num_heads * self.head_dim, self.hidden_size, has_bias=False, dtype=dtype) + + self.rotary_emb = MistralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def construct( + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + past_key_cache: Optional[Tensor] = None, + past_value_cache: Optional[Tensor] = None, + return_key_value_cache: bool = False, + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + bsz, q_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = ops.reshape(query_states, (bsz, q_len, self.num_heads, self.head_dim)) + query_states = ops.transpose(query_states, (0, 2, 1, 3)) + + key_states = ops.reshape(key_states, (bsz, q_len, self.num_key_value_heads, self.head_dim)) + key_states = ops.transpose(key_states, (0, 2, 1, 3)) + + value_states = ops.reshape(value_states, (bsz, q_len, self.num_key_value_heads, self.head_dim)) + value_states = ops.transpose(value_states, (0, 2, 1, 3)) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if return_key_value_cache: + key_cache, value_cache = key_states, value_states + else: + key_cache, value_cache = None, None + + if past_key_cache is not None and past_value_cache is not None: + key_states = ops.concat([past_key_cache, key_states], axis=-2) + value_states = ops.concat([past_value_cache, value_states], axis=-2) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states = ops.transpose(key_states, (0, 1, 3, 2)) + attn_weights = ops.matmul(query_states, key_states) / ms.numpy.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = ops.softmax(attn_weights.to(ms.float32), axis=-1).to(query_states.dtype) + attn_weights = ops.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = ops.matmul(attn_weights, value_states) + + attn_output = ops.transpose(attn_output, (0, 2, 1, 3)) + attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) + attn_output = self.o_proj(attn_output) + + return attn_output, key_cache, value_cache + + +class MistralFlashAttention(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + max_position_embeddings: int = 32768, + rope_theta: float = 1000000.0, + attention_dropout: float = 0.0, + dtype: ms.dtype = ms.float32, + ) -> None: + super().__init__() + + self.attention_dropout = attention_dropout + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Dense(self.hidden_size, self.num_heads * self.head_dim, has_bias=False, dtype=dtype) + self.k_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=False, dtype=dtype) + self.v_proj = nn.Dense(self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=False, dtype=dtype) + self.o_proj = nn.Dense(self.num_heads * self.head_dim, self.hidden_size, has_bias=False, dtype=dtype) + + self.rotary_emb = MistralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + self.flash_attention = FlashAttentionScore( + self.num_heads, keep_prob=1 - self.attention_dropout, scale_value=self.head_dim**-0.5, input_layout="BSND" + ) + + def construct( + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + past_key_cache: Optional[Tensor] = None, + past_value_cache: Optional[Tensor] = None, + return_key_value_cache: bool = False, + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + bsz, q_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = ops.reshape(query_states, (bsz, q_len, self.num_heads, self.head_dim)) + query_states = ops.transpose(query_states, (0, 2, 1, 3)) + + key_states = ops.reshape(key_states, (bsz, q_len, self.num_key_value_heads, self.head_dim)) + key_states = ops.transpose(key_states, (0, 2, 1, 3)) + + value_states = ops.reshape(value_states, (bsz, q_len, self.num_key_value_heads, self.head_dim)) + value_states = ops.transpose(value_states, (0, 2, 1, 3)) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if return_key_value_cache: + key_cache, value_cache = key_states, value_states + else: + key_cache, value_cache = None, None + + if past_key_cache is not None and past_value_cache is not None: + key_states = ops.concat([key_states, past_key_cache], axis=-2) + value_states = ops.concat([value_states, past_value_cache], axis=-2) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Reshape to the expected shape and dtype for Flash Attention + query_states = ops.transpose(query_states, (0, 2, 1, 3)) + key_states = ops.transpose(key_states, (0, 2, 1, 3)) + value_states = ops.transpose(value_states, (0, 2, 1, 3)) + attention_mask = attention_mask.to(ms.uint8) + + _, _, _, attn_output = self.flash_attention( + query_states, key_states, value_states, None, None, None, attention_mask + ) + attn_output = ops.reshape(attn_output, (bsz, q_len, -1)) + attn_output = self.o_proj(attn_output) + + return attn_output, key_cache, value_cache diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/model/mistral/network.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/mistral/network.py new file mode 100644 index 0000000000..108ccb6cc4 --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/mistral/network.py @@ -0,0 +1,295 @@ +from typing import Literal, Optional, Tuple + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor + +from ..common_layer import Embedding +from .layer import MistralAttention, MistralFlashAttention, MistralMLP, MistralRMSNorm + +MISTRAL_ATTENTION_CLASSES = { + "eager": MistralAttention, + "flash_attention": MistralFlashAttention, +} + + +class MistralDecoderLayer(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + intermediate_size: int = 14336, + num_attention_heads: int = 32, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-5, + max_position_embeddings: int = 32768, + rope_theta: float = 1000000.0, + attention_dropout: float = 0.0, + hidden_act: str = "silu", + attn_implementation: Literal["eager", "flash_attention"] = "eager", + dtype: ms.dtype = ms.float32, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + + self.self_attn = MISTRAL_ATTENTION_CLASSES[attn_implementation]( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + max_position_embeddings=max_position_embeddings, + rope_theta=rope_theta, + attention_dropout=attention_dropout, + dtype=dtype, + ) + + self.mlp = MistralMLP( + intermediate_size=intermediate_size, hidden_size=hidden_size, hidden_act=hidden_act, dtype=dtype + ) + self.input_layernorm = MistralRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) + self.post_attention_layernorm = MistralRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) + + def construct( + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + past_key_cache: Optional[Tensor] = None, + past_value_cache: Optional[Tensor] = None, + return_key_value_cache: bool = False, + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, key_cache, value_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_cache=past_key_cache, + past_value_cache=past_value_cache, + return_key_value_cache=return_key_value_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, key_cache, value_cache + + +class MistralModel(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + intermediate_size: int = 14336, + max_position_embeddings: int = 32768, + num_attention_heads: int = 32, + num_hidden_layers: int = 32, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-5, + rope_theta: float = 1000000.0, + vocab_size: int = 32064, + attention_dropout: float = 0.0, + hidden_act: str = "silu", + pad_token_id: Optional[int] = None, + attn_implementation: Literal["eager", "flash_attention"] = "eager", + dtype: ms.dtype = ms.float32, + ) -> None: + super().__init__() + self.padding_idx = pad_token_id + self.vocab_size = vocab_size + self.attn_implementation = attn_implementation + + self.embed_tokens = Embedding(vocab_size, hidden_size, padding_idx=self.padding_idx, dtype=dtype) + self.layers = nn.CellList( + [ + MistralDecoderLayer( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + max_position_embeddings=max_position_embeddings, + rope_theta=rope_theta, + attention_dropout=attention_dropout, + hidden_act=hidden_act, + attn_implementation=attn_implementation, + dtype=dtype, + ) + for _ in range(num_hidden_layers) + ] + ) + self.norm = MistralRMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype) + + def get_input_embeddings(self) -> nn.Cell: + return self.embed_tokens + + def set_input_embeddings(self, value: nn.Cell) -> None: + self.embed_tokens = value + + def construct( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + inputs_embeds: Optional[Tensor] = None, + past_key_cache_list: Optional[Tensor] = None, + past_value_cache_list: Optional[Tensor] = None, + return_key_value_cache: bool = False, + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + past_seen_tokens = past_key_cache_list.shape[-2] if past_key_cache_list is not None else 0 + cache_position = ops.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], dtype=ms.int32) + if position_ids is None: + position_ids = ops.unsqueeze(cache_position, 0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + + hidden_states = inputs_embeds + + if return_key_value_cache: + key_cache_list, value_cache_list = [], [] + else: + key_cache_list, value_cache_list = None, None + + for i, decoder_layer in enumerate(self.layers): + if past_key_cache_list is not None and past_value_cache_list is not None: + past_key_cache, past_value_cache = past_key_cache_list[i], past_value_cache_list[i] + else: + past_key_cache, past_value_cache = None, None + + hidden_states, key_cache, value_cache = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_cache=past_key_cache, + past_value_cache=past_value_cache, + return_key_value_cache=return_key_value_cache, + ) + + if return_key_value_cache: + key_cache_list.append(key_cache) + value_cache_list.append(value_cache) + + hidden_states = self.norm(hidden_states) + + if return_key_value_cache: + key_cache_list = ops.stack(key_cache_list) + value_cache_list = ops.stack(value_cache_list) + + return hidden_states, key_cache_list, value_cache_list + + def _update_causal_mask(self, attention_mask: Tensor, input_tensor: Tensor, cache_position: Tensor) -> Tensor: + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + target_length = attention_mask.shape[-1] + + if len(attention_mask.shape) == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + fill_value = -ms.numpy.inf if self.attn_implementation == "eager" else 1.0 + causal_mask = ops.full((sequence_length, target_length), fill_value=fill_value, dtype=dtype) + exclude_mask = ops.arange(target_length) > cache_position.reshape(-1, 1) + causal_mask = ops.masked_fill(causal_mask, ~exclude_mask, Tensor(0, dtype=dtype)) + causal_mask = ops.broadcast_to(causal_mask[None, None, :, :], (input_tensor.shape[0], 1, -1, -1)) + if len(attention_mask.shape) == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = ops.masked_fill( + causal_mask[:, :, :, :mask_length], padding_mask, Tensor(fill_value, dtype=dtype) + ) + + return causal_mask + + +class MistralForCausalLM(nn.Cell): + def __init__( + self, + hidden_size: int = 4096, + intermediate_size: int = 14336, + max_position_embeddings: int = 32768, + num_attention_heads: int = 32, + num_hidden_layers: int = 32, + num_key_value_heads: int = 8, + rms_norm_eps: float = 1e-5, + rope_theta: float = 1000000.0, + vocab_size: int = 32000, + attention_dropout: float = 0.0, + hidden_act: str = "silu", + pad_token_id: Optional[int] = None, + attn_implementation: Literal["eager", "flash_attention"] = "eager", + dtype: ms.dtype = ms.float32, + **kwargs, + ) -> None: + super().__init__() + + self.model = MistralModel( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + max_position_embeddings=max_position_embeddings, + num_attention_heads=num_attention_heads, + num_hidden_layers=num_hidden_layers, + num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + rope_theta=rope_theta, + vocab_size=vocab_size, + attention_dropout=attention_dropout, + hidden_act=hidden_act, + pad_token_id=pad_token_id, + attn_implementation=attn_implementation, + dtype=dtype, + ) + self.vocab_size = vocab_size + self.lm_head = nn.Dense(hidden_size, vocab_size, has_bias=False, dtype=dtype) + + def get_input_embeddings(self) -> nn.Cell: + return self.model.embed_tokens + + def set_input_embeddings(self, value: nn.Cell) -> None: + self.model.embed_tokens = value + + def get_output_embeddings(self) -> nn.Cell: + return self.lm_head + + def set_output_embeddings(self, new_embeddings: nn.Cell) -> None: + self.lm_head = new_embeddings + + def set_decoder(self, decoder: nn.Cell) -> None: + self.model = decoder + + def get_decoder(self) -> nn.Cell: + return self.model + + @ms.jit + def construct( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + inputs_embeds: Optional[Tensor] = None, + past_key_cache_list: Optional[Tensor] = None, + past_value_cache_list: Optional[Tensor] = None, + return_key_value_cache: bool = False, + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + hidden_states, key_cache_list, value_cache_list = self.model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_cache_list=past_key_cache_list, + past_value_cache_list=past_value_cache_list, + return_key_value_cache=return_key_value_cache, + ) + logits = self.lm_head(hidden_states).to(ms.float32) + return logits, key_cache_list, value_cache_list diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/model/padding.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/padding.py new file mode 100644 index 0000000000..367db56352 --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/llava/model/padding.py @@ -0,0 +1,30 @@ +import math +from numbers import Number +from typing import Optional + +import mindspore.ops as ops +from mindspore import Tensor + + +def pad_along_axis( + x: Tensor, + value: Optional[Number] = None, + multiplier: int = 512, + axis: int = -1, + shift: int = 0, + padding_direction: str = "right", +) -> Tensor: + if axis >= 0: + raise ValueError("Input `axis` must be a negative number.") + + shape = x.shape + max_value = math.ceil(shape[axis] / multiplier) * multiplier + pad_num = max(max_value - shape[axis] + shift, 0) + + if pad_num == 0: + return x + + padding = (0, pad_num) if padding_direction == "right" else (pad_num, 0) + padding = (-axis - 1) * (0, 0) + padding + x = ops.pad(x, padding=padding, value=value) + return x diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/pipeline/__init__.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/pipeline/__init__.py new file mode 100644 index 0000000000..bfa5257cff --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/llava/pipeline/__init__.py @@ -0,0 +1 @@ +from .text_generation import TextGenerator diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/pipeline/helpers/__init__.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/pipeline/helpers/__init__.py new file mode 100644 index 0000000000..a1d0001774 --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/llava/pipeline/helpers/__init__.py @@ -0,0 +1 @@ +from .stopping_criteria import * diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/pipeline/helpers/stopping_criteria.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/pipeline/helpers/stopping_criteria.py new file mode 100644 index 0000000000..9103b252a5 --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/llava/pipeline/helpers/stopping_criteria.py @@ -0,0 +1,56 @@ +import abc +import logging +from typing import List, Optional, Union + +import mindspore as ms +import mindspore.ops as ops +from mindspore import Tensor + +logger = logging.getLogger(__name__) + + +__all__ = ["MaxLengthCriteria", "EosTokenCriteria", "StoppingCriteriaList"] + + +class StoppingCriteria(abc.ABC): + @abc.abstractmethod + def __call__(self, input_ids: Tensor) -> Tensor: + raise NotImplementedError("StoppingCriteria needs to be subclassed") + + +class MaxLengthCriteria(StoppingCriteria): + def __init__(self, max_length: int) -> None: + self.max_length = max_length + + def __call__(self, input_ids: Tensor) -> Tensor: + cur_len = input_ids.shape[-1] + is_done = cur_len >= self.max_length + return ops.full((input_ids.shape[0],), is_done, dtype=ms.bool_) + + +class EosTokenCriteria(StoppingCriteria): + def __init__(self, eos_token_id: Union[int, List[int], Tensor]) -> None: + if not isinstance(eos_token_id, Tensor): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id = Tensor(eos_token_id) + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: Tensor) -> Tensor: + is_done = ms.numpy.isin(input_ids[:, -1], self.eos_token_id) + return is_done + + +class StoppingCriteriaList(list): + def __call__(self, input_ids: Tensor) -> Tensor: + is_done = ops.full((input_ids.shape[0],), False, dtype=ms.bool_) + for criteria in self: + is_done = ops.logical_or(is_done, criteria(input_ids)) + return is_done + + @property + def max_length(self) -> Optional[int]: + for stopping_criterium in self: + if isinstance(stopping_criterium, MaxLengthCriteria): + return stopping_criterium.max_length + return None diff --git a/examples/opensora_hpcai/tools/caption/llava_next/llava/pipeline/text_generation.py b/examples/opensora_hpcai/tools/caption/llava_next/llava/pipeline/text_generation.py new file mode 100644 index 0000000000..5003c62d53 --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/llava/pipeline/text_generation.py @@ -0,0 +1,275 @@ +import logging +from typing import Dict, Optional, Tuple + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor + +from .helpers import EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList + +logger = logging.getLogger(__name__) + + +class TextGenerator: + def __init__( + self, + model: nn.Cell, + bos_token_id: int = 1, + eos_token_id: int = 2, + pad_token_id: Optional[int] = None, + max_new_tokens: Optional[int] = 100, + min_new_tokens: Optional[int] = None, + use_kv_cache: bool = False, + ) -> None: + self.model = model.set_train(False) + for param in self.model.trainable_params(): + param.requires_grad = False + + self._bos_token_id = bos_token_id + self._eos_token_id = eos_token_id + self._pad_token_id = pad_token_id + self._max_new_tokens = max_new_tokens + self._min_new_tokens = min_new_tokens + self._use_kv_cache = use_kv_cache + + self._max_length: Optional[int] = None + self._min_length: Optional[int] = None + + if not hasattr(self.model, "prepare_inputs_for_generation"): + raise NotImplementedError( + "A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`." + ) + + if self._use_kv_cache: + self._past_key_cache_list: Optional[Tensor] = None + self._past_value_cache_list: Optional[Tensor] = None + + def _prepare_model_inputs( + self, bos_token_id: Optional[Tensor] = None, model_kwargs: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Dict[str, Tensor]]: + input_name = "input_ids" # support inputs id only + model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} + + inputs = model_kwargs.pop(input_name, None) + + # if `inputs` is still None, try to create `input_ids` from BOS token + inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) + return inputs, model_kwargs + + def _maybe_initialize_input_ids_for_generation( + self, + inputs: Optional[Tensor] = None, + bos_token_id: Optional[Tensor] = None, + model_kwargs: Optional[Dict[str, Tensor]] = None, + ) -> Tensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, Tensor): + batch_size = value.shape[0] + break + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + + return ops.ones((batch_size, 1), dtype=ms.int32) * bos_token_id + + def _prepare_attention_mask_for_generation( + self, inputs: Tensor, pad_token_id: Optional[Tensor], eos_token_id: Optional[Tensor] + ) -> Tensor: + # No information for attention mask inference -> return default attention mask + default_attention_mask = ops.ones(inputs.shape[:2], dtype=ms.int32, device=inputs.device) + if pad_token_id is None: + return default_attention_mask + + is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [ms.int32, ms.int64] + if not is_input_ids: + return default_attention_mask + + is_pad_token_in_inputs = (pad_token_id is not None) and ( + ms.numpy.isin(element=inputs, test_elements=pad_token_id).any() + ) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~( + ms.numpy.isin(element=eos_token_id, test_elements=pad_token_id).any() + ) + can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id + attention_mask_from_padding = inputs.ne(pad_token_id).to(ms.int32) + + attention_mask = ( + attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask + ) + return attention_mask + + def _update_model_kwargs_for_generation( + self, + model_kwargs: Dict[str, Tensor], + key_cache_list: Optional[Tensor] = None, + value_cache_list: Optional[Tensor] = None, + ) -> Dict[str, Tensor]: + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = ops.concat( + [attention_mask, ops.ones((attention_mask.shape[0], 1), dtype=attention_mask.dtype)], axis=-1 + ) + + # update kv cache + if key_cache_list is not None and value_cache_list is not None: + if self._past_key_cache_list is not None and self._past_value_cache_list is not None: + self._past_key_cache_list = ops.concat([self._past_key_cache_list, key_cache_list], axis=-2) + self._past_value_cache_list = ops.concat([self._past_value_cache_list, value_cache_list], axis=-2) + else: + self._past_key_cache_list = key_cache_list + self._past_value_cache_list = value_cache_list + + model_kwargs["past_key_cache_list"] = self._past_key_cache_list + model_kwargs["past_value_cache_list"] = self._past_value_cache_list + + return model_kwargs + + def _get_stopping_criteria(self) -> StoppingCriteriaList: + criteria = StoppingCriteriaList() + if self._max_length is not None: + criteria.append(MaxLengthCriteria(self._max_length)) + + if self._eos_token_id is not None: + criteria.append(EosTokenCriteria(eos_token_id=self._eos_token_id)) + return criteria + + def _prepare_generated_length(self, input_ids_length: int) -> None: + """Prepared max and min length in generaion configs to avoid clashes between similar attributes""" + if self._max_new_tokens is not None: + self._max_length = self._max_new_tokens + input_ids_length + if self._min_new_tokens is not None: + self._min_length = self._min_new_tokens + input_ids_length + + def _prepare_special_tokens(self, kwargs_has_attention_mask: Optional[bool] = None): + # Convert special tokens to tensors (if they exist either in kwargs or in self.config) + def _tensor_or_none(token): + if token is None or isinstance(token, Tensor): + return token + return Tensor(token, dtype=ms.int32) + + bos_token_id = _tensor_or_none(self._bos_token_id) + eos_token_id = _tensor_or_none(self._eos_token_id) + pad_token_id = _tensor_or_none(self._pad_token_id) + + # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). + if eos_token_id is not None and eos_token_id.ndim == 0: + eos_token_id = eos_token_id.unsqueeze(0) + + # Set pad token if unset (and there are conditions to do so) + if pad_token_id is None and eos_token_id is not None: + if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + pad_token_id = eos_token_id[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.") + + # we can't infer attn mask if pad token is set to be eos token in model's generation config + if eos_token_id is not None and ms.numpy.isin(element=eos_token_id, test_elements=pad_token_id).any(): + if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: + logger.warning( + "The attention mask is not set and cannot be inferred from input because pad token is same as eos token." + "As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` " + "to obtain reliable results." + ) + + # Sanity checks/warnings + if eos_token_id is not None and (eos_token_id < 0).any(): + logger.warning( + f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will not " + "stop until the maximum length is reached. Depending on other flags, it may even crash." + ) + + # Update generation config with the updated special tokens tensors + self._bos_token_id = bos_token_id + self._eos_token_id = eos_token_id + self._pad_token_id = pad_token_id + + def generate(self, **model_kwargs) -> Tensor: + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + + # Define model inputs + input_ids, model_kwargs = self._prepare_model_inputs(self._bos_token_id, model_kwargs) + batch_size = input_ids.shape[0] + self._prepare_special_tokens(kwargs_has_attention_mask) + + # decoder-only models must use left-padding for batched generation. + # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` + if ( + self._pad_token_id is not None + and batch_size > 1 + and len(input_ids.shape) == 2 + and ops.sum(input_ids[:, -1] == self._pad_token_id) > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + + if not kwargs_has_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + input_ids, self._pad_token_id, self._eos_token_id + ) + + # prepare `max_length` depending on other stopping criteria. + input_ids_length = input_ids.shape[-1] + self._prepare_generated_length(input_ids_length) + + # reset cache if neccesary + if self._use_kv_cache: + self._past_key_cache_list, self._past_value_cache_list = None, None + + # prepare stopping criteria + prepared_stopping_criteria = self._get_stopping_criteria() + + # run sample + result = self._sample(input_ids, stopping_criteria=prepared_stopping_criteria, **model_kwargs) + + return result + + def _sample(self, input_ids: Tensor, stopping_criteria: StoppingCriteriaList, **model_kwargs: Tensor) -> Tensor: + # init values + pad_token_id = self._pad_token_id + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + + # keep track of which sequences are already finished + batch_size = input_ids.shape[0] + this_peer_finished = False + unfinished_sequences = ops.ones(batch_size, dtype=ms.int32) + + while not this_peer_finished: + # prepare model inputs + model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # inject kv cache state + model_inputs["return_key_value_cache"] = self._use_kv_cache + + # forward pass to get next token + logits, key_cache_list, value_cache_list = self.model(**model_inputs) + next_token_scores = logits[:, -1, :] + + # token selection + next_tokens = ops.argmax(next_token_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = ops.concat([input_ids, next_tokens[:, None]], axis=-1) + model_kwargs = self._update_model_kwargs_for_generation(model_kwargs, key_cache_list, value_cache_list) + + unfinished_sequences = ops.logical_and(unfinished_sequences, ~stopping_criteria(input_ids)) + this_peer_finished = unfinished_sequences.max() == 0 + + return input_ids diff --git a/examples/opensora_hpcai/tools/caption/llava_next/models/.gitkeep b/examples/opensora_hpcai/tools/caption/llava_next/models/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/opensora_hpcai/tools/caption/llava_next/predict.py b/examples/opensora_hpcai/tools/caption/llava_next/predict.py new file mode 100755 index 0000000000..6d9a7d672c --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/predict.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +import argparse +import copy +import json +import logging +import os +import time +from typing import Any, Dict + +from llava.model.llava_next import LlavaNextForConditionalGeneration +from llava.pipeline import TextGenerator +from PIL import Image +from transformers import LlavaNextProcessor + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor + + +def parse_args(): + parser = argparse.ArgumentParser( + description="LLaVa-Next prediction", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--model_path", default="models/llava-v1.6-mistral-7b-hf", help="Path of the model root") + parser.add_argument("--input_image", default="assets/llava_v1_5_radar.jpg", help="Input Image") + parser.add_argument("--prompt", default="What is shown in this image?", help="Input Prompt") + parser.add_argument("--benchmark", action="store_true", help="Do performance benchmark") + args = parser.parse_args() + return args + + +def load_network(config: Dict[str, Any], ckpt_path: str) -> nn.Cell: + config_ = copy.copy(config) + config_["vision_config"]["hidden_size"] = 1024 + config_["text_config"]["hidden_size"] = 4096 + + vision_config = config_.pop("vision_config") + text_config = config_.pop("text_config") + network = LlavaNextForConditionalGeneration( + vision_config, + text_config, + dtype=ms.float16, + attn_implementation="flash_attention", + language_model_input_method="padding", # dynamic + **config_, + ) + ms.load_checkpoint(ckpt_path, net=network, strict_load=True) + return network + + +def main(): + args = parse_args() + + ms.set_context(jit_config=dict(jit_level="O1")) + + with open(os.path.join(args.model_path, "config.json"), "r") as f: + config = json.load(f) + + model_path = os.path.join(args.model_path, "model.ckpt") + logging.info(f"Loading the network from {model_path}") + network = load_network(config, model_path) + + # prepare image and text prompt, using the appropriate prompt template + logging.info(f"Loading the processer from {args.model_path}") + processor = LlavaNextProcessor.from_pretrained(args.model_path) + + image = Image.open(args.input_image) + logging.info(f"Input Image: {args.input_image}") + + input_prompt = f"[INST] \n{args.prompt} [/INST]" + logging.info(f"Input Prompt: {input_prompt}") + + inputs = processor(input_prompt, image, return_tensors="np") + inputs = {k: Tensor(v) for k, v in inputs.items()} + + # autoregressively complete prompt + trials = 2 if args.benchmark else 1 + logging.info("Starting inference...") + for trial in range(trials): + logging.info(f"Trial: {trial}") + pipeline = TextGenerator(network, max_new_tokens=100, use_kv_cache=True) + start = time.time() + output = pipeline.generate(**inputs) + end = time.time() + logging.info(f"Time Taken: {end-start:.3f}, Tokens/Second: {len(output[0]) / (end - start):.1f}") + + print(processor.decode(output[0], skip_special_tokens=True)) + + +if __name__ == "__main__": + fmt = "%(asctime)s %(levelname)s: %(message)s" + datefmt = "[%Y-%m-%d %H:%M:%S]" + logging.basicConfig(level=logging.INFO, format=fmt, datefmt=datefmt) + main() diff --git a/examples/opensora_hpcai/tools/caption/llava_next/tools/convert_llava.py b/examples/opensora_hpcai/tools/caption/llava_next/tools/convert_llava.py new file mode 100755 index 0000000000..d9a4e51f82 --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/tools/convert_llava.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +import argparse +import glob +import logging +import os +from typing import Dict + +from fileio import load_safetensors +from tqdm import tqdm + +import mindspore as ms +from mindspore import Tensor + +logger = logging.getLogger(__name__) + + +def load(root_path: str) -> Dict[str, Tensor]: + # TODO: this method may cause OOM on computer with low memory + # use a better solution later + pattern = os.path.join(root_path, "*.safetensors") + filelist = sorted(glob.glob(pattern)) + + filenames = [os.path.basename(x) for x in filelist] + logger.info(f"Files need to be converted: `{filenames}`") + + params: Dict[str, Tensor] = dict() + for x in tqdm(filelist, desc="Loading the safetensors"): + params_chunk = load_safetensors(x) + if params.keys().isdisjoint(params_chunk.keys()): + params.update(params_chunk) + else: + same_keys = set(params.keys()).intersection(params_chunk.keys()) + raise RuntimeError(f"Duplicated keys found: `{same_keys}`.") + return params + + +def convert(params: Dict[str, Tensor]) -> Dict[str, Tensor]: + # compatibility between MS and PyTorch naming and formating + return params + + +def save(ckpt: Dict[str, Tensor], output: str) -> None: + output = os.path.abspath(output) + logger.info(f"Saving to {output}...") + ms.save_checkpoint(ckpt, output) + logger.info(f"Saving to {output}...Done!") + + +def main(): + parser = argparse.ArgumentParser(description="Convert LLaVa checkpoints into Mindspore Format") + parser.add_argument("src", help="Directory storing the safetensors") + parser.add_argument( + "-o", "--output", default="models/llava_1_6.ckpt", help="Name of the output Mindspore checkpoint" + ) + + args = parser.parse_args() + + params = load(args.src) + params = convert(params) + save(params, args.output) + + +if __name__ == "__main__": + fmt = "%(asctime)s %(levelname)s: %(message)s" + datefmt = "[%Y-%m-%d %H:%M:%S]" + logging.basicConfig(level=logging.INFO, format=fmt, datefmt=datefmt) + main() diff --git a/examples/opensora_hpcai/tools/caption/llava_next/tools/fileio/__init__.py b/examples/opensora_hpcai/tools/caption/llava_next/tools/fileio/__init__.py new file mode 100644 index 0000000000..e64328c9f6 --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/tools/fileio/__init__.py @@ -0,0 +1 @@ +from .safetensors import load_safetensors diff --git a/examples/opensora_hpcai/tools/caption/llava_next/tools/fileio/safetensors.py b/examples/opensora_hpcai/tools/caption/llava_next/tools/fileio/safetensors.py new file mode 100644 index 0000000000..9f80901190 --- /dev/null +++ b/examples/opensora_hpcai/tools/caption/llava_next/tools/fileio/safetensors.py @@ -0,0 +1,21 @@ +import os +from typing import Dict, Union + +import numpy as np +import safetensors.numpy + +from mindspore import Parameter, Tensor + + +def load_safetensors(filename: Union[str, os.PathLike], force_fp32: bool = False) -> Dict[str, Tensor]: + flat = safetensors.numpy.load_file(filename) + output = _np2ms(flat, force_fp32=force_fp32) + return output + + +def _np2ms(np_dict: Dict[str, np.ndarray], force_fp32: bool = False) -> Dict[str, Tensor]: + for k, v in np_dict.items(): + if force_fp32 and v.dtype == np.float16: + v = v.astype(np.float32) + np_dict[k] = Parameter(v, name=k) + return np_dict