Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve vlm support (add idefics3 support) #2437

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def local_launcher(
dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_input_tokens: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None,
Expand Down Expand Up @@ -374,6 +375,9 @@ def local_launcher(
if max_input_length:
args.append("--max-input-length")
args.append(str(max_input_length))
if max_input_tokens:
args.append("--max-input-tokens")
args.append(str(max_input_tokens))
if max_batch_prefill_tokens:
args.append("--max-batch-prefill-tokens")
args.append(str(max_batch_prefill_tokens))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 578,
"logprob": -0.2475586,
"special": false,
"text": " The"
},
{
"id": 2217,
"logprob": -0.017303467,
"special": false,
"text": " image"
},
{
"id": 62991,
"logprob": -0.7368164,
"special": false,
"text": " depicts"
},
{
"id": 279,
"logprob": -0.39990234,
"special": false,
"text": " the"
},
{
"id": 89675,
"logprob": -0.34350586,
"special": false,
"text": " Statue"
},
{
"id": 315,
"logprob": -0.0002901554,
"special": false,
"text": " of"
},
{
"id": 32492,
"logprob": -0.0009598732,
"special": false,
"text": " Liberty"
},
{
"id": 11,
"logprob": -0.2355957,
"special": false,
"text": ","
},
{
"id": 264,
"logprob": -0.66503906,
"special": false,
"text": " a"
},
{
"id": 97937,
"logprob": -0.9199219,
"special": false,
"text": " colossal"
}
],
"top_tokens": null
},
"generated_text": " The image depicts the Statue of Liberty, a colossal"
}
64 changes: 64 additions & 0 deletions integration-tests/models/test_idefics3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest
import base64


def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"


@pytest.fixture(scope="module")
def flash_idefics3_next_handle(launcher):
with launcher(
"HuggingFaceM4/Idefics3-8B-Llama3",
max_total_tokens=3000,
max_batch_prefill_tokens=2501,
max_input_tokens=2500,
) as handle:
yield handle


@pytest.fixture(scope="module")
async def flash_idefics3_next(flash_idefics3_next_handle):
await flash_idefics3_next_handle.health(300)
return flash_idefics3_next_handle.client


