Skip to content

Commit

Permalink
get decoder prompt ids
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Aug 1, 2024
1 parent 6173b39 commit 82b7148
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 16 deletions.
4 changes: 1 addition & 3 deletions examples/c/src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,14 @@ void CXX_API(const char* model_path, int32_t num_beams) {
std::unique_ptr<OgaAudios> audios = OgaAudios::Load(audio_path.c_str());

std::cout << "Processing audio..." << std::endl;
auto input_tensors = processor->ProcessAudios(audios.get());
auto input_tensors = processor->ProcessAudios(audios.get(), "english", "transcribe", 1);

std::cout << "Generating response..." << std::endl;
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 256);
params->SetSearchOption("num_beams", num_beams);
params->SetSearchOption("num_return_sequences", 4);
params->SetInputs(*input_tensors);
const std::array<int32_t, 3> input_ids = {50258, 50259, 50359};
params->SetInputIDs(input_ids.data(), input_ids.size(), input_ids.size(), 1);

auto generator = OgaGenerator::Create(*model, *params);

Expand Down
5 changes: 1 addition & 4 deletions examples/python/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def run(args: argparse.Namespace):
audio = og.Audios.open(audio_path)

print("Processing audio...")
inputs = processor(audios=audio)
inputs = processor(audios=audio, lang="en", task="transcribe")

params = og.GeneratorParams(model)
params.set_search_options(
Expand All @@ -50,9 +50,6 @@ def run(args: argparse.Namespace):

batch_size = 1
params.set_inputs(inputs)
params.input_ids = np.array(
[[50258, 50259, 50359]] * batch_size, dtype=np.int32
)

generator = og.Generator(model, params)

Expand Down
16 changes: 15 additions & 1 deletion src/models/audio_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ AudioProcessor::AudioProcessor(Config& config, const SessionInfo& session_info)
config.AddMapping(std::string(Config::Defaults::InputFeaturesName), config.model.encoder_decoder_init.inputs.input_features);
}

