diff --git a/examples/python/phi3v.py b/examples/python/phi3v.py index df5dddcfb..fd92dbd93 100644 --- a/examples/python/phi3v.py +++ b/examples/python/phi3v.py @@ -8,7 +8,6 @@ import onnxruntime_genai as og - def _complete(text, state): return (glob.glob(text + "*") + [None])[state] @@ -29,9 +28,10 @@ def run(args: argparse.Namespace): "Image Path (comma separated; leave empty if no image): " ).split(",") ] + image_paths = [image_path for image_path in image_paths if len(image_path)] print(image_paths) - image = None + images = None prompt = "<|user|>\n" if len(image_paths) == 0: print("No image provided") diff --git a/src/config.cpp b/src/config.cpp index 00708faae..ec426e5ce 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -137,10 +137,6 @@ struct Inputs_Element : JSON::Element { v_.position_ids = value; } else if (name == "attention_mask") { v_.attention_mask = value; - } else if (name == "seqlens_k") { - v_.seqlens_k = value; - } else if (name == "total_seq_len") { - v_.total_sequence_length = value; } else if (name == "past_key_names") { v_.past_key_names = value; } else if (name == "past_value_names") { @@ -248,8 +244,8 @@ struct VisionOutputs_Element : JSON::Element { explicit VisionOutputs_Element(Config::Model::Vision::Outputs& v) : v_{v} {} void OnString(std::string_view name, std::string_view value) override { - if (name == "visual_features") { - v_.visual_features = value; + if (name == "image_features") { + v_.image_features = value; } else throw JSON::unknown_value_error{}; } @@ -312,6 +308,8 @@ struct EmbeddingInputs_Element : JSON::Element { void OnString(std::string_view name, std::string_view value) override { if (name == "input_ids") { v_.input_ids = value; + } else if (name == "image_features") { + v_.image_features = value; } else throw JSON::unknown_value_error{}; } diff --git a/src/config.h b/src/config.h index 7263dbda3..f511c0885 100644 --- a/src/config.h +++ b/src/config.h @@ -12,6 +12,7 @@ struct Config { static constexpr std::string_view InputIdsName = "input_ids"; static constexpr std::string_view PixelValuesName = "pixel_values"; static constexpr std::string_view ImageSizesName = "image_sizes"; + static constexpr std::string_view ImageFeaturesName = "image_features"; }; fs::path config_path; // Path of the config directory @@ -62,6 +63,7 @@ struct Config { struct Inputs { std::string input_ids{Defaults::InputIdsName}; + std::string image_features{Defaults::ImageFeaturesName}; } inputs; struct Outputs { @@ -78,7 +80,7 @@ struct Config { } inputs; struct Outputs { - std::string visual_features{"visual_features"}; + std::string image_features{Defaults::ImageFeaturesName}; } outputs; } vision; @@ -97,8 +99,6 @@ struct Config { std::string embeddings{"inputs_embeds"}; std::string position_ids{"position_ids"}; std::string attention_mask{"attention_mask"}; - std::string seqlens_k{"seqlens_k"}; - std::string total_sequence_length{"total_seq_len"}; std::string past_key_names{"past_key_values.%d.key"}, past_value_names{"past_key_values.%d.value"}; std::string past_names; // When key/value pairs are combined std::string cross_past_key_names, cross_past_value_names; diff --git a/src/models/embeddings.cpp b/src/models/embeddings.cpp index 10508b89f..3accd9136 100644 --- a/src/models/embeddings.cpp +++ b/src/models/embeddings.cpp @@ -17,6 +17,7 @@ Embeddings::Embeddings(const Model& model, State& state, Embeddings::Mode mode, : model_.session_info_->GetOutputDataType(name)}, mode_{mode}, name_{name} { + // Embeddings are only transient inputs and outputs. // They are never the user provided/requested model inputs/outputs // So only create the transient output and reuse that ortvalue for subsequent diff --git a/src/models/image_features.cpp b/src/models/image_features.cpp new file mode 100644 index 000000000..4a479ba7a --- /dev/null +++ b/src/models/image_features.cpp @@ -0,0 +1,67 @@ +#include "../generators.h" +#include "model.h" +#include "image_features.h" + +namespace Generators { + +ImageFeatures::ImageFeatures(const Model& model, State& state, ImageFeatures::Mode mode, const std::string& name, int64_t num_image_tokens) + : model_{model}, + state_{state}, + shape_{num_image_tokens, state_.params_->hidden_size}, + type_{mode == ImageFeatures::Mode::Input + ? model_.session_info_->GetInputDataType(name) + : model_.session_info_->GetOutputDataType(name)}, + mode_{mode}, + name_{name} { + + // There are four cases for ImageFeatures: + // 1) Created as an output for vision model (num_image_tokens > 0) + // The tensor needs to be pre-allocated to store the output. + // It will be transferred to an input for the embedding model. + // 2) Created as an output for vision model (num_image_tokens = 0) + // The tensor will be pre-allocated to store the empty output. + // It will be transferred to an input for the embedding model. + // 3) Created as an input for embedding model (num_image_tokens > 0) + // The tensor does not need to be pre-allocated because it will be created during (1). + // 4) Created as an input for embedding model (num_image_tokens = 0) + // The tensor does not need to be pre-allocated because it will be created during (2). + if (mode == ImageFeatures::Mode::Output) { + image_features_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + } +} + +void ImageFeatures::Add() { + if (mode_ == ImageFeatures::Mode::Input) { + // In case the input_features are an input to a model, they are added + // as a nullptr to reserve a slot in the inputs. The input_features + // input will be overwritten when ReuseImageFeaturesBuffer is invoked. + index_ = state_.inputs_.size(); + state_.inputs_.push_back(nullptr); + state_.input_names_.push_back(name_.c_str()); + } else { + index_ = state_.outputs_.size(); + state_.outputs_.push_back(image_features_.get()); + state_.output_names_.push_back(name_.c_str()); + } +} + +void ImageFeatures::Update() { + // Initialize empty image_features tensor for no-image or after-prompt input scenarios + if (shape_[0] > 0) { // if num_image_tokens > 0 + shape_[0] = 0; + image_features_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + state_.inputs_[index_] = image_features_.get(); + } +} + +void ImageFeatures::ReuseImageFeaturesBuffer(ImageFeatures& other) { + if (mode_ == ImageFeatures::Mode::Output || other.mode_ == ImageFeatures::Mode::Input) { + throw std::runtime_error("Incorrect usage of the ImageFeatures inputs and outputs."); + } + + // Share the output ImageFeatures OrtValue* from other with the input ImageFeatures for this. + image_features_ = std::move(other.image_features_); + state_.inputs_[index_] = other.state_.outputs_[other.index_]; +} + +} // namespace Generators diff --git a/src/models/image_features.h b/src/models/image_features.h new file mode 100644 index 000000000..06bdd9787 --- /dev/null +++ b/src/models/image_features.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace Generators { + +struct ImageFeatures { + enum struct Mode { + Input = 0, + Output + }; + + ImageFeatures(const Model& model, State& state, ImageFeatures::Mode mode, const std::string& name, int64_t num_image_tokens); + ImageFeatures(const ImageFeatures&) = delete; + ImageFeatures& operator=(const ImageFeatures&) = delete; + + void Add(); + void Update(); + void ReuseImageFeaturesBuffer(ImageFeatures& other); + + auto& GetShape() const { return shape_; } + OrtValue* Get() { return image_features_.get(); } + + private: + const Model& model_; + State& state_; + size_t index_{~0U}; + + const Mode mode_{}; + const std::string name_; + + std::array shape_{}; // [num_image_tokens, hidden_size] + ONNXTensorElementDataType type_; + std::unique_ptr image_features_; +}; + +} // namespace Generators diff --git a/src/models/multi_modal_vision_model.cpp b/src/models/multi_modal_vision_model.cpp index 6237a25c9..d5e371040 100644 --- a/src/models/multi_modal_vision_model.cpp +++ b/src/models/multi_modal_vision_model.cpp @@ -12,90 +12,16 @@ RoamingArray MakeDummy() { return RoamingArray(); } -#pragma warning(push) -#pragma warning(disable : 4189) // local variable is initialized but not referenced - -void Select(const Model& model, std::span input_ids, OrtValue* hidden_states, - OrtValue* visual_features, int32_t num_img_tokens, int32_t hidden_size, DeviceType device_type, - cudaStream_t cuda_stream) { - // Assme batch_size = 1 - constexpr int32_t min_input_id = -1000000000; - constexpr int64_t expected_batch_size = 1; - - // Find the position in the input_ids that corresponds to the start of the image tokens. - // Image tokens are represented by negative values in the input_ids. - const int64_t sequence_length = input_ids.size(); - int32_t image_position_start{}; - for (int64_t idx = 0; idx < sequence_length; ++idx) { - if (input_ids[idx] < 0 && input_ids[idx] > min_input_id) { - image_position_start = static_cast(idx); - break; - } - } - - // Replace the positions in the hidden_states tensor that correspond to the image tokens - // with the visual features tensor. - const int32_t start_pos = image_position_start * hidden_size; - const int32_t element_count = num_img_tokens * hidden_size; - const int32_t hidden_states_element_count = static_cast(sequence_length) * hidden_size; - - switch (device_type) { - case DeviceType::CPU: { - auto target = cpu_span(hidden_states->GetTensorMutableData(), hidden_states_element_count) - .subspan(start_pos, element_count); - auto source = cpu_span(visual_features->GetTensorMutableData(), element_count); - std::copy(source.begin(), source.end(), target.begin()); - break; - } -#if USE_CUDA - case DeviceType::CUDA: { - auto target = gpu_span(hidden_states->GetTensorMutableData(), hidden_states_element_count) - .subspan(start_pos, element_count); - auto source = gpu_span(visual_features->GetTensorMutableData(), element_count); - CudaCheck() == cudaMemcpyAsync(target.data(), source.data(), source.size_bytes(), - cudaMemcpyDeviceToDevice, cuda_stream); - break; - } -#endif - -#if USE_DML - case DeviceType::DML: { - ComPtr source_resource; - Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model.allocator_device_, visual_features->GetTensorMutableRawData(), &source_resource)); - - ComPtr target_resource; - Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model.allocator_device_, hidden_states->GetTensorMutableRawData(), &target_resource)); - - model.GetDmlExecutionContext()->CopyBufferRegion( - target_resource.Get(), - start_pos * sizeof(uint16_t), - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - source_resource.Get(), - 0, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - element_count * sizeof(uint16_t)); - - // Execute the cached command list - ComPtr fence; - uint64_t completion_value; - model.GetDmlExecutionContext()->ExecuteCommandList(nullptr, &fence, &completion_value); - break; - } -#endif - default: - throw std::runtime_error("Unsupported device type for Select."); - } -} - -#pragma warning(pop) - int64_t GetNumImageTokens(const std::vector& extra_inputs, + const std::string& pixel_values_name, const std::string& image_sizes_name) { + std::shared_ptr pixel_values; std::shared_ptr image_sizes; for (size_t i = 0; i < extra_inputs.size(); ++i) { - if (extra_inputs[i].name == image_sizes_name) { + if (extra_inputs[i].name == pixel_values_name) { + pixel_values = extra_inputs[i].tensor; + } else if (extra_inputs[i].name == image_sizes_name) { image_sizes = extra_inputs[i].tensor; - break; } } @@ -104,40 +30,19 @@ int64_t GetNumImageTokens(const std::vector& extra_input return 0; } - if (image_sizes->ort_tensor_->GetTensorTypeAndShapeInfo()->GetShape() != std::vector{1, 2}) { - throw std::runtime_error("image_sizes tensor must have 2 elements"); + auto num_images = pixel_values->ort_tensor_->GetTensorTypeAndShapeInfo()->GetShape()[0]; + if (image_sizes->ort_tensor_->GetTensorTypeAndShapeInfo()->GetShape() != std::vector{num_images, 2}) { + throw std::runtime_error("image_sizes tensor must be of shape (num_images, 2)"); } auto image_sizes_data = image_sizes->ort_tensor_->GetTensorMutableData(); - const int64_t h = image_sizes_data[0] / 336; - const int64_t w = image_sizes_data[1] / 336; - return ((h * w + 1) * 144) + 1 + ((h + 1) * 12); -} - -std::unique_ptr GetVisualFeatures(OrtAllocator& device_allocator, const SessionInfo& session_info, - const std::string& visual_features_name, int32_t hidden_size, - int64_t num_image_tokens) { - if (!session_info.HasOutput(visual_features_name)) { - throw std::runtime_error("Visual features output not found in the model"); + int64_t num_image_tokens = 0; + for (int i = 0; i < num_images; i++) { + int64_t h = image_sizes_data[i * num_images] / 336; + int64_t w = image_sizes_data[i * num_images + 1] / 336; + num_image_tokens += static_cast((h * w + 1) * 144) + 1 + ((h + 1) * 12); } - - auto type = session_info.GetOutputDataType(visual_features_name); - - std::vector shape = {num_image_tokens, hidden_size}; - std::unique_ptr visual_features; - - switch (type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - visual_features = OrtValue::CreateTensor(device_allocator, shape); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - visual_features = OrtValue::CreateTensor(device_allocator, shape); - break; - default: - throw std::runtime_error("Unsupported data type for visual features: " + std::to_string(type)); - } - - return visual_features; + return num_image_tokens; } } // namespace @@ -164,16 +69,19 @@ std::unique_ptr MultiModalVisionModel::CreateState(RoamingArray return std::make_unique(*this, sequence_lengths, params); } -EmbeddingState::EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const CapturedGraphInfo* captured_graph_info) +EmbeddingState::EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const CapturedGraphInfo* captured_graph_info, const int64_t num_image_tokens) : State{params, model}, model_{model}, - captured_graph_info_{captured_graph_info} { + captured_graph_info_{captured_graph_info}, + num_image_tokens_{num_image_tokens} { input_ids_.Add(); + image_features_.Add(); inputs_embeds_.Add(); } void EmbeddingState::UpdateInputsAndOutputs(RoamingArray next_tokens) { input_ids_.Update(next_tokens); + image_features_.Update(); inputs_embeds_.UpdateSequenceLength(); } @@ -184,22 +92,17 @@ RoamingArray EmbeddingState::Run(int current_length, RoamingArray(GetNumImageTokens(params_->extra_inputs, model_.config_->model.vision.inputs.image_sizes)); - if (num_image_tokens_ > 0) { - visual_features_ = GetVisualFeatures(*model_.allocator_device_, *model_.session_info_, - model_.config_->model.vision.outputs.visual_features, - params_->hidden_size, num_image_tokens_); - output_names_.push_back(model_.config_->model.vision.outputs.visual_features.c_str()); - outputs_.push_back(visual_features_.get()); - } + image_features_.Add(); } RoamingArray VisionState::Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) { - State::Run(*model_.vision_session_, *model_.run_options_, 1); + int num_images = static_cast(inputs_[0]->GetTensorTypeAndShapeInfo()->GetShape()[0]); + State::Run(*model_.vision_session_, *model_.run_options_, num_images); return MakeDummy(); } @@ -232,33 +135,29 @@ MultiModalPipelineState::MultiModalPipelineState(const MultiModalVisionModel& mo const GeneratorParams& params) : State{params, model}, model_{model}, - captured_graph_info_{model.GetCapturedGraphPool()->ReserveCapturedGraph(model, params)}, - embedding_state_{std::make_unique(model, params, captured_graph_info_.get())}, - vision_state_{std::make_unique(model_, params)}, - decoder_state_{std::make_unique(model_, sequence_lengths_unk, params, captured_graph_info_.get())} { + num_image_tokens_{GetNumImageTokens(params_->extra_inputs, model_.config_->model.vision.inputs.pixel_values, model_.config_->model.vision.inputs.image_sizes)}, + captured_graph_info_{model.GetCapturedGraphPool()->ReserveCapturedGraph(model, params)} { + embedding_state_ = std::make_unique(model, params, captured_graph_info_.get(), num_image_tokens_); + vision_state_ = std::make_unique(model_, params, num_image_tokens_); + decoder_state_ = std::make_unique(model_, sequence_lengths_unk, params, captured_graph_info_.get()); } RoamingArray MultiModalPipelineState::Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) { // Pipeline state defines the pipeline of the execution of the models // Prompt stage: - // - input_ids -> |embeddings_model| -> |inputs_embeds| - // - pixel_values, img_sizes -> |vision_model| -> |inputs_embeds| - // - inputs_embeds, visual_features -> |Select| -> |inputs_embeds| - // - inputs_embeds -> |decoder_model| -> |logits| + // - pixel_values, image_sizes -> |vision_model| -> image_features + // - input_ids, image_features -> |embeddings_model| -> inputs_embeds + // - inputs_embeds -> |decoder_model| -> logits // Generation stage: - // - input_ids -> |embeddings_model| -> |inputs_embeds| - // - inputs_embeds -> |decoder_model| -> |logits| + // - input_ids, image_features -> |embeddings_model| -> inputs_embeds + // - inputs_embeds -> |decoder_model| -> logits if (is_prompt_) { - embedding_state_->Run(current_length, next_tokens, next_indices); - if (vision_state_->num_image_tokens_ > 0) { + if (num_image_tokens_ > 0) { vision_state_->Run(current_length, next_tokens, next_indices); - - // Run the select logic - Select(model_, params_->input_ids, embedding_state_->inputs_embeds_.Get(), - vision_state_->visual_features_.get(), vision_state_->num_image_tokens_, - params_->hidden_size, params_->device_type, params_->cuda_stream); } + embedding_state_->image_features_.ReuseImageFeaturesBuffer(vision_state_->image_features_); + embedding_state_->Run(current_length, next_tokens, next_indices); decoder_state_->inputs_embeds_.ReuseEmbeddingsBuffer(embedding_state_->inputs_embeds_); auto logits = decoder_state_->Run(current_length, next_tokens, next_indices); diff --git a/src/models/multi_modal_vision_model.h b/src/models/multi_modal_vision_model.h index 9b3e62646..eb8b47392 100644 --- a/src/models/multi_modal_vision_model.h +++ b/src/models/multi_modal_vision_model.h @@ -4,6 +4,7 @@ #pragma once #include "model.h" #include "input_ids.h" +#include "image_features.h" #include "embeddings.h" #include "extra_inputs.h" #include "logits.h" @@ -20,13 +21,13 @@ struct MultiModalVisionModel : Model { std::unique_ptr CreateState(RoamingArray sequence_lengths, const GeneratorParams& params) const override; - std::unique_ptr embedding_session_; // input_ids -> inputs_embeds - std::unique_ptr vision_session_; // pixel_values, img_sizes -> visual_features + std::unique_ptr vision_session_; // pixel_values, image_sizes -> image_features + std::unique_ptr embedding_session_; // input_ids, image_features -> inputs_embeds std::unique_ptr decoder_session_; // inputs_embeds, attention_mask, kv_cache -> logits }; struct EmbeddingState : State { - EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const CapturedGraphInfo* captured_graph_info); + EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const CapturedGraphInfo* captured_graph_info, const int64_t num_image_tokens); EmbeddingState(const EmbeddingState&) = delete; EmbeddingState& operator=(const EmbeddingState&) = delete; @@ -42,13 +43,18 @@ struct EmbeddingState : State { const MultiModalVisionModel& model_; const CapturedGraphInfo* captured_graph_info_; - InputIDs input_ids_{model_, *this}; // Model input - Embeddings inputs_embeds_{model_, *this, Embeddings::Mode::Output, // Model output + int64_t num_image_tokens_; + + InputIDs input_ids_{model_, *this}; // Model input + ImageFeatures image_features_{model_, *this, ImageFeatures::Mode::Input, // Optional model input + model_.config_->model.embedding.inputs.image_features, + num_image_tokens_}; + Embeddings inputs_embeds_{model_, *this, Embeddings::Mode::Output, // Model output model_.config_->model.embedding.outputs.embeddings}; }; struct VisionState : State { - VisionState(const MultiModalVisionModel& model, const GeneratorParams& params); + VisionState(const MultiModalVisionModel& model, const GeneratorParams& params, const int64_t num_image_tokens); VisionState(const VisionState&) = delete; VisionState& operator=(const VisionState&) = delete; @@ -59,9 +65,11 @@ struct VisionState : State { friend struct MultiModalPipelineState; const MultiModalVisionModel& model_; - ExtraInputs extra_inputs_{model_, *this}; // Model inputs - std::unique_ptr visual_features_; // Model output - int32_t num_image_tokens_{}; + int64_t num_image_tokens_; + ExtraInputs extra_inputs_{model_, *this}; // Model inputs + ImageFeatures image_features_{model_, *this, ImageFeatures::Mode::Output, // Model output + model_.config_->model.vision.outputs.image_features, + num_image_tokens_}; }; struct DecoderState : State { @@ -108,6 +116,7 @@ struct MultiModalPipelineState : State { std::unique_ptr vision_state_; std::unique_ptr decoder_state_; bool is_prompt_{true}; + int64_t num_image_tokens_{0}; }; } // namespace Generators diff --git a/src/models/prompt_image_processor.cpp b/src/models/prompt_image_processor.cpp index 8b791af1e..a56a66362 100644 --- a/src/models/prompt_image_processor.cpp +++ b/src/models/prompt_image_processor.cpp @@ -15,7 +15,7 @@ std::unique_ptr ProcessImagePrompt(const Generators::Tokenizer& tokeni const size_t num_images = num_img_tokens ? num_img_tokens->NumberOfElement() : 0U; auto* num_img_tokens_data = num_img_tokens ? num_img_tokens->Data() : nullptr; - // Split the prompt string based on the occurrences of the pattern "<|image_|>" + // Split the prompt string based on the occurrences of the pattern "<|image_|>" // Here the represents the image id. const std::regex pattern("<\\|image_\\d+\\|>"); const std::vector prompt_chunks( diff --git a/test/test_models/images/10809054.jpg b/test/test_models/images/10809054.jpg new file mode 100644 index 000000000..117ca64be Binary files /dev/null and b/test/test_models/images/10809054.jpg differ