Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenSora-HPCAI] Llava-Next Captioner #672

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions examples/opensora_hpcai/tools/caption/llava_next/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# LLaVA-NeXT: Open Large Multimodal Models (MindSpore)

This repo contains Mindspore model definitions, pre-trained weights and inference/sampling code for the [model](https://llava-vl.github.io/blog/2024-01-30-llava-next/). Referring to the [official project page](https://github.com/LLaVA-VL/LLaVA-NeXT).

## Dependencies and Installation

- CANN: 8.0.RC2 or later
- Python: 3.9 or later
- Mindspore: 2.3.1

## Getting Start

### Downloading Pretrained Checkpoints

Please download the model [llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) to the `./models` directory. And run

```bash
python tools/convert_llava.py models/llava-v1.6-mistral-7b-hf -o models/llava-v1.6-mistral-7b-hf/model.ckpt
```

to convert the model weight in Mindspore `ckpt` format.

### Inference

To run the inference, you may use `predict.py` with the following command

```bash
python predict.py --input_image path_to_your_input_image --prompt input_prompt
```

For example, running `python predict.py` with the default image [llava_v1_5_radar.jpg](https://github.com/user-attachments/assets/8e016871-82fd-488a-8629-5ca71222e0e3) and default prompt `What is shown in this image?` will give the following result:

```text
[INST]
What is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multivariate chart that displays values for multiple variables represented on axes
starting from the same point. This particular radar chart is showing the performance of different models or systems across various metrics.

The axes represent different metrics or benchmarks, such as MM-Vet, MM-Vet, MM-Vet, MM-Vet, MM-Vet, MM-V
```

## Benchmark

### Inference

To perform the benchmark, you may first download the image [llava_v1_5_radar.jpg](https://github.com/user-attachments/assets/8e016871-82fd-488a-8629-5ca71222e0e3) and save it in `./assets`, and then run `python predict --benchmark` to get the throughput.

| Model | Context | Batch Size | Throughput (tokens/second)|
|-----------------------|---------------|------------|---------------------------|
| llava-v1.6-mistral-7b | D910*x1-MS2.3 | 1 | 21.2 |

> Context: {Ascend chip}-{number of NPUs}-{mindspore version}.\
> Throughput (tokens/second): number of generated tokens per second.\
> We use the second round of inference as the benchmark result.
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import logging
from collections import OrderedDict

import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor

logger = logging.getLogger(__name__)


class QuickGELU(nn.Cell):
def construct(self, x: Tensor):
return x * ops.sigmoid(1.702 * x)


class ClassInstantier(OrderedDict):
def __getitem__(self, key):
content = super().__getitem__(key)
cls, kwargs = content if isinstance(content, tuple) else (content, {})
return cls(**kwargs)


ACT2CLS = {
"quick_gelu": QuickGELU,
"gelu": nn.GELU,
"silu": nn.SiLU,
}
ACT2FN = ClassInstantier(ACT2CLS)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .network import CLIPVisionModel, CLIPVisionModelWithProjection
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor

from ..activation import ACT2FN
from ..common_layer import Embedding


class CLIPVisionEmbeddings(nn.Cell):
def __init__(
self,
hidden_size: int = 1024,
image_size: int = 336,
patch_size: int = 14,
num_channels: int = 3,
dtype: ms.dtype = ms.float32,
) -> None:
super().__init__()
self.embed_dim = hidden_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels

self.class_embedding = ms.Parameter(Tensor(ops.randn(self.embed_dim), dtype=dtype))

self.patch_embedding = nn.Conv2d(
in_channels=num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
has_bias=False,
dtype=dtype,
)

self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = Embedding(self.num_positions, self.embed_dim, dtype=dtype)
self.position_ids = ops.broadcast_to(ops.arange(self.num_positions), (1, -1))

def construct(self, pixel_values: Tensor) -> Tensor:
batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values)
patch_embeds = ops.flatten(patch_embeds, start_dim=2)
patch_embeds = ops.transpose(patch_embeds, (0, 2, 1))

class_embeds = ops.broadcast_to(self.class_embedding, (batch_size, 1, -1))
embeddings = ops.concat([class_embeds, patch_embeds], axis=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings


class CLIPAttention(nn.Cell):
def __init__(
self,
hidden_size: int = 1024,
num_attention_heads: int = 16,
attention_dropout: float = 0.0,
dtype: ms.dtype = ms.float32,
) -> None:
super().__init__()
self.embed_dim = hidden_size
self.num_heads = num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = attention_dropout

self.k_proj = nn.Dense(self.embed_dim, self.embed_dim, dtype=dtype)
self.v_proj = nn.Dense(self.embed_dim, self.embed_dim, dtype=dtype)
self.q_proj = nn.Dense(self.embed_dim, self.embed_dim, dtype=dtype)
self.out_proj = nn.Dense(self.embed_dim, self.embed_dim, dtype=dtype)

def _shape(self, tensor: Tensor, seq_len: int, bsz: int) -> Tensor:
tensor = ops.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim))
tensor = ops.transpose(tensor, (0, 2, 1, 3))
return tensor

def construct(self, hidden_states: Tensor) -> Tensor:
"""Input shape: Batch x Time x Channel"""

bsz, tgt_len, embed_dim = hidden_states.shape

# get query proj
query_states = self.q_proj(hidden_states) * self.scale
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz)
query_states = ops.reshape(query_states, proj_shape)
key_states = ops.reshape(key_states, proj_shape)
value_states = ops.reshape(value_states, proj_shape)

attn_weights = ops.bmm(query_states, ops.transpose(key_states, (0, 2, 1)))
attn_weights = ops.softmax(attn_weights, axis=-1)
attn_probs = ops.dropout(attn_weights, p=self.dropout, training=self.training)

attn_output = ops.bmm(attn_probs, value_states)
attn_output = ops.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim))
attn_output = ops.transpose(attn_output, (0, 2, 1, 3))
attn_output = ops.reshape(attn_output, (bsz, tgt_len, embed_dim))

attn_output = self.out_proj(attn_output)

return attn_output


class CLIPMLP(nn.Cell):
def __init__(
self,
hidden_size: int = 1024,
intermediate_size: int = 4096,
hidden_act: str = "quick_gelu",
dtype: ms.dtype = ms.float32,
) -> None:
super().__init__()
self.activation_fn = ACT2FN[hidden_act]
self.fc1 = nn.Dense(hidden_size, intermediate_size, dtype=dtype)
self.fc2 = nn.Dense(intermediate_size, hidden_size, dtype=dtype)

def construct(self, hidden_states: ms.Tensor) -> ms.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
Loading