diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 5d61238..8babd31 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -16,5 +16,9 @@ datafusion-federation-flight-sql.path = "../sources/flight-sql" connectorx = { git = "https://github.com/sfu-db/connector-x.git", rev = "fa0fc7bc", features = [ "dst_arrow", "src_sqlite", + "src_postgres", ] } tonic = "0.10.2" + +[dependencies] +async-std = "1.12.0" diff --git a/examples/examples/postgres-partial.rs b/examples/examples/postgres-partial.rs new file mode 100644 index 0000000..873dd40 --- /dev/null +++ b/examples/examples/postgres-partial.rs @@ -0,0 +1,70 @@ +use std::sync::Arc; +use tokio::task; + +use datafusion::{ + catalog::schema::SchemaProvider, + error::Result, + execution::context::{SessionContext, SessionState}, +}; +use datafusion_federation::{FederatedQueryPlanner, FederationAnalyzerRule}; +use datafusion_federation_sql::connectorx::CXExecutor; +use datafusion_federation_sql::{MultiSchemaProvider, SQLFederationProvider, SQLSchemaProvider}; + +#[tokio::main] +async fn main() -> Result<()> { + let state = SessionContext::new().state(); + // Register FederationAnalyzer + // TODO: Interaction with other analyzers & optimizers. + let state = state + .add_analyzer_rule(Arc::new(FederationAnalyzerRule::new())) + .with_query_planner(Arc::new(FederatedQueryPlanner::new())); + + let df = task::spawn_blocking(move || { + // Register schema + let pg_provider_1 = async_std::task::block_on(create_postgres_provider(vec!["class"], "conn1")).unwrap(); + let pg_provider_2 = async_std::task::block_on(create_postgres_provider(vec!["teacher"], "conn2")).unwrap(); + let provider = MultiSchemaProvider::new(vec![ + pg_provider_1, + pg_provider_2, + ]); + + overwrite_default_schema(&state, Arc::new(provider)).unwrap(); + + // Run query + let ctx = SessionContext::new_with_state(state); + let query = r#"SELECT class.name AS classname, teacher.name AS teachername FROM class JOIN teacher ON class.id = teacher.class_id"#; + let df = async_std::task::block_on(ctx.sql(query)).unwrap(); + + df + }).await.unwrap(); + + task::spawn_blocking(move || async_std::task::block_on(df.show())) + .await + .unwrap() +} + +async fn create_postgres_provider( + known_tables: Vec<&str>, + context: &str, +) -> Result> { + let dsn = "postgresql://:@localhost:/".to_string(); + let known_tables: Vec = known_tables.iter().map(|&x| x.into()).collect(); + let mut executor = CXExecutor::new(dsn)?; + executor.context(context.to_string()); + let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor))); + Ok(Arc::new( + SQLSchemaProvider::new_with_tables(provider, known_tables).await?, + )) +} + +fn overwrite_default_schema(state: &SessionState, schema: Arc) -> Result<()> { + let options = &state.config().options().catalog; + let catalog = state + .catalog_list() + .catalog(options.default_catalog.as_str()) + .unwrap(); + + catalog.register_schema(options.default_schema.as_str(), schema)?; + + Ok(()) +} diff --git a/sources/flight-sql/src/executor/mod.rs b/sources/flight-sql/src/executor/mod.rs index 6da17e7..a5c5a38 100644 --- a/sources/flight-sql/src/executor/mod.rs +++ b/sources/flight-sql/src/executor/mod.rs @@ -4,6 +4,7 @@ use async_trait::async_trait; use datafusion::{ error::{DataFusionError, Result}, physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}, + sql::sqlparser::dialect::{Dialect, GenericDialect}, }; use datafusion_federation_sql::SQLExecutor; use futures::TryStreamExt; @@ -93,6 +94,10 @@ impl SQLExecutor for FlightSQLExecutor { let schema = flight_info.try_decode_schema().map_err(arrow_error_to_df)?; Ok(Arc::new(schema)) } + + fn dialect(&self) -> Arc { + Arc::new(GenericDialect {}) + } } fn arrow_error_to_df(err: ArrowError) -> DataFusionError { diff --git a/sources/sql/src/connectorx/executor.rs b/sources/sql/src/connectorx/executor.rs index 1cefb04..a66dfac 100644 --- a/sources/sql/src/connectorx/executor.rs +++ b/sources/sql/src/connectorx/executor.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use connectorx::{ destinations::arrow::ArrowDestinationError, errors::{ConnectorXError, ConnectorXOutError}, - prelude::{get_arrow, CXQuery, SourceConn}, + prelude::{get_arrow, CXQuery, SourceConn, SourceType}, }; use datafusion::{ arrow::datatypes::{Field, Schema, SchemaRef}, @@ -10,8 +10,11 @@ use datafusion::{ physical_plan::{ stream::RecordBatchStreamAdapter, EmptyRecordBatchStream, SendableRecordBatchStream, }, + sql::sqlparser::dialect::{Dialect, GenericDialect, PostgreSqlDialect, SQLiteDialect}, }; +use futures::executor::block_on; use std::sync::Arc; +use tokio::task; use crate::executor::SQLExecutor; @@ -54,7 +57,10 @@ impl SQLExecutor for CXExecutor { let conn = self.conn.clone(); let query: CXQuery = sql.into(); - let mut dst = get_arrow(&conn, None, &[query.clone()]).map_err(cx_out_error_to_df)?; + let mut dst = block_on(task::spawn_blocking(move || -> Result<_, _> { + get_arrow(&conn, None, &[query.clone()]).map_err(cx_out_error_to_df) + })) + .map_err(|err| DataFusionError::External(err.to_string().into()))??; let stream = if let Some(batch) = dst.record_batch().map_err(cx_dst_error_to_df)? { futures::stream::once(async move { Ok(batch) }) } else { @@ -84,6 +90,14 @@ impl SQLExecutor for CXExecutor { let schema = schema_to_lowercase(dst.arrow_schema()); Ok(schema) } + + fn dialect(&self) -> Arc { + match &self.conn.ty { + SourceType::Postgres => Arc::new(PostgreSqlDialect {}), + SourceType::SQLite => Arc::new(SQLiteDialect {}), + _ => Arc::new(GenericDialect {}), + } + } } fn cx_dst_error_to_df(err: ArrowDestinationError) -> DataFusionError { diff --git a/sources/sql/src/executor.rs b/sources/sql/src/executor.rs index c03dca4..7f05910 100644 --- a/sources/sql/src/executor.rs +++ b/sources/sql/src/executor.rs @@ -2,6 +2,7 @@ use async_trait::async_trait; use core::fmt; use datafusion::{ arrow::datatypes::SchemaRef, error::Result, physical_plan::SendableRecordBatchStream, + sql::sqlparser::dialect::Dialect, }; use std::sync::Arc; @@ -16,6 +17,9 @@ pub trait SQLExecutor: Sync + Send { /// such as authorization or active database. fn compute_context(&self) -> Option; + // The specific SQL dialect (currently supports 'sqlite', 'postgres', 'flight') + fn dialect(&self) -> Arc; + // Execution /// Execute a SQL query fn execute(&self, query: &str, schema: SchemaRef) -> Result; diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 4b09ce6..ed0f1ca 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -168,7 +168,7 @@ impl ExecutionPlan for VirtualExecutionPlan { _partition: usize, _context: Arc, ) -> Result { - let ast = query_to_sql(&self.plan)?; + let ast = query_to_sql(&self.plan, self.executor.dialect())?; let query = format!("{ast}"); self.executor.execute(query.as_str(), self.schema()) diff --git a/sources/sql/src/producer.rs b/sources/sql/src/producer.rs index 045a4aa..5c9f365 100644 --- a/sources/sql/src/producer.rs +++ b/sources/sql/src/producer.rs @@ -17,13 +17,16 @@ use datafusion::logical_expr::expr::{ }; use datafusion::logical_expr::{Between, LogicalPlan, Operator}; use datafusion::prelude::Expr; +use datafusion::sql::sqlparser::dialect::{ + Dialect, GenericDialect, PostgreSqlDialect, SQLiteDialect, +}; use crate::ast_builder::{ BuilderError, QueryBuilder, RelationBuilder, SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder, }; -pub fn query_to_sql(plan: &LogicalPlan) -> Result { +pub fn query_to_sql(plan: &LogicalPlan, dialect: Arc) -> Result { match plan { LogicalPlan::Projection(_) | LogicalPlan::Filter(_) @@ -51,6 +54,7 @@ pub fn query_to_sql(plan: &LogicalPlan) -> Result { &mut query_builder, &mut select_builder, &mut relation_builder, + dialect, )?; let mut twj = select_builder.pop_from().unwrap(); @@ -86,12 +90,14 @@ fn select_to_sql( query: &mut QueryBuilder, select: &mut SelectBuilder, relation: &mut RelationBuilder, + dialect: Arc, ) -> Result<()> { match plan { LogicalPlan::TableScan(scan) => { let mut builder = TableRelationBuilder::default(); builder.name(ast::ObjectName(vec![new_ident( scan.table_name.table().to_string(), + dialect, )])); relation.table(builder); @@ -101,18 +107,25 @@ fn select_to_sql( let items = p .expr .iter() - .map(|e| select_item_to_sql(e, p.input.schema(), 0).unwrap()) + .map(|e| select_item_to_sql(e, p.input.schema(), 0, dialect.clone()).unwrap()) .collect::>(); select.projection(items); - select_to_sql(p.input.as_ref(), query, select, relation) + select_to_sql(p.input.as_ref(), query, select, relation, dialect.clone()) } LogicalPlan::Filter(filter) => { - let filter_expr = expr_to_sql(&filter.predicate, filter.input.schema(), 0)?; + let filter_expr = + expr_to_sql(&filter.predicate, filter.input.schema(), 0, dialect.clone())?; select.selection(Some(filter_expr)); - select_to_sql(filter.input.as_ref(), query, select, relation) + select_to_sql( + filter.input.as_ref(), + query, + select, + relation, + dialect.clone(), + ) } LogicalPlan::Limit(limit) => { if let Some(fetch) = limit.fetch { @@ -122,12 +135,29 @@ fn select_to_sql( )))); } - select_to_sql(limit.input.as_ref(), query, select, relation) + select_to_sql( + limit.input.as_ref(), + query, + select, + relation, + dialect.clone(), + ) } LogicalPlan::Sort(sort) => { - query.order_by(sort_to_sql(sort.expr.clone(), sort.input.schema(), 0)?); + query.order_by(sort_to_sql( + sort.expr.clone(), + sort.input.schema(), + 0, + dialect.clone(), + )?); - select_to_sql(sort.input.as_ref(), query, select, relation) + select_to_sql( + sort.input.as_ref(), + query, + select, + relation, + dialect.clone(), + ) } LogicalPlan::Aggregate(_agg) => { not_impl_err!("Unsupported operator: {plan:?}") @@ -146,14 +176,24 @@ fn select_to_sql( // parse filter if exists let in_join_schema = join.left.schema().join(join.right.schema())?; let join_filter = match &join.filter { - Some(filter) => Some(expr_to_sql(filter, &Arc::new(in_join_schema), 0)?), + Some(filter) => Some(expr_to_sql( + filter, + &Arc::new(in_join_schema), + 0, + dialect.clone(), + )?), None => None, }; // map join.on to `l.a = r.a AND l.b = r.b AND ...` let eq_op = ast::BinaryOperator::Eq; - let join_on = - join_conditions_to_sql(&join.on, eq_op, join.left.schema(), join.right.schema())?; + let join_on = join_conditions_to_sql( + &join.on, + eq_op, + join.left.schema(), + join.right.schema(), + dialect.clone(), + )?; // Merge `join_on` and `join_filter` let join_expr = match (join_filter, join_on) { @@ -169,8 +209,14 @@ fn select_to_sql( let mut right_relation = RelationBuilder::default(); - select_to_sql(join.left.as_ref(), query, select, relation)?; - select_to_sql(join.right.as_ref(), query, select, &mut right_relation)?; + select_to_sql(join.left.as_ref(), query, select, relation, dialect.clone())?; + select_to_sql( + join.right.as_ref(), + query, + select, + &mut right_relation, + dialect.clone(), + )?; let ast_join = ast::Join { relation: right_relation.build().map_err(builder_error_to_df)?, @@ -184,9 +230,18 @@ fn select_to_sql( } LogicalPlan::SubqueryAlias(plan_alias) => { // Handle bottom-up to allocate relation - select_to_sql(plan_alias.input.as_ref(), query, select, relation)?; + select_to_sql( + plan_alias.input.as_ref(), + query, + select, + relation, + dialect.clone(), + )?; - relation.alias(Some(new_table_alias(plan_alias.alias.table().to_string()))); + relation.alias(Some(new_table_alias( + plan_alias.alias.table().to_string(), + dialect.clone(), + ))); Ok(()) } @@ -205,25 +260,31 @@ fn select_item_to_sql( expr: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, + dialect: Arc, ) -> Result { match expr { Expr::Alias(Alias { expr, name, .. }) => { - let inner = expr_to_sql(expr, schema, col_ref_offset)?; + let inner = expr_to_sql(expr, schema, col_ref_offset, dialect.clone())?; Ok(ast::SelectItem::ExprWithAlias { expr: inner, - alias: new_ident(name.to_string()), + alias: new_ident(name.to_string(), dialect.clone()), }) } _ => { - let inner = expr_to_sql(expr, schema, col_ref_offset)?; + let inner = expr_to_sql(expr, schema, col_ref_offset, dialect.clone())?; Ok(ast::SelectItem::UnnamedExpr(inner)) } } } -fn expr_to_sql(expr: &Expr, _schema: &DFSchemaRef, _col_ref_offset: usize) -> Result { +fn expr_to_sql( + expr: &Expr, + _schema: &DFSchemaRef, + _col_ref_offset: usize, + dialect: Arc, +) -> Result { match expr { Expr::InList(InList { expr, @@ -243,10 +304,10 @@ fn expr_to_sql(expr: &Expr, _schema: &DFSchemaRef, _col_ref_offset: usize) -> Re }) => { not_impl_err!("Unsupported expression: {expr:?}") } - Expr::Column(col) => col_to_sql(col), + Expr::Column(col) => col_to_sql(col, dialect.clone()), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = expr_to_sql(left.as_ref(), _schema, 0)?; - let r = expr_to_sql(right.as_ref(), _schema, 0)?; + let l = expr_to_sql(left.as_ref(), _schema, 0, dialect.clone())?; + let r = expr_to_sql(right.as_ref(), _schema, 0, dialect.clone())?; let op = op_to_sql(op)?; Ok(binary_op_to_sql(l, r, op)) @@ -262,7 +323,9 @@ fn expr_to_sql(expr: &Expr, _schema: &DFSchemaRef, _col_ref_offset: usize) -> Re not_impl_err!("Unsupported expression: {expr:?}") } Expr::Literal(value) => Ok(ast::Expr::Value(scalar_to_sql(value)?)), - Expr::Alias(Alias { expr, name: _, .. }) => expr_to_sql(expr, _schema, _col_ref_offset), + Expr::Alias(Alias { expr, name: _, .. }) => { + expr_to_sql(expr, _schema, _col_ref_offset, dialect.clone()) + } Expr::WindowFunction(WindowFunction { fun: _, args: _, @@ -289,12 +352,13 @@ fn sort_to_sql( sort_exprs: Vec, _schema: &DFSchemaRef, _col_ref_offset: usize, + dialect: Arc, ) -> Result> { sort_exprs .iter() .map(|expr: &Expr| match expr { Expr::Sort(sort_expr) => { - let col = expr_to_sql(&sort_expr.expr, _schema, _col_ref_offset)?; + let col = expr_to_sql(&sort_expr.expr, _schema, _col_ref_offset, dialect.clone())?; Ok(OrderByExpr { asc: Some(sort_expr.asc), expr: col, @@ -425,14 +489,14 @@ fn scalar_to_sql(v: &ScalarValue) -> Result { } } -fn col_to_sql(col: &Column) -> Result { +fn col_to_sql(col: &Column, dialect: Arc) -> Result { Ok(ast::Expr::CompoundIdentifier( [ col.relation.as_ref().unwrap().table().to_string(), col.name.to_string(), ] .iter() - .map(|i| new_ident(i.to_string())) + .map(|i| new_ident(i.to_string(), dialect.clone())) .collect(), )) } @@ -455,17 +519,19 @@ fn join_conditions_to_sql( eq_op: ast::BinaryOperator, left_schema: &DFSchemaRef, right_schema: &DFSchemaRef, + dialect: Arc, ) -> Result> { // Only support AND conjunction for each binary expression in join conditions let mut exprs: Vec = vec![]; for (left, right) in join_conditions { // Parse left - let l = expr_to_sql(left, left_schema, 0)?; + let l = expr_to_sql(left, left_schema, 0, dialect.clone())?; // Parse right let r = expr_to_sql( right, right_schema, left_schema.fields().len(), // offset to return the correct index + dialect.clone(), )?; // AND with existing expression exprs.push(binary_op_to_sql(l, r, eq_op.clone())); @@ -486,17 +552,23 @@ pub fn binary_op_to_sql(lhs: SQLExpr, rhs: SQLExpr, op: ast::BinaryOperator) -> } } -fn new_table_alias(alias: String) -> ast::TableAlias { +fn new_table_alias(alias: String, dialect: Arc) -> ast::TableAlias { ast::TableAlias { - name: new_ident(alias), + name: new_ident(alias, dialect.clone()), columns: Vec::new(), } } -fn new_ident(str: String) -> ast::Ident { +fn new_ident(str: String, dialect: Arc) -> ast::Ident { ast::Ident { value: str, - quote_style: Some('`'), + quote_style: if dialect.is::() { + Some('"') + } else if dialect.is::() || dialect.is::() { + Some('`') + } else { + todo!() + }, } } diff --git a/sources/sql/src/schema.rs b/sources/sql/src/schema.rs index 7e90cfa..4d69ebe 100644 --- a/sources/sql/src/schema.rs +++ b/sources/sql/src/schema.rs @@ -73,6 +73,40 @@ impl SchemaProvider for SQLSchemaProvider { } } +pub struct MultiSchemaProvider { + children: Vec>, +} + +impl MultiSchemaProvider { + pub fn new(children: Vec>) -> Self { + Self { children } + } +} + +#[async_trait] +impl SchemaProvider for MultiSchemaProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + self.children.iter().flat_map(|p| p.table_names()).collect() + } + + async fn table(&self, name: &str) -> Option> { + for child in &self.children { + if let Some(table) = child.table(name).await { + return Some(table); + } + } + None + } + + fn table_exist(&self, name: &str) -> bool { + self.children.iter().any(|p| p.table_exist(name)) + } +} + pub struct SQLTableSource { provider: Arc, table_name: String,