Skip to content

Commit

Permalink
[chore/issue127] embeddings status
Browse files Browse the repository at this point in the history
  • Loading branch information
toschoo committed Mar 26, 2024
1 parent e9a3cc0 commit 3d4c749
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 0 deletions.
3 changes: 3 additions & 0 deletions crates/edgen_server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ async fn run_server(args: &cli::Serve) -> Result<bool, types::EdgenError> {
)
.await;

status::set_embeddings_active_model(&SETTINGS.read().await.read().await.embeddings_model_name)
.await;

let http_app = routes::routes()
.layer(CorsLayer::permissive())
.layer(DefaultBodyLimit::max(
Expand Down
3 changes: 3 additions & 0 deletions crates/edgen_server/src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ pub fn routes() -> Router {
"/v1/audio/transcriptions/status",
get(status::audio_transcriptions_status),
)
// ---- Embeddings -----------------------------------------------------
.route("/v1/embeddings/status", get(status::embeddings_status))
// -- Model Manager ----------------------------------------------------
// -- Model Manager ----------------------------------------------------
.route("/v1/models", get(model_man::list_models))
.route("/v1/models/:model", get(model_man::retrieve_model))
Expand Down
102 changes: 102 additions & 0 deletions crates/edgen_server/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ pub async fn audio_transcriptions_status() -> Response {
Json(state.clone()).into_response()
}

/// GET `/v1/embeddings`: returns the current status of the /embeddings endpoint.
///
/// The status is returned as json value AIStatus.
/// For any error, the version endpoint returns "internal server error".
pub async fn embeddings_status() -> Response {
let state = get_embeddings_status().read().await;
Json(state.clone()).into_response()
}

/// Current Endpoint status.
#[derive(ToSchema, Deserialize, Serialize, Clone, Debug, PartialEq, Eq)]
pub struct AIStatus {
Expand Down Expand Up @@ -72,6 +81,7 @@ static AISTATES: Lazy<AIStates> = Lazy::new(Default::default);

const EP_CHAT_COMPLETIONS: usize = 0;
const EP_AUDIO_TRANSCRIPTIONS: usize = 1;
const EP_EMBEDDINGS: usize = 2;

const MAX_ERRORS: usize = 32;

Expand All @@ -87,6 +97,12 @@ pub fn get_audio_transcriptions_status() -> &'static RwLock<AIStatus> {
get_status(EP_AUDIO_TRANSCRIPTIONS)
}

/// Get a protected embeddings status.
/// Call read() or write() on the returned value to get either read or write access.
pub fn get_embeddings_status() -> &'static RwLock<AIStatus> {
get_status(EP_EMBEDDINGS)
}

fn get_status(idx: usize) -> &'static RwLock<AIStatus> {
&AISTATES.endpoints[idx]
}
Expand All @@ -101,6 +117,11 @@ pub async fn reset_audio_transcriptions_status() {
reset_status(EP_AUDIO_TRANSCRIPTIONS).await;
}

/// Reset the embeddings status to its defaults
pub async fn reset_embeddings_status() {
reset_status(EP_EMBEDDINGS).await;
}

async fn reset_status(idx: usize) {
let mut status = get_status(idx).write().await;
*status = AIStatus::default();
Expand All @@ -116,6 +137,11 @@ pub async fn set_audio_transcriptions_active_model(model: &str) {
set_active_model(EP_AUDIO_TRANSCRIPTIONS, model).await;
}

/// Set embeddings active model
pub async fn set_embeddings_active_model(model: &str) {
set_active_model(EP_EMBEDDINGS, model).await;
}

async fn set_active_model(idx: usize, model: &str) {
let mut state = get_status(idx).write().await;
state.active_model = model.to_string();
Expand All @@ -141,6 +167,16 @@ pub async fn set_audio_transcriptions_download(ongoing: bool) {
set_download(EP_AUDIO_TRANSCRIPTIONS, ongoing).await;
}

/// Set embeddings download ongoing
pub async fn set_embeddings_download(ongoing: bool) {
if ongoing {
info!("starting embeddings model download");
} else {
info!("embeddings model download finished");
};
set_download(EP_EMBEDDINGS, ongoing).await;
}

async fn set_download(idx: usize, ongoing: bool) {
let mut state = get_status(idx).write().await;
state.download_ongoing = ongoing;
Expand All @@ -156,6 +192,11 @@ pub async fn set_audio_transcriptions_progress(progress: u64) {
set_progress(EP_AUDIO_TRANSCRIPTIONS, progress).await;
}

/// Set embeddings download progress
pub async fn set_embeddings_progress(progress: u64) {
set_progress(EP_EMBEDDINGS, progress).await;
}

async fn set_progress(idx: usize, progress: u64) {
let mut state = get_status(idx).write().await;
state.download_progress = progress;
Expand All @@ -179,6 +220,15 @@ pub async fn observe_audio_transcriptions_progress(
observe_progress(EP_AUDIO_TRANSCRIPTIONS, datadir, size, download).await
}

/// Observe embeddings download progress
pub async fn observe_embeddings_progress(
datadir: &PathBuf,
size: Option<u64>,
download: bool,
) -> tokio::task::JoinHandle<()> {
observe_progress(EP_EMBEDDINGS, datadir, size, download).await
}

/// Add an error to the last errors in chat completions
pub async fn add_chat_completions_error<E>(e: E)
where
Expand Down Expand Up @@ -217,6 +267,7 @@ impl Default for AIStates {
endpoints: vec![
RwLock::new(Default::default()),
RwLock::new(Default::default()),
RwLock::new(Default::default()),
],
}
}
Expand Down Expand Up @@ -603,4 +654,55 @@ mod tests {
assert!(response.text().len() > 0);
assert_eq!(response.json::<AIStatus>().active_model, model);
}

#[tokio::test]
async fn test_embeddings_status() {
reset_embeddings_status().await;

// default
let mut expected = AIStatus::default();

{
let status = get_embeddings_status().read().await;
assert_eq!(*status, AIStatus::default());
}

// download ongoing
expected.download_ongoing = true;
set_embeddings_download(true).await;

{
let status = get_embeddings_status().read().await;
assert_eq!(*status, expected);
}

// download progress
expected.download_progress = 42;
set_embeddings_progress(42).await;

{
let status = get_embeddings_status().read().await;
assert_eq!(*status, expected);
}

// axum router
let router = Router::new().route("/v1/embeddings/status", get(embeddings_status));

let server = TestServer::new(router).expect("cannot instantiate TestServer");

let response = server.get("/v1/embeddings/status").await;

response.assert_status_ok();
assert!(response.text().len() > 0);
assert_eq!(response.json::<AIStatus>().active_model, "unknown");

let model = "shes-a-model-and-shes-looking-good".to_string();
set_embeddings_active_model(&model).await;

let response = server.get("/v1/embeddings/status").await;

response.assert_status_ok();
assert!(response.text().len() > 0);
assert_eq!(response.json::<AIStatus>().active_model, model);
}
}

0 comments on commit 3d4c749

Please sign in to comment.