Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore/issue127: embeddings status endpoint #132

Merged
merged 2 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions crates/edgen_server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ async fn run_server(args: &cli::Serve) -> Result<bool, types::EdgenError> {
)
.await;

status::set_embeddings_active_model(&SETTINGS.read().await.read().await.embeddings_model_name)
.await;

let http_app = routes::routes()
.layer(CorsLayer::permissive())
.layer(DefaultBodyLimit::max(
Expand Down Expand Up @@ -295,6 +298,10 @@ async fn run_server(args: &cli::Serve) -> Result<bool, types::EdgenError> {
.audio_transcriptions_model_name,
)
.await;
status::set_embeddings_active_model(
&SETTINGS.read().await.read().await.embeddings_model_name,
)
.await;
});
});

Expand Down
9 changes: 6 additions & 3 deletions crates/edgen_server/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,15 +332,15 @@ async fn observe_download(
Endpoint::AudioTranscriptions => {
status::observe_audio_transcriptions_progress(dir, size, download).await
}
Endpoint::Embeddings => todo!(),
Endpoint::Embeddings => status::observe_embeddings_progress(dir, size, download).await,
}
}

async fn report_start_of_download(ep: Endpoint) {
match ep {
Endpoint::ChatCompletions => status::set_chat_completions_download(true).await,
Endpoint::AudioTranscriptions => status::set_audio_transcriptions_download(true).await,
Endpoint::Embeddings => todo!(),
Endpoint::Embeddings => status::set_embeddings_download(true).await,
}
}

Expand All @@ -354,7 +354,10 @@ async fn report_end_of_download(ep: Endpoint) {
status::set_audio_transcriptions_progress(100).await;
status::set_audio_transcriptions_download(false).await;
}
Endpoint::Embeddings => todo!(),
Endpoint::Embeddings => {
status::set_embeddings_progress(100).await;
status::set_embeddings_download(false).await;
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions crates/edgen_server/src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ pub fn routes() -> Router {
"/v1/audio/transcriptions/status",
get(status::audio_transcriptions_status),
)
// ---- Embeddings -----------------------------------------------------
.route("/v1/embeddings/status", get(status::embeddings_status))
// -- Model Manager ----------------------------------------------------
// -- Model Manager ----------------------------------------------------
.route("/v1/models", get(model_man::list_models))
.route("/v1/models/:model", get(model_man::retrieve_model))
Expand Down
102 changes: 102 additions & 0 deletions crates/edgen_server/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ pub async fn audio_transcriptions_status() -> Response {
Json(state.clone()).into_response()
}

/// GET `/v1/embeddings`: returns the current status of the /embeddings endpoint.
///
/// The status is returned as json value AIStatus.
/// For any error, the version endpoint returns "internal server error".
pub async fn embeddings_status() -> Response {
let state = get_embeddings_status().read().await;
Json(state.clone()).into_response()
}

/// Current Endpoint status.
#[derive(ToSchema, Deserialize, Serialize, Clone, Debug, PartialEq, Eq)]
pub struct AIStatus {
Expand Down Expand Up @@ -72,6 +81,7 @@ static AISTATES: Lazy<AIStates> = Lazy::new(Default::default);

const EP_CHAT_COMPLETIONS: usize = 0;
const EP_AUDIO_TRANSCRIPTIONS: usize = 1;
const EP_EMBEDDINGS: usize = 2;

const MAX_ERRORS: usize = 32;

Expand All @@ -87,6 +97,12 @@ pub fn get_audio_transcriptions_status() -> &'static RwLock<AIStatus> {
get_status(EP_AUDIO_TRANSCRIPTIONS)
}

/// Get a protected embeddings status.
/// Call read() or write() on the returned value to get either read or write access.
pub fn get_embeddings_status() -> &'static RwLock<AIStatus> {
get_status(EP_EMBEDDINGS)
}

fn get_status(idx: usize) -> &'static RwLock<AIStatus> {
&AISTATES.endpoints[idx]
}
Expand All @@ -101,6 +117,11 @@ pub async fn reset_audio_transcriptions_status() {
reset_status(EP_AUDIO_TRANSCRIPTIONS).await;
}

