diff --git a/src/query.rs b/src/query.rs index eba1915..fd5c3fb 100644 --- a/src/query.rs +++ b/src/query.rs @@ -3,6 +3,8 @@ use serde::{Deserialize, Serialize}; use std::fmt::Display; use url::Url; +#[cfg(feature = "watch")] +use crate::watch; use crate::{ error::{Error, Result}, headers::with_request_headers, @@ -90,6 +92,60 @@ impl Query { Ok(RowCursor::new(response)) } + /// Executes the query, returning a [`watch::RowJsonCursor`] to obtain results. + #[cfg(feature = "watch")] + pub fn fetch_json(mut self) -> Result> { + self.sql.append(" FORMAT JSONEachRowWithProgress"); + + let response = self.do_execute(true)?; + Ok(watch::RowJsonCursor::new(response)) + } + + /// Executes the query and returns just a single row. + /// + /// Note that `T` must be owned. + #[cfg(feature = "watch")] + pub async fn fetch_json_one(self) -> Result + where + T: for<'b> Deserialize<'b>, + { + match self.fetch_json()?.next().await { + Ok(Some(row)) => Ok(row), + Ok(None) => Err(Error::RowNotFound), + Err(err) => Err(err), + } + } + + /// Executes the query and returns at most one row. + /// + /// Note that `T` must be owned. + #[cfg(feature = "watch")] + pub async fn fetch_json_optional(self) -> Result> + where + T: for<'b> Deserialize<'b>, + { + self.fetch_json()?.next().await + } + + /// Executes the query and returns all the generated results, + /// collected into a [`Vec`]. + /// + /// Note that `T` must be owned. + #[cfg(feature = "watch")] + pub async fn fetch_json_all(self) -> Result> + where + T: for<'b> Deserialize<'b>, + { + let mut result = Vec::new(); + let mut cursor = self.fetch_json::()?; + + while let Some(row) = cursor.next().await? { + result.push(row); + } + + Ok(result) + } + /// Executes the query and returns just a single row. /// /// Note that `T` must be owned. diff --git a/src/watch.rs b/src/watch.rs index 1914399..656c9e4 100644 --- a/src/watch.rs +++ b/src/watch.rs @@ -6,6 +6,7 @@ use sha1::{Digest, Sha1}; use crate::{ cursor::JsonCursor, error::{Error, Result}, + response::Response, row::Row, sql::{Bind, SqlBuilder}, Client, Compression, @@ -165,6 +166,23 @@ impl EventCursor { } } +/// A cursor that emits rows in JSON format. +pub struct RowJsonCursor(JsonCursor); + +impl RowJsonCursor { + pub(crate) fn new(response: Response) -> Self { + Self(JsonCursor::new(response)) + } + + /// Emits the next row. + pub async fn next<'a, 'b: 'a>(&'a mut self) -> Result> + where + T: Deserialize<'b>, + { + self.0.next().await + } +} + // === RowCursor === /// A cursor that emits `(Version, T)`. diff --git a/tests/it/query.rs b/tests/it/query.rs index 195297e..4e5e3de 100644 --- a/tests/it/query.rs +++ b/tests/it/query.rs @@ -263,3 +263,98 @@ async fn prints_query() { "SELECT ?fields FROM test WHERE a = ? AND b < ?" ); } + +#[cfg(feature = "watch")] +#[tokio::test] +async fn fetches_json_row() { + let client = prepare_database!(); + + let value = client + .query("SELECT 1,2,3") + .fetch_json_one::() + .await + .unwrap(); + + assert_eq!(value, serde_json::json!({ "1": 1, "2": 2, "3": 3})); + + let value = client + .query("SELECT (1,2,3) as data") + .fetch_json_one::() + .await + .unwrap(); + + assert_eq!(value, serde_json::json!({ "data": [1,2,3]})); +} + +#[cfg(feature = "watch")] +#[tokio::test] +async fn fetches_json_struct() { + let client = prepare_database!(); + + #[derive(Debug, Deserialize, PartialEq)] + struct Row { + one: i8, + two: String, + three: f32, + four: bool, + } + + let value = client + .query("SELECT -1 as one, '2' as two, 3.0 as three, false as four") + .fetch_json_one::() + .await + .unwrap(); + + assert_eq!( + value, + Row { + one: -1, + two: "2".to_owned(), + three: 3.0, + four: false, + } + ); +} + +#[cfg(feature = "watch")] +#[tokio::test] +async fn describes_table() { + let client = prepare_database!(); + + let columns = client + .query("DESCRIBE TABLE system.users") + .fetch_json_all::() + .await + .unwrap(); + for c in &columns { + println!("{c}"); + } + let columns = columns + .into_iter() + .map(|row| { + let column_name = row + .as_object() + .expect("JSONEachRow") + .get("name") + .expect("`system.users` must contain the `name` column"); + (column_name.as_str().unwrap().to_owned(), row) + }) + .collect::>(); + dbg!(&columns); + + let name_column = columns + .get("name") + .expect("`system.users` must contain the `name` column"); + assert_eq!( + name_column.as_object().unwrap().get("type").unwrap(), + &serde_json::json!("String") + ); + + let id_column = columns + .get("id") + .expect("`system.users` must contain the `id` column"); + assert_eq!( + id_column.as_object().unwrap().get("type").unwrap(), + &serde_json::json!("UUID") + ); +}