std::unique_ptr<NamedTensors> AudioProcessor::Process(const Audios* audios) const {
std::unique_ptr<NamedTensors> AudioProcessor::Process(const Tokenizer& tokenizer, const Audios* audios,
const std::string& language, const std::string& task,
int32_t no_timestamps) const {
if (!audios || !audios->audios_) {
throw std::runtime_error("No audios provided to process.");
}
Expand All @@ -77,6 +79,18 @@ std::unique_ptr<NamedTensors> AudioProcessor::Process(const Audios* audios) cons
named_tensors->emplace(std::string(Config::Defaults::InputFeaturesName),
std::make_shared<Tensor>(ProcessMel(mel, input_features_type_, allocator)));

// TOO: This needs the start of transcription token to be added to the prompt.
// It makes sense to add that to the tokenizer if possible.
const auto prompt_token_ids = tokenizer.GetDecoderPromptIds(audios->num_audios_, language, task, no_timestamps);

const std::array<int64_t, 2> shape{static_cast<int64_t>(audios->num_audios_),
static_cast<int64_t>(prompt_token_ids.size())};
auto decoder_input_ids = OrtValue::CreateTensor<int32_t>(allocator, shape);
std::copy(prompt_token_ids.begin(), prompt_token_ids.end(), decoder_input_ids->GetTensorMutableData<int32_t>());

named_tensors->emplace(std::string(Config::Defaults::InputIdsName),
std::make_shared<Tensor>(std::move(decoder_input_ids)));

return named_tensors;
}

Expand Down
4 changes: 3 additions & 1 deletion src/models/audio_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ struct AudioProcessor {
AudioProcessor(const AudioProcessor&) = delete;
AudioProcessor& operator=(const AudioProcessor&) = delete;

std::unique_ptr<NamedTensors> Process(const Audios* audios) const;
std::unique_ptr<NamedTensors> Process(const Tokenizer& tokenizer, const Audios* audios,
const std::string& language, const std::string& task,
int32_t no_timestamps) const;

private:
ort_extensions::OrtxObjectPtr<OrtxFeatureExtractor> processor_;
Expand Down
19 changes: 19 additions & 0 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,25 @@ std::vector<std::string> Tokenizer::DecodeBatch(std::span<const int32_t> sequenc
return strings;
}

std::vector<int32_t> Tokenizer::GetDecoderPromptIds(size_t batch_size, const std::string& language,
const std::string& task, int32_t no_timestamps) const {
ort_extensions::OrtxObjectPtr<OrtxTokenId2DArray> prompt_ids;
CheckResult(OrtxGetDecoderPromptIds(tokenizer_, batch_size, language.c_str(),
task.c_str(), no_timestamps, ort_extensions::ptr(prompt_ids)));

std::vector<std::vector<int32_t>> tokens_vector;
std::vector<std::span<const int32_t>> span_sequences;
for (size_t i = 0; i < batch_size; i++) {
const extTokenId_t* tokens = nullptr;
size_t token_count = 0;
CheckResult(OrtxTokenId2DArrayGetItem(prompt_ids.get(), i, &tokens, &token_count));
tokens_vector.emplace_back(tokens, tokens + token_count);
span_sequences.emplace_back(tokens_vector.back());
}

return PadInputs(span_sequences, pad_token_id_);
}

#if USE_CUDA
// Since Python/Others can and will hold onto a generator object past the model object's lifetime we need to ensure
// the allocator used is not destroyed until last. This keeps the allocator around until exit, after all other memory
Expand Down
3 changes: 3 additions & 0 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ struct Tokenizer : std::enable_shared_from_this<Tokenizer> {
std::vector<int32_t> EncodeBatch(std::span<const std::string> strings) const;
std::vector<std::string> DecodeBatch(std::span<const int32_t> sequences, size_t count) const;

std::vector<int32_t> GetDecoderPromptIds(size_t batch_size, const std::string& language,
const std::string& task, int32_t no_timestamps) const;

OrtxPtr<OrtxTokenizer> tokenizer_;
std::shared_ptr<Tokenizer> external_owner_; // Set to 'this' when created by the C API to preserve lifetime

Expand Down
5 changes: 3 additions & 2 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,10 @@ struct OgaMultiModalProcessor : OgaAbstract {
return std::unique_ptr<OgaNamedTensors>(p);
}

std::unique_ptr<OgaNamedTensors> ProcessAudios(const OgaAudios* audios) const {
std::unique_ptr<OgaNamedTensors> ProcessAudios(const OgaAudios* audios, const std::string& language,
const std::string& task, int32_t no_timestamps) const {
OgaNamedTensors* p;
OgaCheckResult(OgaProcessorProcessAudios(this, audios, &p));
OgaCheckResult(OgaProcessorProcessAudios(this, audios, language.c_str(), task.c_str(), no_timestamps, &p));
return std::unique_ptr<OgaNamedTensors>(p);
}

Expand Down
6 changes: 4 additions & 2 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,15 +386,17 @@ OgaResult* OGA_API_CALL OgaProcessorProcessImages(const OgaMultiModalProcessor*
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaProcessorProcessAudios(const OgaMultiModalProcessor* p, const OgaAudios* audios_p, OgaNamedTensors** input_tensors) {
OgaResult* OGA_API_CALL OgaProcessorProcessAudios(const OgaMultiModalProcessor* p, const OgaAudios* audios_p, const char* language,
const char* task, int32_t no_timestamp, OgaNamedTensors** input_tensors) {
OGA_TRY
auto& processor = *reinterpret_cast<const Generators::MultiModalProcessor*>(p);
auto* audios = reinterpret_cast<const Generators::Audios*>(audios_p);

if (!processor.audio_processor_)
throw std::runtime_error("Audio processor not available for this model.");

auto named_tensors = processor.audio_processor_->Process(audios);
auto named_tensors = processor.audio_processor_->Process(*processor.tokenizer_, audios,
language, task, no_timestamp);
*input_tensors = reinterpret_cast<OgaNamedTensors*>(named_tensors.release());

return nullptr;
Expand Down
2 changes: 1 addition & 1 deletion src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer*, const

OGA_EXPORT OgaResult* OGA_API_CALL OgaProcessorProcessImages(const OgaMultiModalProcessor*, const char* prompt, const OgaImages* images, OgaNamedTensors** input_tensors);

OGA_EXPORT OgaResult* OGA_API_CALL OgaProcessorProcessAudios(const OgaMultiModalProcessor*, const OgaAudios* audios, OgaNamedTensors** input_tensors);
OGA_EXPORT OgaResult* OGA_API_CALL OgaProcessorProcessAudios(const OgaMultiModalProcessor*, const OgaAudios* audios, const char* language, const char* task, int32_t no_timestamp, OgaNamedTensors** input_tensors);

/* Decode a single token sequence and returns a null terminated utf8 string. out_string must be freed with OgaDestroyString
*/
Expand Down
18 changes: 16 additions & 2 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,10 +483,24 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
if (!prompt.has_value()) {
throw std::runtime_error("Prompt is required for processing the image.");
}
return std::make_unique<PyNamedTensors>(processor.image_processor_->Process(*processor.tokenizer_, *prompt, images));
return std::make_unique<PyNamedTensors>(
processor.image_processor_->Process(*processor.tokenizer_, *prompt, images));
} else if (kwargs.contains("audios")) {
const Audios* audios = kwargs["audios"].cast<const Audios*>();
return std::make_unique<PyNamedTensors>(processor.audio_processor_->Process(audios));
std::string language = "en";
if (kwargs.contains("language")) {
language = kwargs["lang"].cast<std::string>();
}
std::string task = "transcribe";
if (kwargs.contains("task")) {
task = kwargs["task"].cast<std::string>();
}
int32_t no_timestamps = 1;
if (kwargs.contains("no_timestamps")) {
no_timestamps = kwargs["no_timestamps"].cast<int32_t>();
}
return std::make_unique<PyNamedTensors>(
processor.audio_processor_->Process(*processor.tokenizer_, audios, language, task, no_timestamps));
} else {
throw std::runtime_error("Nothing to process.");
}
Expand Down

0 comments on commit 82b7148

Please sign in to comment.