Skip to content

Commit

Permalink
Support loading multiple audio files
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Sep 11, 2024
1 parent f937c15 commit e06c46f
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 35 deletions.
1 change: 1 addition & 0 deletions examples/c/src/phi3v.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ void C_API(const char* model_path) {
OgaStringArray* image_paths_string_array;
CheckResult(OgaCreateStringArrayFromStrings(image_paths_c.data(), image_paths_c.size(), &image_paths_string_array));
CheckResult(OgaLoadImages(image_paths_string_array, &images));
OgaDestroyStringArray(image_paths_string_array);
}

std::string text;
Expand Down
107 changes: 74 additions & 33 deletions examples/c/src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ bool FileExists(const char* path) {
return static_cast<bool>(std::ifstream(path));
}

std::string trim(const std::string& str) {
const size_t first = str.find_first_not_of(' ');
if (std::string::npos == first) {
return str;
}
const size_t last = str.find_last_not_of(' ');
return str.substr(first, (last - first + 1));
}

// C++ API Example

void CXX_API(const char* model_path, int32_t num_beams) {
Expand All @@ -23,24 +32,39 @@ void CXX_API(const char* model_path, int32_t num_beams) {
auto tokenizer = OgaTokenizer::Create(*model);

while (true) {
std::string audio_path;
std::cout << "Audio Path:" << std::endl;
std::getline(std::cin, audio_path);
std::cout << "Loading audio..." << std::endl;
if (audio_path.empty()) {
throw std::runtime_error("Audio file not provided.");
} else if (!FileExists(audio_path.c_str())) {
throw std::runtime_error(std::string("Audio file not found: ") + audio_path);
std::string audio_paths_str;
std::cout << "Audio Paths (comma separated):" << std::endl;
std::getline(std::cin, audio_paths_str);
std::unique_ptr<OgaAudios> audios;
std::vector<std::string> audio_paths;
for (size_t start = 0, end = 0; end < audio_paths_str.size(); start = end + 1) {
end = audio_paths_str.find(',', start);
audio_paths.push_back(trim(audio_paths_str.substr(start, end - start)));
}
if (audio_paths.empty()) {
throw std::runtime_error("No audio file provided.");
} else {
std::cout << "Loading audios..." << std::endl;
for (const auto& audio_path : audio_paths) {
if (!FileExists(audio_path.c_str())) {
throw std::runtime_error(std::string("Audio file not found: ") + audio_path);
}
}
std::vector<const char*> audio_paths_c;
for (const auto& audio_path : audio_paths) audio_paths_c.push_back(audio_path.c_str());
audios = OgaAudios::Load(audio_paths_c);
}
std::unique_ptr<OgaAudios> audios = OgaAudios::Load(audio_path.c_str());

std::cout << "Processing audio..." << std::endl;
auto mel = processor->ProcessAudios(audios.get());
const std::array<const char*, 4> prompt_tokens = {"<|startoftranscript|>", "<|en|>", "<|transcribe|>",
"<|notimestamps|>"};
auto input_ids = OgaSequences::Create();
for (const auto& token : prompt_tokens) {
tokenizer->ToTokenId(token, *input_ids, 0);
const size_t batch_size = audio_paths.size();
for (size_t i = 0; i < batch_size; ++i) {
for (const auto& token : prompt_tokens) {
tokenizer->ToTokenId(token, *input_ids, i);
}
}

std::cout << "Generating response..." << std::endl;
Expand All @@ -59,11 +83,11 @@ void CXX_API(const char* model_path, int32_t num_beams) {
generator->GenerateNextToken();
}

std::cout << "Transcription:" << std::endl;
for (size_t beam = 0; beam < static_cast<size_t>(num_beams); ++beam) {
std::cout << " Beam " << beam << ":";
const auto num_tokens = generator->GetSequenceCount(beam);
const auto tokens = generator->GetSequenceData(beam);
for (size_t i = 0; i < static_cast<size_t>(num_beams * batch_size); ++i) {
std::cout << "Transcription:" << std::endl;
std::cout << " batch " << i / num_beams << ", beam " << i % num_beams << ":";
const auto num_tokens = generator->GetSequenceCount(i);
const auto tokens = generator->GetSequenceData(i);
std::cout << processor->Decode(tokens, num_tokens) << std::endl;
}

Expand Down Expand Up @@ -95,17 +119,31 @@ void C_API(const char* model_path, int32_t num_beams) {
CheckResult(OgaCreateTokenizer(model, &tokenizer));

while (true) {
std::string audio_path;
std::cout << "Audio Path:" << std::endl;
std::getline(std::cin, audio_path);
std::cout << "Loading audio..." << std::endl;
if (audio_path.empty()) {
throw std::runtime_error("Audio file not provided.");
} else if (!FileExists(audio_path.c_str())) {
throw std::runtime_error(std::string("Audio file not found: ") + audio_path);
}
std::string audio_paths_str;
std::cout << "Audio Paths (comma separated):" << std::endl;
std::getline(std::cin, audio_paths_str);
OgaAudios* audios = nullptr;
CheckResult(OgaLoadAudio(audio_path.c_str(), &audios));
std::vector<std::string> audio_paths;
for (size_t start = 0, end = 0; end < audio_paths_str.size(); start = end + 1) {
end = audio_paths_str.find(',', start);
audio_paths.push_back(trim(audio_paths_str.substr(start, end - start)));
}
if (audio_paths.empty()) {
throw std::runtime_error("No audio file provided.");
} else {
std::cout << "Loading audios..." << std::endl;
for (const auto& audio_path : audio_paths) {
if (!FileExists(audio_path.c_str())) {
throw std::runtime_error(std::string("Audio file not found: ") + audio_path);
}
std::vector<const char*> audio_paths_c;
for (const auto& audio_path : audio_paths) audio_paths_c.push_back(audio_path.c_str());
OgaStringArray* audio_paths_string_array;
CheckResult(OgaCreateStringArrayFromStrings(audio_paths_c.data(), audio_paths_c.size(), &audio_paths_string_array));
CheckResult(OgaLoadAudios(audio_paths_string_array, &audios));
OgaDestroyStringArray(audio_paths_string_array);
}
}

std::cout << "Processing audio..." << std::endl;
OgaNamedTensors* mel;
Expand All @@ -114,8 +152,11 @@ void C_API(const char* model_path, int32_t num_beams) {
"<|notimestamps|>"};
OgaSequences* input_ids;
CheckResult(OgaCreateSequences(&input_ids));
for (const auto& token : prompt_tokens) {
CheckResult(OgaTokenizerToTokenId(tokenizer, token, input_ids, 0));
const size_t batch_size = audio_paths.size();
for (size_t i = 0; i < batch_size; ++i) {
for (const auto& token : prompt_tokens) {
CheckResult(OgaTokenizerToTokenId(tokenizer, token, input_ids, i));
}
}

std::cout << "Generating response..." << std::endl;
Expand All @@ -136,11 +177,11 @@ void C_API(const char* model_path, int32_t num_beams) {
CheckResult(OgaGenerator_GenerateNextToken(generator));
}

std::cout << "Transcription:" << std::endl;
for (size_t beam = 0; beam < static_cast<size_t>(num_beams); ++beam) {
std::cout << " Beam " << beam << ":";
const int32_t num_tokens = OgaGenerator_GetSequenceCount(generator, beam);
const int32_t* tokens = OgaGenerator_GetSequenceData(generator, beam);
for (size_t i = 0; i < static_cast<size_t>(num_beams * batch_size); ++i) {
std::cout << "Transcription:" << std::endl;
std::cout << " batch " << i / num_beams << ", beam " << i % num_beams << ":";
const int32_t num_tokens = OgaGenerator_GetSequenceCount(generator, i);
const int32_t* tokens = OgaGenerator_GetSequenceData(generator, i);

const char* str;
CheckResult(OgaProcessorDecode(processor, tokens, num_tokens, &str));
Expand Down
18 changes: 16 additions & 2 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,25 @@ struct OgaImages : OgaAbstract {
};

struct OgaAudios : OgaAbstract {
static std::unique_ptr<OgaAudios> Load(const char* audio_path) {
static std::unique_ptr<OgaAudios> Load(const std::vector<const char*>& audio_paths) {
OgaAudios* p;
OgaCheckResult(OgaLoadAudio(audio_path, &p));
OgaStringArray* strs;
OgaCheckResult(OgaCreateStringArrayFromStrings(audio_paths.data(), audio_paths.size(), &strs));
OgaCheckResult(OgaLoadAudios(strs, &p));
OgaDestroyStringArray(strs);
return std::unique_ptr<OgaAudios>(p);
}

#if __cplusplus >= 202002L
static std::unique_ptr<OgaAudios> Load(std::span<const char* const> audio_paths) {
OgaAudios* p;
OgaStringArray* strs;
OgaCheckResult(OgaCreateStringArrayFromStrings(audio_paths.data(), audio_paths.size(), &strs));
OgaCheckResult(OgaLoadAudios(strs, &p));
OgaDestroyStringArray(strs);
return std::unique_ptr<OgaAudios>(p);
}
#endif

static void operator delete(void* p) { OgaDestroyAudios(reinterpret_cast<OgaAudios*>(p)); }
};
Expand Down
10 changes: 10 additions & 0 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,16 @@ OgaResult* OGA_API_CALL OgaLoadAudio(const char* audio_path, OgaAudios** audios)
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaLoadAudios(const OgaStringArray* audio_paths, OgaAudios** audios) {
OGA_TRY
const auto& audio_paths_vector = *reinterpret_cast<const std::vector<std::string>*>(audio_paths);
std::vector<const char*> audio_paths_vector_c;
for (const auto& audio_path : audio_paths_vector) audio_paths_vector_c.push_back(audio_path.c_str());
*audios = reinterpret_cast<OgaAudios*>(Generators::LoadAudios(audio_paths_vector_c).release());
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out) {
OGA_TRY
auto model = Generators::CreateModel(Generators::GetOrtEnv(), config_path);
Expand Down
2 changes: 2 additions & 0 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyImages(OgaImages* images);

OGA_EXPORT OgaResult* OGA_API_CALL OgaLoadAudio(const char* audio_path, OgaAudios** audios);

OGA_EXPORT OgaResult* OGA_API_CALL OgaLoadAudios(const OgaStringArray* audio_paths, OgaAudios** audios);

OGA_EXPORT void OGA_API_CALL OgaDestroyAudios(OgaAudios* audios);

/*
Expand Down

0 comments on commit e06c46f

Please sign in to comment.