Skip to content

Commit

Permalink
Add support for Phi-3.5 vision
Browse files Browse the repository at this point in the history
  • Loading branch information
kunal-vaishnavi committed Sep 6, 2024
1 parent bb0affc commit ccc5c12
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 160 deletions.
4 changes: 2 additions & 2 deletions examples/python/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import onnxruntime_genai as og


def _complete(text, state):
return (glob.glob(text + "*") + [None])[state]

Expand All @@ -29,9 +28,10 @@ def run(args: argparse.Namespace):
"Image Path (comma separated; leave empty if no image): "
).split(",")
]
image_paths = [image_path for image_path in image_paths if len(image_path)]
print(image_paths)

image = None
images = None
prompt = "<|user|>\n"
if len(image_paths) == 0:
print("No image provided")
Expand Down
10 changes: 4 additions & 6 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,6 @@ struct Inputs_Element : JSON::Element {
v_.position_ids = value;
} else if (name == "attention_mask") {
v_.attention_mask = value;
} else if (name == "seqlens_k") {
v_.seqlens_k = value;
} else if (name == "total_seq_len") {
v_.total_sequence_length = value;
} else if (name == "past_key_names") {
v_.past_key_names = value;
} else if (name == "past_value_names") {
Expand Down Expand Up @@ -248,8 +244,8 @@ struct VisionOutputs_Element : JSON::Element {
explicit VisionOutputs_Element(Config::Model::Vision::Outputs& v) : v_{v} {}

void OnString(std::string_view name, std::string_view value) override {
if (name == "visual_features") {
v_.visual_features = value;
if (name == "image_features") {
v_.image_features = value;
} else
throw JSON::unknown_value_error{};
}
Expand Down Expand Up @@ -312,6 +308,8 @@ struct EmbeddingInputs_Element : JSON::Element {
void OnString(std::string_view name, std::string_view value) override {
if (name == "input_ids") {
v_.input_ids = value;
} else if (name == "image_features") {
v_.image_features = value;
} else
throw JSON::unknown_value_error{};
}
Expand Down
6 changes: 3 additions & 3 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ struct Config {
static constexpr std::string_view InputIdsName = "input_ids";
static constexpr std::string_view PixelValuesName = "pixel_values";
static constexpr std::string_view ImageSizesName = "image_sizes";
static constexpr std::string_view ImageFeaturesName = "image_features";
};

fs::path config_path; // Path of the config directory
Expand Down Expand Up @@ -62,6 +63,7 @@ struct Config {

struct Inputs {
std::string input_ids{Defaults::InputIdsName};
std::string image_features{Defaults::ImageFeaturesName};
} inputs;

struct Outputs {
Expand All @@ -78,7 +80,7 @@ struct Config {
} inputs;

struct Outputs {
std::string visual_features{"visual_features"};
std::string image_features{Defaults::ImageFeaturesName};
} outputs;
} vision;

Expand All @@ -97,8 +99,6 @@ struct Config {
std::string embeddings{"inputs_embeds"};
std::string position_ids{"position_ids"};
std::string attention_mask{"attention_mask"};
std::string seqlens_k{"seqlens_k"};
std::string total_sequence_length{"total_seq_len"};
std::string past_key_names{"past_key_values.%d.key"}, past_value_names{"past_key_values.%d.value"};
std::string past_names; // When key/value pairs are combined
std::string cross_past_key_names, cross_past_value_names;
Expand Down
1 change: 1 addition & 0 deletions src/models/embeddings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Embeddings::Embeddings(const Model& model, State& state, Embeddings::Mode mode,
: model_.session_info_->GetOutputDataType(name)},
mode_{mode},
name_{name} {

// Embeddings are only transient inputs and outputs.
// They are never the user provided/requested model inputs/outputs
// So only create the transient output and reuse that ortvalue for subsequent
Expand Down
67 changes: 67 additions & 0 deletions src/models/image_features.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "../generators.h"
#include "model.h"
#include "image_features.h"

namespace Generators {

ImageFeatures::ImageFeatures(const Model& model, State& state, ImageFeatures::Mode mode, const std::string& name, int64_t num_image_tokens)
: model_{model},
state_{state},
shape_{num_image_tokens, state_.params_->hidden_size},
type_{mode == ImageFeatures::Mode::Input

Check failure on line 11 in src/models/image_features.cpp

View workflow job for this annotation

GitHub Actions / windows-cuda-x64-build

the following warning is treated as an error

Check failure on line 11 in src/models/image_features.cpp

View workflow job for this annotation

GitHub Actions / windows-cpu-x64-build

the following warning is treated as an error

Check warning on line 11 in src/models/image_features.cpp

View workflow job for this annotation

GitHub Actions / windows-cpu-x64-build

data member 'Generators::ImageFeatures::type_' will be initialized after data member 'Generators::ImageFeatures::mode_'
? model_.session_info_->GetInputDataType(name)
: model_.session_info_->GetOutputDataType(name)},
mode_{mode},
name_{name} {

// There are four cases for ImageFeatures:
// 1) Created as an output for vision model (num_image_tokens > 0)
// The tensor needs to be pre-allocated to store the output.
// It will be transferred to an input for the embedding model.
// 2) Created as an output for vision model (num_image_tokens = 0)
// The tensor will be pre-allocated to store the empty output.
// It will be transferred to an input for the embedding model.
// 3) Created as an input for embedding model (num_image_tokens > 0)
// The tensor does not need to be pre-allocated because it will be created during (1).
// 4) Created as an input for embedding model (num_image_tokens = 0)
// The tensor does not need to be pre-allocated because it will be created during (2).
if (mode == ImageFeatures::Mode::Output) {
image_features_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
}
}

void ImageFeatures::Add() {
if (mode_ == ImageFeatures::Mode::Input) {
// In case the input_features are an input to a model, they are added
// as a nullptr to reserve a slot in the inputs. The input_features
// input will be overwritten when ReuseImageFeaturesBuffer is invoked.
index_ = state_.inputs_.size();
state_.inputs_.push_back(nullptr);
state_.input_names_.push_back(name_.c_str());
} else {
index_ = state_.outputs_.size();
state_.outputs_.push_back(image_features_.get());
state_.output_names_.push_back(name_.c_str());
}
}

void ImageFeatures::Update() {
// Initialize empty image_features tensor for no-image or after-prompt input scenarios
if (shape_[0] > 0) { // if num_image_tokens > 0
shape_[0] = 0;
image_features_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
state_.inputs_[index_] = image_features_.get();
}
}

void ImageFeatures::ReuseImageFeaturesBuffer(ImageFeatures& other) {
if (mode_ == ImageFeatures::Mode::Output || other.mode_ == ImageFeatures::Mode::Input) {
throw std::runtime_error("Incorrect usage of the ImageFeatures inputs and outputs.");
}

// Share the output ImageFeatures OrtValue* from other with the input ImageFeatures for this.
image_features_ = std::move(other.image_features_);
state_.inputs_[index_] = other.state_.outputs_[other.index_];
}

} // namespace Generators
38 changes: 38 additions & 0 deletions src/models/image_features.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

namespace Generators {

struct ImageFeatures {
enum struct Mode {
Input = 0,
Output
};

ImageFeatures(const Model& model, State& state, ImageFeatures::Mode mode, const std::string& name, int64_t num_image_tokens);
ImageFeatures(const ImageFeatures&) = delete;
ImageFeatures& operator=(const ImageFeatures&) = delete;

void Add();
void Update();
void ReuseImageFeaturesBuffer(ImageFeatures& other);

auto& GetShape() const { return shape_; }
OrtValue* Get() { return image_features_.get(); }

private:
const Model& model_;
State& state_;
size_t index_{~0U};

const Mode mode_{};
const std::string name_;

std::array<int64_t, 2> shape_{}; // [num_image_tokens, hidden_size]
ONNXTensorElementDataType type_;
std::unique_ptr<OrtValue> image_features_;
};

} // namespace Generators
Loading

0 comments on commit ccc5c12

Please sign in to comment.