# TODO: dont skip when token issue is resolved
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics3_next_simple_base64(
flash_idefics3_next, response_snapshot
):
chicken = get_chicken()
query = "Write me a short story"
response = await flash_idefics3_next.generate(
f"<|begin_of_text|><|begin_of_text|>User:![]({chicken}){query}<end_of_utterance>\nAssistant:",
max_new_tokens=10,
)
assert (
response.generated_text == " A chicken is sitting on a pile of money."
), f"{repr(response.generated_text)}"
# assert response.details.generated_tokens == 10
# assert response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics3_next_simple_url(flash_idefics3_next, response_snapshot):
ny_skyline = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
query = "What is in this image?"
response = await flash_idefics3_next.generate(
f"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}<end_of_utterance>\nAssistant:",
max_new_tokens=10,
seed=1337,
)
print(response)
assert (
response.generated_text
== " The image depicts the Statue of Liberty, a colossal"
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 10
assert response == response_snapshot
19 changes: 19 additions & 0 deletions router/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,24 @@ pub struct ClipVisionModel {
patch_size: usize,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Idefics3 {}

impl Idefics3 {
pub fn get_max_longest_edge(&self) -> usize {
364
}

pub fn get_number_of_features(&self) -> usize {
169
}

pub fn get_max_longest_edge_for_image_resize(&self) -> usize {
1456
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Idefics2 {}
Expand Down Expand Up @@ -147,6 +165,7 @@ pub enum Config {
Mistral,
Idefics,
Idefics2(Idefics2),
Idefics3(Idefics3),
Ssm,
GptBigcode,
Santacoder,
Expand Down
1 change: 1 addition & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ impl TokenizerConfigToken {
#[serde(tag = "processor_class")]
pub enum HubPreprocessorConfig {
Idefics2Processor(Idefics2Preprocessor),
Idefics3Processor(Idefics2Preprocessor),
}

impl HubPreprocessorConfig {
Expand Down
112 changes: 110 additions & 2 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,73 @@ fn image_tokens(

image_string
}
Idefics3(config) => {
const FAKE: &str = "<fake_token_around_image>";
const IMAGE: &str = "<image>";
const GLOBAL_IMG: &str = "<global-img>";

let max_longest_edge_for_image_resize = config.get_max_longest_edge_for_image_resize();

// resize image if it is larger than max_longest_edge_for_image_resize keeping aspect ratio
let (height, width) = if height > max_longest_edge_for_image_resize
|| width > max_longest_edge_for_image_resize
{
let aspect_ratio = height as f32 / width as f32;
if height > width {
(
max_longest_edge_for_image_resize,
(max_longest_edge_for_image_resize as f32 / aspect_ratio) as usize,
)
} else {
(
(max_longest_edge_for_image_resize as f32 * aspect_ratio) as usize,
max_longest_edge_for_image_resize,
)
}
} else {
(height, width)
};

let image_seq_len = config.get_number_of_features();
let max_edge = config.get_max_longest_edge();

let (image_rows, image_cols) = if height > max_edge || width > max_edge {
(
(height as f32 / max_edge as f32).ceil() as usize,
(width as f32 / max_edge as f32).ceil() as usize,
)
} else {
(0, 0)
};

let mut image_string = String::new();

if image_rows == 0 && image_cols == 0 {
// Single image case
image_string.push_str(FAKE);
image_string.push_str(GLOBAL_IMG);
image_string.push_str(&IMAGE.repeat(image_seq_len));
image_string.push_str(FAKE);
} else {
// Split image case
for n_h in 0..image_rows {
for n_w in 0..image_cols {
image_string.push_str(FAKE);
image_string.push_str(&format!("<row_{}_col_{}>", n_h + 1, n_w + 1));
image_string.push_str(&IMAGE.repeat(image_seq_len));
}
image_string.push('\n');
}

image_string.push('\n');
image_string.push_str(FAKE);
image_string.push_str(GLOBAL_IMG);
image_string.push_str(&IMAGE.repeat(image_seq_len));
image_string.push_str(FAKE);
}

image_string
}
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
_ => unimplemented!("Images tokens are not supported for this model configuration"),
Expand Down Expand Up @@ -598,7 +665,7 @@ fn prepare_input(
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config {
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
Some(config @ (Idefics | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_))) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
Expand Down Expand Up @@ -796,7 +863,7 @@ pub enum ValidationError {
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{Idefics2, PaliTextConfig, Paligemma};
use crate::config::{Idefics2, Idefics3, PaliTextConfig, Paligemma};
use crate::default_parameters;
use crate::tests::get_tokenizer;

Expand Down Expand Up @@ -1189,4 +1256,45 @@ mod tests {
11
);
}

#[tokio::test]
async fn test_idefics2_image_tokens() {
let config = Config::Idefics3(Idefics3 {});

let preprocessor_config = Some(&HubPreprocessorConfig::Idefics2Processor(
Idefics2Preprocessor {
do_image_splitting: true,
},
));

let height = 1067;
let width = 1600;

let tokens = image_tokens(&config, preprocessor_config, height, width);

// get all unique tags `<tag>` from the tokens
let tags: std::collections::HashSet<&str> = tokens
.split(|c| c == '<' || c == '>')
.filter(|s| !s.is_empty())
.collect();

assert_eq!(tags.len(), 17); // all below and `\n` and `\n\n`
assert_eq!(tags.contains(&"row_1_col_1"), true);
assert_eq!(tags.contains(&"row_1_col_2"), true);
assert_eq!(tags.contains(&"row_1_col_3"), true);
assert_eq!(tags.contains(&"row_1_col_4"), true);
assert_eq!(tags.contains(&"row_2_col_1"), true);
assert_eq!(tags.contains(&"row_2_col_2"), true);
assert_eq!(tags.contains(&"row_2_col_3"), true);
assert_eq!(tags.contains(&"row_2_col_4"), true);
assert_eq!(tags.contains(&"row_3_col_1"), true);
assert_eq!(tags.contains(&"row_3_col_2"), true);
assert_eq!(tags.contains(&"row_3_col_3"), true);
assert_eq!(tags.contains(&"row_3_col_4"), true);
assert_eq!(tags.contains(&"global-img"), true);
assert_eq!(tags.contains(&"image"), true);
assert_eq!(tags.contains(&"fake_token_around_image"), true);

assert_eq!(tokens.len(), 15_901)
}
}
24 changes: 24 additions & 0 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
Idefics3ForConditionalGeneration,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e:
Expand Down Expand Up @@ -171,6 +172,12 @@ class ModelType(enum.Enum):
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
"multimodal": True,
}
IDEFICS3 = {
"type": "idefics3",
"name": "Idefics 3",
"url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
"multimodal": True,
}
LLAVA_NEXT = {
"type": "llava_next",
"name": "Llava Next (1.6)",
Expand Down Expand Up @@ -1091,6 +1098,23 @@ def get_model(
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == IDEFICS3:
if FLASH_ATTENTION:
return VlmCausalLM(
model_id=model_id,
model_class=Idefics3ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == PALIGEMMA:
if FLASH_ATTENTION:
return VlmCausalLM(
Expand Down
Loading
Loading