Skip to content

Commit

Permalink
Merge pull request #25 from superfly/api-res-format
Browse files Browse the repository at this point in the history
New query / watch stream response format
  • Loading branch information
jeromegn authored Aug 18, 2023
2 parents 7edfd9e + dcd54ae commit 111aa57
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 124 deletions.
61 changes: 32 additions & 29 deletions crates/corro-agent/src/api/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use corro_types::{
api::{QueryEvent, RqliteResponse, RqliteResult, Statement},
broadcast::{ChangeV1, Changeset, Timestamp},
change::{row_to_change, SqliteValue},
pubsub::ChangeType,
schema::{make_schema_inner, parse_sql},
sqlite::SqlitePoolError,
};
Expand Down Expand Up @@ -353,7 +352,7 @@ async fn build_query_rows_response(
agent: &Agent,
data_tx: mpsc::Sender<QueryEvent>,
stmt: Statement,
) -> Option<(StatusCode, RqliteResult)> {
) -> Result<(), (StatusCode, RqliteResult)> {
let (res_tx, res_rx) = oneshot::channel();

let pool = agent.pool().clone();
Expand All @@ -362,7 +361,7 @@ async fn build_query_rows_response(
let conn = match pool.read().await {
Ok(conn) => conn,
Err(e) => {
_ = res_tx.send(Some((
_ = res_tx.send(Err((
StatusCode::INTERNAL_SERVER_ERROR,
RqliteResult::Error {
error: e.to_string(),
Expand All @@ -381,7 +380,7 @@ async fn build_query_rows_response(
let mut prepped = match prepped_res {
Ok(prepped) => prepped,
Err(e) => {
_ = res_tx.send(Some((
_ = res_tx.send(Err((
StatusCode::BAD_REQUEST,
RqliteResult::Error {
error: e.to_string(),
Expand All @@ -405,10 +404,12 @@ async fn build_query_rows_response(
return;
}

let start = Instant::now();

let mut rows = match prepped.query(()) {
Ok(rows) => rows,
Err(e) => {
_ = res_tx.send(Some((
_ = res_tx.send(Err((
StatusCode::INTERNAL_SERVER_ERROR,
RqliteResult::Error {
error: e.to_string(),
Expand All @@ -417,8 +418,9 @@ async fn build_query_rows_response(
return;
}
};
let elapsed = start.elapsed();

if let Err(_e) = res_tx.send(None) {
if let Err(_e) = res_tx.send(Ok(())) {
error!("could not send back response through oneshot channel, aborting");
return;
}
Expand All @@ -433,11 +435,8 @@ async fn build_query_rows_response(
.collect::<rusqlite::Result<Vec<_>>>()
{
Ok(cells) => {
if let Err(e) = data_tx.blocking_send(QueryEvent::Row {
change_type: ChangeType::Upsert,
rowid,
cells,
}) {
if let Err(e) = data_tx.blocking_send(QueryEvent::Row(rowid, cells))
{
error!("could not send back row: {e}");
return;
}
Expand All @@ -459,12 +458,16 @@ async fn build_query_rows_response(
}
}
}

_ = data_tx.blocking_send(QueryEvent::EndOfQuery {
time: elapsed.as_secs_f64(),
});
});
});

match res_rx.await {
Ok(res) => res,
Err(e) => Some((
Err(e) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
RqliteResult::Error {
error: e.to_string(),
Expand Down Expand Up @@ -513,7 +516,13 @@ pub async fn api_v1_queries(
});

match build_query_rows_response(&agent, data_tx, stmt).await {
Some((status, res)) => {
Ok(_) => {
return hyper::Response::builder()
.status(StatusCode::OK)
.body(body)
.expect("could not build query response body");
}
Err((status, res)) => {
return hyper::Response::builder()
.status(status)
.body(
Expand All @@ -523,12 +532,6 @@ pub async fn api_v1_queries(
)
.expect("could not build query response body");
}
None => {
return hyper::Response::builder()
.status(StatusCode::OK)
.body(body)
.expect("could not build query response body");
}
}
}

Expand Down Expand Up @@ -842,11 +845,7 @@ mod tests {

assert_eq!(
row,
QueryEvent::Row {
rowid: 1,
change_type: ChangeType::Upsert,
cells: vec!["service-id".into(), "service-name".into()]
}
QueryEvent::Row(1, vec!["service-id".into(), "service-name".into()])
);

buf.extend_from_slice(&body.data().await.unwrap()?);
Expand All @@ -857,13 +856,17 @@ mod tests {

assert_eq!(
row,
QueryEvent::Row {
rowid: 2,
change_type: ChangeType::Upsert,
cells: vec!["service-id-2".into(), "service-name-2".into()]
}
QueryEvent::Row(2, vec!["service-id-2".into(), "service-name-2".into()])
);

buf.extend_from_slice(&body.data().await.unwrap()?);

let s = lines.decode(&mut buf).unwrap().unwrap();

let query_evt: QueryEvent = serde_json::from_str(&s).unwrap();

assert!(matches!(query_evt, QueryEvent::EndOfQuery { .. }));

assert!(body.data().await.is_none());

Ok(())
Expand Down
64 changes: 32 additions & 32 deletions crates/corro-agent/src/api/pubsub.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use std::{collections::HashMap, io::Write, sync::Arc, time::Duration};
use std::{
collections::HashMap,
io::Write,
sync::Arc,
time::{Duration, Instant},
};

use axum::{http::StatusCode, response::IntoResponse, Extension};
use bytes::{BufMut, BytesMut};
Expand All @@ -7,7 +12,7 @@ use corro_types::{
agent::Agent,
api::{QueryEvent, Statement},
change::SqliteValue,
pubsub::{normalize_sql, ChangeType, Matcher, MatcherCmd},
pubsub::{normalize_sql, Matcher, MatcherCmd},
};
use futures::future::poll_fn;
use rusqlite::Connection;
Expand Down Expand Up @@ -101,7 +106,9 @@ async fn watch_by_id(agent: Agent, id: Uuid) -> hyper::Response<hyper::Body> {

init_tx.blocking_send(QueryEvent::Columns(matcher.0.col_names.clone()))?;

let start = Instant::now();
let mut rows = prepped.query(())?;
let elapsed = start.elapsed();

loop {
let row = match rows.next()? {
Expand All @@ -114,21 +121,19 @@ async fn watch_by_id(agent: Agent, id: Uuid) -> hyper::Response<hyper::Body> {
.map(|i| row.get::<_, SqliteValue>(i))
.collect::<rusqlite::Result<Vec<_>>>()?;

init_tx.blocking_send(QueryEvent::Row {
rowid,
change_type: ChangeType::Upsert,
cells,
})?;
init_tx.blocking_send(QueryEvent::Row(rowid, cells))?;
}

_ = init_tx.blocking_send(QueryEvent::EndOfQuery {
time: elapsed.as_secs_f64(),
})?;

Ok::<_, QueryTempError>(())
});

if let Err(QueryTempError::Sqlite(e)) = res {
_ = init_tx.send(QueryEvent::Error(e.to_compact_string())).await;
}

_ = init_tx.send(QueryEvent::EndOfQuery).await;
});

hyper::Response::builder()
Expand Down Expand Up @@ -189,7 +194,7 @@ async fn process_watch_channel(
let mut query_evt = query_evt;

loop {
if matches!(query_evt, QueryEvent::EndOfQuery) {
if matches!(query_evt, QueryEvent::EndOfQuery { .. }) {
init_done = true;
}

Expand Down Expand Up @@ -459,7 +464,7 @@ pub async fn api_v1_watches(
#[cfg(test)]
mod tests {
use arc_swap::ArcSwap;
use corro_types::{actor::ActorId, agent::SplitPool, config::Config};
use corro_types::{actor::ActorId, agent::SplitPool, config::Config, pubsub::ChangeType};
use http_body::Body;
use tokio_util::codec::{Decoder, LinesCodec};
use tripwire::Tripwire;
Expand Down Expand Up @@ -578,31 +583,26 @@ mod tests {

assert_eq!(
rows.recv().await.unwrap().unwrap(),
QueryEvent::Row {
rowid: 1,
change_type: ChangeType::Upsert,
cells: vec!["service-id".into(), "service-name".into()]
}
QueryEvent::Row(1, vec!["service-id".into(), "service-name".into()])
);

assert_eq!(
rows.recv().await.unwrap().unwrap(),
QueryEvent::Row {
rowid: 2,
change_type: ChangeType::Upsert,
cells: vec!["service-id-2".into(), "service-name-2".into()]
}
QueryEvent::Row(2, vec!["service-id-2".into(), "service-name-2".into()])
);

assert_eq!(rows.recv().await.unwrap().unwrap(), QueryEvent::EndOfQuery,);
assert!(matches!(
rows.recv().await.unwrap().unwrap(),
QueryEvent::EndOfQuery { .. }
));

assert_eq!(
rows.recv().await.unwrap().unwrap(),
QueryEvent::Row {
rowid: 3,
change_type: ChangeType::Upsert,
cells: vec!["service-id-3".into(), "service-name-3".into()]
}
QueryEvent::Change(
ChangeType::Insert,
3,
vec!["service-id-3".into(), "service-name-3".into()]
)
);

let (status_code, _) = api_v1_transactions(
Expand All @@ -618,11 +618,11 @@ mod tests {

assert_eq!(
rows.recv().await.unwrap().unwrap(),
QueryEvent::Row {
rowid: 4,
change_type: ChangeType::Upsert,
cells: vec!["service-id-4".into(), "service-name-4".into()]
}
QueryEvent::Change(
ChangeType::Insert,
4,
vec!["service-id-4".into(), "service-name-4".into()]
)
);

Ok(())
Expand Down
14 changes: 8 additions & 6 deletions crates/corro-api-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,20 @@ use sqlite::ChangeType;
pub mod sqlite;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case", tag = "event", content = "data")]
#[serde(rename_all = "snake_case")]
pub enum QueryEvent {
Columns(Vec<CompactString>),
Row {
rowid: i64,
change_type: ChangeType,
cells: Vec<SqliteValue>,
Row(i64, Vec<SqliteValue>),
#[serde(rename = "eoq")]
EndOfQuery {
time: f64,
},
EndOfQuery,
Change(ChangeType, i64, Vec<SqliteValue>),
Error(CompactString),
}

pub type RowIdCells = (i64, Vec<SqliteValue>);

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Statement {
Expand Down
3 changes: 2 additions & 1 deletion crates/corro-api-types/src/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ use serde::{Deserialize, Serialize};
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ChangeType {
Upsert,
Insert,
Update,
Delete,
}

Expand Down
35 changes: 19 additions & 16 deletions crates/corro-tpl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,24 +231,27 @@ impl QueryResponseIter {
match res {
Some(Ok(evt)) => match evt {
QueryEvent::Columns(cols) => self.columns = Some(Arc::new(cols)),
QueryEvent::Row { rowid, cells, .. } => match self.columns.as_ref() {
Some(columns) => {
return Some(Ok(Row {
id: rowid,
columns: columns.clone(),
cells,
}));
}
None => {
self.done = true;
return Some(Err(Box::new(EvalAltResult::from(
"did not receive columns data",
))));
}
},
QueryEvent::EndOfQuery => {
QueryEvent::EndOfQuery { .. } => {
return None;
}
QueryEvent::Row(rowid, cells) | QueryEvent::Change(_, rowid, cells) => {
println!("got a row (rowid: {rowid}) or a change...");
match self.columns.as_ref() {
Some(columns) => {
return Some(Ok(Row {
id: rowid,
columns: columns.clone(),
cells,
}));
}
None => {
self.done = true;
return Some(Err(Box::new(EvalAltResult::from(
"did not receive columns data",
))));
}
}
}
QueryEvent::Error(e) => {
self.done = true;
return Some(Err(Box::new(EvalAltResult::from(e))));
Expand Down
Loading

0 comments on commit 111aa57

Please sign in to comment.