Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for idefics2 #10

Merged
merged 16 commits into from
May 3, 2024
24 changes: 18 additions & 6 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import mlx.core as mx

from .prompt_utils import get_message_json
from .utils import generate, get_model_path, load, load_config, load_image_processor

MODEL_TYPE = ""
Expand Down Expand Up @@ -39,14 +40,21 @@ def parse_arguments():
parser.add_argument(
"--temp", type=float, default=0.3, help="Temperature for sampling."
)
parser.add_argument(
"--verbose",
type=bool,
help="Detailed output.",
default=True,
)
return parser.parse_args()


def get_model_and_processors(model_path):
model_path = get_model_path(model_path)
config = load_config(model_path)
model, processor = load(model_path, {"trust_remote_code": True})
image_processor = load_image_processor(model_path)
return model, processor, image_processor
return model, processor, image_processor, config


def sample(logits, temperature=0.0):
Expand All @@ -58,37 +66,41 @@ def sample(logits, temperature=0.0):

def main():
args = parse_arguments()
model, processor, image_processor = get_model_and_processors(args.model)
model, processor, image_processor, config = get_model_and_processors(args.model)

prompt = codecs.decode(args.prompt, "unicode_escape")

if "chat_template" in processor.__dict__.keys():
prompt = processor.apply_chat_template(
[{"role": "user", "content": f"<image>\n{prompt}"}],
[get_message_json(config["model_type"], prompt)],
tokenize=False,
add_generation_prompt=True,
)

elif "tokenizer" in processor.__dict__.keys():
prompt = processor.tokenizer.apply_chat_template(
[{"role": "user", "content": f"<image>\n{prompt}"}],
[get_message_json(config["model_type"], prompt)],
tokenize=False,
add_generation_prompt=True,
)

else:
ValueError(
"Error: processor does not have 'chat_template' or 'tokenizer' attribute."
)

generate(
output = generate(
model,
processor,
args.image,
prompt,
image_processor,
args.temp,
args.max_tokens,
True,
args.verbose,
)
if not args.verbose:
print(output)


if __name__ == "__main__":
Expand Down
9 changes: 9 additions & 0 deletions mlx_vlm/models/idefics2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .idefics2 import (
LanguageModel,
Model,
ModelConfig,
PerceiverConfig,
TextConfig,
VisionConfig,
VisionModel,
)
Loading
Loading