/// Reset the embeddings status to its defaults
pub async fn reset_embeddings_status() {
reset_status(EP_EMBEDDINGS).await;
}

async fn reset_status(idx: usize) {
let mut status = get_status(idx).write().await;
*status = AIStatus::default();
Expand All @@ -116,6 +137,11 @@ pub async fn set_audio_transcriptions_active_model(model: &str) {
set_active_model(EP_AUDIO_TRANSCRIPTIONS, model).await;
}

/// Set embeddings active model
pub async fn set_embeddings_active_model(model: &str) {
set_active_model(EP_EMBEDDINGS, model).await;
}

async fn set_active_model(idx: usize, model: &str) {
let mut state = get_status(idx).write().await;
state.active_model = model.to_string();
Expand All @@ -141,6 +167,16 @@ pub async fn set_audio_transcriptions_download(ongoing: bool) {
set_download(EP_AUDIO_TRANSCRIPTIONS, ongoing).await;
}

/// Set embeddings download ongoing
pub async fn set_embeddings_download(ongoing: bool) {
if ongoing {
info!("starting embeddings model download");
} else {
info!("embeddings model download finished");
};
set_download(EP_EMBEDDINGS, ongoing).await;
}

async fn set_download(idx: usize, ongoing: bool) {
let mut state = get_status(idx).write().await;
state.download_ongoing = ongoing;
Expand All @@ -156,6 +192,11 @@ pub async fn set_audio_transcriptions_progress(progress: u64) {
set_progress(EP_AUDIO_TRANSCRIPTIONS, progress).await;
}

/// Set embeddings download progress
pub async fn set_embeddings_progress(progress: u64) {
set_progress(EP_EMBEDDINGS, progress).await;
}

async fn set_progress(idx: usize, progress: u64) {
let mut state = get_status(idx).write().await;
state.download_progress = progress;
Expand All @@ -179,6 +220,15 @@ pub async fn observe_audio_transcriptions_progress(
observe_progress(EP_AUDIO_TRANSCRIPTIONS, datadir, size, download).await
}

/// Observe embeddings download progress
pub async fn observe_embeddings_progress(
datadir: &PathBuf,
size: Option<u64>,
download: bool,
) -> tokio::task::JoinHandle<()> {
observe_progress(EP_EMBEDDINGS, datadir, size, download).await
}

/// Add an error to the last errors in chat completions
pub async fn add_chat_completions_error<E>(e: E)
where
Expand Down Expand Up @@ -217,6 +267,7 @@ impl Default for AIStates {
endpoints: vec![
RwLock::new(Default::default()),
RwLock::new(Default::default()),
RwLock::new(Default::default()),
],
}
}
Expand Down Expand Up @@ -603,4 +654,55 @@ mod tests {
assert!(response.text().len() > 0);
assert_eq!(response.json::<AIStatus>().active_model, model);
}

#[tokio::test]
async fn test_embeddings_status() {
reset_embeddings_status().await;

// default
let mut expected = AIStatus::default();

{
let status = get_embeddings_status().read().await;
assert_eq!(*status, AIStatus::default());
}

// download ongoing
expected.download_ongoing = true;
set_embeddings_download(true).await;

{
let status = get_embeddings_status().read().await;
assert_eq!(*status, expected);
}

// download progress
expected.download_progress = 42;
set_embeddings_progress(42).await;

{
let status = get_embeddings_status().read().await;
assert_eq!(*status, expected);
}

// axum router
let router = Router::new().route("/v1/embeddings/status", get(embeddings_status));

let server = TestServer::new(router).expect("cannot instantiate TestServer");

let response = server.get("/v1/embeddings/status").await;

response.assert_status_ok();
assert!(response.text().len() > 0);
assert_eq!(response.json::<AIStatus>().active_model, "unknown");

let model = "shes-a-model-and-shes-looking-good".to_string();
set_embeddings_active_model(&model).await;

let response = server.get("/v1/embeddings/status").await;

response.assert_status_ok();
assert!(response.text().len() > 0);
assert_eq!(response.json::<AIStatus>().active_model, model);
}
}
Loading