Skip to content

Commit

Permalink
feat(remote databases): execute_query for snowflake (#7573)
Browse files Browse the repository at this point in the history
* feat(remote databases): execute_query for snowflake

* refactor

* add comment

* fix support nulls

---------

Co-authored-by: Henry Fontanier <henry@dust.tt>
  • Loading branch information
fontanierh and Henry Fontanier authored Sep 24, 2024
1 parent 6fc623e commit e81a040
Show file tree
Hide file tree
Showing 11 changed files with 377 additions and 210 deletions.
47 changes: 28 additions & 19 deletions core/bin/dust_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@ use axum::{
routing::{delete, get, patch, post},
Router,
};
use futures::future::try_join_all;
use hyper::http::StatusCode;
use parking_lot::Mutex;
use serde_json::{json, Value};
use std::collections::{HashMap, HashSet};
use std::convert::Infallible;
use std::sync::Arc;
use tokio::{
net::TcpListener,
signal::unix::{signal, SignalKind},
sync::mpsc::unbounded_channel,
};
use tokio_stream::Stream;
use tower_http::trace::{self, TraceLayer};
use tracing::{error, info, Level};
use tracing_bunyan_formatter::{BunyanFormattingLayer, JsonStorageLayer};
use tracing_subscriber::prelude::*;

use dust::{
api_keys::validate_api_key,
Expand All @@ -19,7 +36,10 @@ use dust::{
data_source::{self, Section},
qdrant::QdrantClients,
},
databases::database::{query_database, QueryDatabaseError, Row, Table},
databases::{
database::{QueryDatabaseError, Row, Table},
transient_database::execute_query_on_transient_database,
},
databases_store::store::{self as databases_store, DatabasesStore},
dataset,
deno::js_executor::JSExecutor,
Expand All @@ -31,23 +51,6 @@ use dust::{
stores::{postgres, store},
utils::{self, error_response, APIError, APIResponse, CoreRequestMakeSpan},
};
use futures::future::try_join_all;
use hyper::http::StatusCode;
use parking_lot::Mutex;
use serde_json::{json, Value};
use std::collections::{HashMap, HashSet};
use std::convert::Infallible;
use std::sync::Arc;
use tokio::{
net::TcpListener,
signal::unix::{signal, SignalKind},
sync::mpsc::unbounded_channel,
};
use tokio_stream::Stream;
use tower_http::trace::{self, TraceLayer};
use tracing::{error, info, Level};
use tracing_bunyan_formatter::{BunyanFormattingLayer, JsonStorageLayer};
use tracing_subscriber::prelude::*;

/// API State
Expand Down Expand Up @@ -2576,7 +2579,13 @@ async fn databases_query_run(
)
}
Some(tables) => {
match query_database(&tables, state.store.clone(), &payload.query).await {
match execute_query_on_transient_database(
&tables,
state.store.clone(),
&payload.query,
)
.await
{
Err(QueryDatabaseError::TooManyResultRows) => error_response(
StatusCode::BAD_REQUEST,
"too_many_result_rows",
Expand Down
18 changes: 11 additions & 7 deletions core/src/blocks/database.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
use crate::blocks::block::{parse_pair, Block, BlockResult, BlockType, Env};
use crate::databases::database::{query_database, QueryDatabaseError};
use crate::Rule;
use anyhow::{anyhow, Result};
use async_trait::async_trait;

use pest::iterators::Pair;
use serde_json::{json, Value};
use tokio::sync::mpsc::UnboundedSender;

use super::block::replace_variables_in_string;
use super::database_schema::load_tables_from_identifiers;
use crate::{
blocks::{
block::{parse_pair, replace_variables_in_string, Block, BlockResult, BlockType, Env},
database_schema::load_tables_from_identifiers,
},
databases::{
database::QueryDatabaseError, transient_database::execute_query_on_transient_database,
},
Rule,
};

#[derive(Clone)]
pub struct Database {
Expand Down Expand Up @@ -103,7 +107,7 @@ impl Block for Database {
let query = replace_variables_in_string(&self.query, "query", env)?;
let tables = load_tables_from_identifiers(&table_identifiers, env).await?;

match query_database(&tables, env.store.clone(), &query).await {
match execute_query_on_transient_database(&tables, env.store.clone(), &query).await {
Ok((results, schema)) => Ok(BlockResult {
value: json!({
"results": results,
Expand Down
18 changes: 13 additions & 5 deletions core/src/blocks/database_schema.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
use super::helpers::get_data_source_project_and_view_filter;
use crate::blocks::block::{Block, BlockResult, BlockType, Env};
use crate::databases::database::{get_unique_table_names_for_database, Table};
use crate::Rule;
use anyhow::{anyhow, Ok, Result};
use async_trait::async_trait;
use futures::future::try_join_all;
use itertools::Itertools;
use pest::iterators::Pair;
use serde_json::{json, Value};
use tokio::sync::mpsc::UnboundedSender;

use crate::{
blocks::{
block::{Block, BlockResult, BlockType, Env},
helpers::get_data_source_project_and_view_filter,
},
databases::{
database::Table, transient_database::get_unique_table_names_for_transient_database,
},
Rule,
};

#[derive(Clone)]
pub struct DatabaseSchema {}

Expand Down Expand Up @@ -74,7 +82,7 @@ impl Block for DatabaseSchema {
let mut tables = load_tables_from_identifiers(&table_identifiers, env).await?;

// Compute the unique table names for each table.
let unique_table_names = get_unique_table_names_for_database(&tables);
let unique_table_names = get_unique_table_names_for_transient_database(&tables);

// Load the schema for each table.
// If the schema cache is stale, this will update it in place.
Expand Down
138 changes: 16 additions & 122 deletions core/src/databases/database.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
use std::collections::HashMap;

use super::table_schema::TableSchema;
use crate::{
databases_store::store::DatabasesStore,
project::Project,
search_filter::{Filterable, SearchFilter},
sqlite_workers::client::{SqliteWorker, SqliteWorkerError, HEARTBEAT_INTERVAL_MS},
sqlite_workers::client::HEARTBEAT_INTERVAL_MS,
stores::store::Store,
utils,
};
use anyhow::{anyhow, Result};
use futures::future::try_join_all;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
use tracing::info;

#[derive(Debug, Clone, Copy, Serialize, PartialEq, Deserialize)]
#[serde(rename_all = "lowercase")]
Expand Down Expand Up @@ -43,99 +39,6 @@ pub enum QueryDatabaseError {
ExecutionError(String),
}

impl From<SqliteWorkerError> for QueryDatabaseError {
fn from(e: SqliteWorkerError) -> Self {
match &e {
SqliteWorkerError::TooManyResultRows => QueryDatabaseError::TooManyResultRows,
SqliteWorkerError::QueryExecutionError(msg) => {
QueryDatabaseError::ExecutionError(msg.clone())
}
_ => QueryDatabaseError::GenericError(e.into()),
}
}
}

pub async fn query_database(
tables: &Vec<Table>,
store: Box<dyn Store + Sync + Send>,
query: &str,
) -> Result<(Vec<QueryResult>, TableSchema), QueryDatabaseError> {
let table_ids_hash = tables.iter().map(|t| t.unique_id()).sorted().join("/");
let database = store
.upsert_database(&table_ids_hash, HEARTBEAT_INTERVAL_MS)
.await?;

let time_query_start = utils::now();

let result_rows = match database.sqlite_worker() {
Some(sqlite_worker) => {
let result_rows = sqlite_worker
.execute_query(&table_ids_hash, tables, query)
.await?;
result_rows
}
None => Err(anyhow!(
"No live SQLite worker found for database {}",
database.table_ids_hash
))?,
};

info!(
duration = utils::now() - time_query_start,
"DSSTRUCTSTAT Finished executing user query on worker"
);

let infer_result_schema_start = utils::now();
let table_schema = TableSchema::from_rows(&result_rows)?;

info!(
duration = utils::now() - infer_result_schema_start,
"DSSTRUCTSTAT Finished inferring schema"
);
info!(
duration = utils::now() - time_query_start,
"DSSTRUCTSTAT Finished query database"
);

Ok((result_rows, table_schema))
}

pub async fn invalidate_database(db: Database, store: Box<dyn Store + Sync + Send>) -> Result<()> {
if let Some(worker) = db.sqlite_worker() {
worker.invalidate_database(db.unique_id()).await?;
} else {
// If the worker is not alive, we delete the database row in case the worker becomes alive again.
store.delete_database(&db.table_ids_hash).await?;
}

Ok(())
}

#[derive(Debug, Serialize, Clone)]
pub struct Database {
created: u64,
table_ids_hash: String,
sqlite_worker: Option<SqliteWorker>,
}

impl Database {
pub fn new(created: u64, table_ids_hash: &str, sqlite_worker: &Option<SqliteWorker>) -> Self {
Database {
created,
table_ids_hash: table_ids_hash.to_string(),
sqlite_worker: sqlite_worker.clone(),
}
}

pub fn sqlite_worker(&self) -> &Option<SqliteWorker> {
&self.sqlite_worker
}

pub fn unique_id(&self) -> &str {
&self.table_ids_hash
}
}

#[derive(Debug, Serialize, Clone, Deserialize)]
pub struct Table {
project: Project,
Expand All @@ -157,6 +60,7 @@ pub fn get_table_unique_id(project: &Project, data_source_id: &str, table_id: &s
format!("{}__{}__{}", project.project_id(), data_source_id, table_id)
}

// TODO(@fontanierh): Support for remote DBs.
impl Table {
pub fn new(
project: &Project,
Expand Down Expand Up @@ -273,7 +177,13 @@ impl Table {
)
.await?)
.into_iter()
.map(|db| invalidate_database(db, store.clone())),
.map(|db| {
let store = store.clone();
async move {
db.invalidate(store).await?;
Ok::<_, anyhow::Error>(())
}
}),
)
.await?;

Expand Down Expand Up @@ -365,7 +275,13 @@ impl Table {
)
.await?)
.into_iter()
.map(|db| invalidate_database(db, store.clone())),
.map(|db| {
let store = store.clone();
async move {
db.invalidate(store).await?;
Ok::<_, anyhow::Error>(())
}
}),
)
.await?;

Expand Down Expand Up @@ -507,28 +423,6 @@ impl HasValue for QueryResult {
}
}

pub fn get_unique_table_names_for_database(tables: &[Table]) -> HashMap<String, String> {
let mut name_count: HashMap<&str, usize> = HashMap::new();

tables
.iter()
.sorted_by_key(|table| table.unique_id())
.map(|table| {
let base_name = table.name();
let count = name_count.entry(base_name).or_insert(0);
*count += 1;

(
table.unique_id(),
match *count {
1 => base_name.to_string(),
_ => format!("{}_{}", base_name, *count - 1),
},
)
})
.collect()
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
10 changes: 9 additions & 1 deletion core/src/databases/remote_databases/remote_database.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
use anyhow::Result;

use async_trait::async_trait;

use crate::databases::{
database::{QueryDatabaseError, QueryResult},
table_schema::TableSchema,
};

#[async_trait]
pub trait RemoteDatabase {
async fn get_tables_used_by_query(&self, query: &str) -> Result<Vec<String>>;
async fn execute_query(
&self,
query: &str,
) -> Result<(Vec<QueryResult>, TableSchema), QueryDatabaseError>;
}
Loading

0 comments on commit e81a040

Please sign in to comment.