-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bb0affc
commit ccc5c12
Showing
10 changed files
with
172 additions
and
160 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#include "../generators.h" | ||
#include "model.h" | ||
#include "image_features.h" | ||
|
||
namespace Generators { | ||
|
||
ImageFeatures::ImageFeatures(const Model& model, State& state, ImageFeatures::Mode mode, const std::string& name, int64_t num_image_tokens) | ||
: model_{model}, | ||
state_{state}, | ||
shape_{num_image_tokens, state_.params_->hidden_size}, | ||
type_{mode == ImageFeatures::Mode::Input | ||
Check failure on line 11 in src/models/image_features.cpp GitHub Actions / windows-cuda-x64-build
Check failure on line 11 in src/models/image_features.cpp GitHub Actions / windows-cpu-x64-build
|
||
? model_.session_info_->GetInputDataType(name) | ||
: model_.session_info_->GetOutputDataType(name)}, | ||
mode_{mode}, | ||
name_{name} { | ||
|
||
// There are four cases for ImageFeatures: | ||
// 1) Created as an output for vision model (num_image_tokens > 0) | ||
// The tensor needs to be pre-allocated to store the output. | ||
// It will be transferred to an input for the embedding model. | ||
// 2) Created as an output for vision model (num_image_tokens = 0) | ||
// The tensor will be pre-allocated to store the empty output. | ||
// It will be transferred to an input for the embedding model. | ||
// 3) Created as an input for embedding model (num_image_tokens > 0) | ||
// The tensor does not need to be pre-allocated because it will be created during (1). | ||
// 4) Created as an input for embedding model (num_image_tokens = 0) | ||
// The tensor does not need to be pre-allocated because it will be created during (2). | ||
if (mode == ImageFeatures::Mode::Output) { | ||
image_features_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); | ||
} | ||
} | ||
|
||
void ImageFeatures::Add() { | ||
if (mode_ == ImageFeatures::Mode::Input) { | ||
// In case the input_features are an input to a model, they are added | ||
// as a nullptr to reserve a slot in the inputs. The input_features | ||
// input will be overwritten when ReuseImageFeaturesBuffer is invoked. | ||
index_ = state_.inputs_.size(); | ||
state_.inputs_.push_back(nullptr); | ||
state_.input_names_.push_back(name_.c_str()); | ||
} else { | ||
index_ = state_.outputs_.size(); | ||
state_.outputs_.push_back(image_features_.get()); | ||
state_.output_names_.push_back(name_.c_str()); | ||
} | ||
} | ||
|
||
void ImageFeatures::Update() { | ||
// Initialize empty image_features tensor for no-image or after-prompt input scenarios | ||
if (shape_[0] > 0) { // if num_image_tokens > 0 | ||
shape_[0] = 0; | ||
image_features_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); | ||
state_.inputs_[index_] = image_features_.get(); | ||
} | ||
} | ||
|
||
void ImageFeatures::ReuseImageFeaturesBuffer(ImageFeatures& other) { | ||
if (mode_ == ImageFeatures::Mode::Output || other.mode_ == ImageFeatures::Mode::Input) { | ||
throw std::runtime_error("Incorrect usage of the ImageFeatures inputs and outputs."); | ||
} | ||
|
||
// Share the output ImageFeatures OrtValue* from other with the input ImageFeatures for this. | ||
image_features_ = std::move(other.image_features_); | ||
state_.inputs_[index_] = other.state_.outputs_[other.index_]; | ||
} | ||
|
||
} // namespace Generators |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
namespace Generators { | ||
|
||
struct ImageFeatures { | ||
enum struct Mode { | ||
Input = 0, | ||
Output | ||
}; | ||
|
||
ImageFeatures(const Model& model, State& state, ImageFeatures::Mode mode, const std::string& name, int64_t num_image_tokens); | ||
ImageFeatures(const ImageFeatures&) = delete; | ||
ImageFeatures& operator=(const ImageFeatures&) = delete; | ||
|
||
void Add(); | ||
void Update(); | ||
void ReuseImageFeaturesBuffer(ImageFeatures& other); | ||
|
||
auto& GetShape() const { return shape_; } | ||
OrtValue* Get() { return image_features_.get(); } | ||
|
||
private: | ||
const Model& model_; | ||
State& state_; | ||
size_t index_{~0U}; | ||
|
||
const Mode mode_{}; | ||
const std::string name_; | ||
|
||
std::array<int64_t, 2> shape_{}; // [num_image_tokens, hidden_size] | ||
ONNXTensorElementDataType type_; | ||
std::unique_ptr<OrtValue> image_features_; | ||
}; | ||
|
||
} // namespace Generators |
Oops, something went wrong.