diff --git a/src/common/config.rs b/src/common/config.rs index 9c3b60a5..746e2f17 100644 --- a/src/common/config.rs +++ b/src/common/config.rs @@ -230,7 +230,7 @@ impl Config { /// -- @db: postgres /// SELECT * FROM some_table; /// - /// The method figures out the correct database to connect in order to validate the SQL query + /// The method figures out the connection name to connect in order to validate the SQL query /// /// If you pass down a query with a annotation to specify a DB /// e.g. @@ -243,31 +243,18 @@ impl Config { /// e.g. /// SELECT * FROM some_table; /// - /// It should return the default connection configured by your configuration settings - pub fn get_correct_db_connection(&self, raw_sql: &str) -> DbConnectionConfig { + /// It should return the connection name that is available based on your connection configurations + pub fn get_correct_db_connection(&self, raw_sql: &str) -> String { let re = Regex::new(r"(/*|//|--) @db: (?P[\w]+)( */){0,}").unwrap(); let found_matches = re.captures(raw_sql); if let Some(found_match) = &found_matches { let detected_conn_name = &found_match[2]; - return self - .connections - .get(detected_conn_name) - .unwrap_or_else(|| { - panic!("Failed to find a matching connection type - connection name: {detected_conn_name}") - }) - .clone(); + + return detected_conn_name.to_string(); } - self.connections - .get("default") - .expect( - r"Failed to find the default connection configuration - check your configuration - CLI options: https://jasonshin.github.io/sqlx-ts/user-guide/2.1.cli-options.html - File based config: https://jasonshin.github.io/sqlx-ts/reference-guide/2.configs-file-based.html - ", - ) - .clone() + return "default".to_string(); } pub fn get_postgres_cred(&self, conn: &DbConnectionConfig) -> String { @@ -295,6 +282,7 @@ impl Config { .db_name(db_name.clone()) } + // TODO: update this to also factor in env variable pub fn get_log_level(file_config_path: &PathBuf) -> LogLevel { let file_based_config = fs::read_to_string(file_config_path); let file_based_config = &file_based_config.map(|f| serde_json::from_str::(f.as_str()).unwrap()); diff --git a/src/common/lazy.rs b/src/common/lazy.rs index 32926e2f..8efc9d75 100644 --- a/src/common/lazy.rs +++ b/src/common/lazy.rs @@ -1,18 +1,52 @@ use crate::common::cli::Cli; use crate::common::config::Config; +use crate::common::types::DatabaseType; +use crate::core::connection::{DBConn, DBConnections}; use crate::ts_generator::information_schema::DBSchema; use clap::Parser; use lazy_static::lazy_static; +use mysql::Conn as MySQLConn; +use postgres::{Client as PGClient, NoTls as PGNoTls}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; // The file contains all implicitly dependent variables or state that files need for the logic // We have a lot of states that we need to drill down into each methods lazy_static! { - pub static ref SOME_INT: i32 = 5; - pub static ref CLI_ARGS: Cli = Cli::parse(); pub static ref CONFIG: Config = Config::new(); // This is a holder for shared DBSChema used to fetch information for information_schema table // By having a singleton, we can think about caching the result if we are fetching a query too many times - pub static ref DB_SCHEMA: DBSchema = DBSchema::new(); + pub static ref DB_SCHEMA: Mutex = Mutex::new(DBSchema::new()); + + // This variable holds database connections for each connection name that is defined in the config + // We are using lazy_static to initialize the connections once and use them throughout the application + static ref DB_CONN_CACHE: HashMap>> = { + let mut cache = HashMap::new(); + for connection in CONFIG.connections.keys() { + let connection_config = CONFIG.connections.get(connection).unwrap(); + let db_type = connection_config.db_type.to_owned(); + let conn = match db_type { + DatabaseType::Mysql => { + let opts = CONFIG.get_mysql_cred(&connection_config); + let mut conn = MySQLConn::new(opts).unwrap(); + DBConn::MySQLPooledConn(Mutex::new(conn)) + } + DatabaseType::Postgres => { + let postgres_cred = &CONFIG.get_postgres_cred(&connection_config); + DBConn::PostgresConn(Mutex::new(PGClient::connect(postgres_cred, PGNoTls).unwrap())) + } + }; + cache.insert(connection.to_owned(), Arc::new(Mutex::new(conn))); + }; + cache + }; + + // This variable holds a singleton of DBConnections that is used to get a DBConn from the cache + // DBConn is used to access the raw connection to the database or run `prepare` statement against each connection + pub static ref DB_CONNECTIONS: Mutex> = { + let db_connections = DBConnections::new(&DB_CONN_CACHE); + Mutex::new(db_connections) + }; } diff --git a/src/core/connection.rs b/src/core/connection.rs new file mode 100644 index 00000000..9acfaf1e --- /dev/null +++ b/src/core/connection.rs @@ -0,0 +1,55 @@ +use crate::common::lazy::CONFIG; +use crate::common::SQL; +use crate::core::mysql::prepare as mysql_explain; +use crate::core::postgres::prepare as postgres_explain; +use crate::ts_generator::types::ts_query::TsQuery; +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::Mutex; + +use color_eyre::Result; +use mysql::Conn as MySQLConn; +use postgres::Client as PostgresConn; +use swc_common::errors::Handler; + +/// Enum to hold a specific database connection instance +pub enum DBConn { + MySQLPooledConn(Mutex), + PostgresConn(Mutex), +} + +impl DBConn { + pub fn prepare( + &self, + sql: &SQL, + should_generate_types: &bool, + handler: &Handler, + ) -> Result<(bool, Option)> { + let (explain_failed, ts_query) = match &self { + DBConn::MySQLPooledConn(_conn) => mysql_explain::prepare(&self, sql, should_generate_types, handler)?, + DBConn::PostgresConn(_conn) => postgres_explain::prepare(&self, sql, should_generate_types, handler)?, + }; + + Ok((explain_failed, ts_query)) + } +} + +pub struct DBConnections<'a> { + pub cache: &'a HashMap>>, +} + +impl<'a> DBConnections<'a> { + pub fn new(cache: &'a HashMap>>) -> Self { + Self { cache } + } + + pub fn get_connection(&mut self, raw_sql: &str) -> Arc> { + let db_conn_name = &CONFIG.get_correct_db_connection(raw_sql); + + let conn = self + .cache + .get(db_conn_name) + .expect("Failed to get the connection from cache"); + conn.to_owned() + } +} diff --git a/src/core/execute.rs b/src/core/execute.rs index 32293f36..17cad500 100644 --- a/src/core/execute.rs +++ b/src/core/execute.rs @@ -1,4 +1,5 @@ -use crate::common::lazy::{CLI_ARGS, CONFIG}; +use super::connection::DBConn; +use crate::common::lazy::{CLI_ARGS, CONFIG, DB_CONNECTIONS}; use crate::common::types::DatabaseType; use crate::common::SQL; use crate::core::mysql::prepare as mysql_explain; @@ -6,7 +7,9 @@ use crate::core::postgres::prepare as postgres_explain; use crate::ts_generator::generator::{write_colocated_ts_file, write_single_ts_file}; use color_eyre::eyre::Result; +use std::borrow::BorrowMut; use std::collections::HashMap; +use std::sync::Arc; use std::path::PathBuf; use swc_common::errors::Handler; @@ -22,19 +25,19 @@ pub fn execute(queries: &HashMap>, handler: &Handler) -> Resul for (file_path, sqls) in queries { let mut sqls_to_write: Vec = vec![]; for sql in sqls { - let connection = &CONFIG.get_correct_db_connection(&sql.query); + let mut connection = DB_CONNECTIONS.lock().unwrap(); + let connection = &connection.get_connection(&sql.query).clone(); + let connection = &connection.lock().unwrap(); - let (explain_failed, ts_query) = match connection.db_type { - DatabaseType::Postgres => postgres_explain::prepare(sql, should_generate_types, handler)?, - DatabaseType::Mysql => mysql_explain::prepare(sql, should_generate_types, handler)?, - }; + let (explain_failed, ts_query) = &connection.prepare(&sql, &should_generate_types, &handler)?; // If any prepare statement fails, we should set the failed flag as true - failed = explain_failed; + failed = explain_failed.clone(); if *should_generate_types { - let ts_query = ts_query.expect("Failed to generate types from query").to_string(); - sqls_to_write.push(ts_query); + let ts_query = &ts_query.clone().expect("Failed to generate types from query"); + let ts_query = &ts_query.to_string(); + sqls_to_write.push(ts_query.to_owned()); } } diff --git a/src/core/mod.rs b/src/core/mod.rs index 7ce4dc96..bc280965 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,3 +1,4 @@ +pub mod connection; pub mod execute; pub mod mysql; pub mod postgres; diff --git a/src/core/mysql/prepare.rs b/src/core/mysql/prepare.rs index 7d28ed03..34fdc3cb 100644 --- a/src/core/mysql/prepare.rs +++ b/src/core/mysql/prepare.rs @@ -1,29 +1,34 @@ -use crate::common::lazy::CONFIG; +use crate::common::lazy::{CONFIG, DB_CONNECTIONS}; use crate::common::SQL; +use crate::core::connection::DBConn; use crate::ts_generator::generator::generate_ts_interface; -use crate::ts_generator::types::db_conn::DBConn; use crate::ts_generator::types::ts_query::TsQuery; use color_eyre::eyre::Result; use mysql::prelude::*; use mysql::*; -use std::cell::RefCell; +use std::borrow::BorrowMut; use swc_common::errors::Handler; /// Runs the prepare statement on the input SQL. /// Validates the query is right by directly connecting to the configured database. /// It also processes ts interfaces if the configuration is set to generate_types = true -pub fn prepare(sql: &SQL, should_generate_types: &bool, handler: &Handler) -> Result<(bool, Option)> { - let connection_config = CONFIG.get_correct_db_connection(&sql.query); - let opts = CONFIG.get_mysql_cred(&connection_config); - let mut conn = Conn::new(opts)?; - +pub fn prepare( + db_conn: &DBConn, + sql: &SQL, + should_generate_types: &bool, + handler: &Handler, +) -> Result<(bool, Option)> { let mut failed = false; let span = sql.span.to_owned(); let explain_query = format!("PREPARE stmt FROM \"{}\"", sql.query); - let result: Result, _> = conn.query(explain_query); + let mut conn = match &db_conn { + DBConn::MySQLPooledConn(conn) => conn, + _ => panic!("Invalid connection type"), + }; + let result: Result, _> = conn.lock().unwrap().borrow_mut().query(explain_query); if let Err(err) = result { handler.span_bug_no_panic(span, err.to_string().as_str()); @@ -33,10 +38,7 @@ pub fn prepare(sql: &SQL, should_generate_types: &bool, handler: &Handler) -> Re let mut ts_query = None; if should_generate_types == &true { - ts_query = Some(generate_ts_interface( - sql, - &DBConn::MySQLPooledConn(&mut RefCell::new(&mut conn)), - )?); + ts_query = Some(generate_ts_interface(sql, &db_conn)?); } Ok((failed, ts_query)) diff --git a/src/core/postgres/prepare.rs b/src/core/postgres/prepare.rs index 5a6a2992..7b70a64c 100644 --- a/src/core/postgres/prepare.rs +++ b/src/core/postgres/prepare.rs @@ -1,42 +1,47 @@ -use crate::common::lazy::CONFIG; use crate::common::SQL; +use crate::core::connection::DBConn; use crate::ts_generator::generator::generate_ts_interface; -use crate::ts_generator::types::db_conn::DBConn; use crate::ts_generator::types::ts_query::TsQuery; use color_eyre::eyre::Result; -use postgres::{Client, NoTls}; -use std::cell::RefCell; +use std::borrow::BorrowMut; use swc_common::errors::Handler; /// Runs the prepare statement on the input SQL. Validates the query is right by directly connecting to the configured database. /// It also processes ts interfaces if the configuration is set to `generate_types = true` -pub fn prepare<'a>(sql: &SQL, should_generate_types: &bool, handler: &Handler) -> Result<(bool, Option)> { - let connection = &CONFIG.get_correct_db_connection(&sql.query); - let postgres_cred = &CONFIG.get_postgres_cred(connection); - let mut conn = Client::connect(postgres_cred, NoTls).unwrap(); - +pub fn prepare( + db_conn: &DBConn, + sql: &SQL, + should_generate_types: &bool, + handler: &Handler, +) -> Result<(bool, Option)> { let mut failed = false; + let mut conn = match &db_conn { + DBConn::PostgresConn(conn) => conn, + _ => panic!("Invalid connection type"), + }; let span = sql.span.to_owned(); let prepare_query = format!("PREPARE sqlx_stmt AS {}", sql.query); - let result = conn.query(prepare_query.as_str(), &[]); + let result = conn.lock().unwrap().borrow_mut().query(prepare_query.as_str(), &[]); if let Err(e) = result { handler.span_bug_no_panic(span, e.as_db_error().unwrap().message()); failed = true; } else { // We should only deallocate if the prepare statement was executed successfully - conn.query("DEALLOCATE sqlx_stmt", &[]).unwrap(); + let _ = &conn + .lock() + .unwrap() + .borrow_mut() + .query("DEALLOCATE sqlx_stmt", &[]) + .unwrap(); } let mut ts_query = None; if should_generate_types == &true { - ts_query = Some(generate_ts_interface( - sql, - &DBConn::PostgresConn(&mut RefCell::new(&mut conn)), - )?); + ts_query = Some(generate_ts_interface(sql, &db_conn)?); } Ok((failed, ts_query)) diff --git a/src/main.rs b/src/main.rs index a84d5565..84d16323 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,15 +9,13 @@ extern crate dotenv; use crate::core::execute::execute; -use dotenv::dotenv; use sqlx_ts::ts_generator::generator::clear_single_ts_file_if_exists; use std::env; -use std::io::{stderr, stdout, Write}; use crate::common::lazy::CLI_ARGS; use crate::common::logger::*; use crate::{parser::parse_source, scan_folder::scan_folder}; -use color_eyre::{eyre::eyre, eyre::Result}; +use color_eyre::eyre::Result; fn set_default_env_var() { if env::var("SQLX_TS_LOG").is_err() { diff --git a/src/ts_generator/generator.rs b/src/ts_generator/generator.rs index b60c711e..7807b3c1 100644 --- a/src/ts_generator/generator.rs +++ b/src/ts_generator/generator.rs @@ -6,10 +6,10 @@ use std::{ }; use super::annotations::extract_param_annotations; -use super::types::db_conn::DBConn; use crate::common::lazy::CONFIG; use crate::common::SQL; +use crate::core::connection::DBConn; use crate::ts_generator::annotations::extract_result_annotations; use crate::ts_generator::sql_parser::translate_stmt::translate_stmt; use crate::ts_generator::types::ts_query::TsQuery; diff --git a/src/ts_generator/information_schema.rs b/src/ts_generator/information_schema.rs index f55a4378..bb2a7152 100644 --- a/src/ts_generator/information_schema.rs +++ b/src/ts_generator/information_schema.rs @@ -1,10 +1,12 @@ use mysql; use mysql::prelude::Queryable; use postgres; -use std::cell::RefCell; +use std::borrow::BorrowMut; use std::collections::HashMap; +use std::sync::Mutex; + +use crate::core::connection::DBConn; -use super::types::db_conn::DBConn; use super::types::ts_query::TsFieldType; #[derive(Debug, Clone)] @@ -23,7 +25,6 @@ struct ColumnsQueryResultRow { } pub struct DBSchema { - /// tables cache tables_cache: HashMap, } @@ -41,14 +42,27 @@ impl DBSchema { /// /// # PostgreSQL Notes /// - TABLE_SCHEMA is PostgreSQL is basically 'public' by default. `database_name` is the name of the database itself - pub fn fetch_table(&self, table_name: &Vec<&str>, conn: &DBConn) -> Option { - match &conn { + pub fn fetch_table(&mut self, table_name: &Vec<&str>, conn: &DBConn) -> Option { + let table_key: String = table_name.join(","); + let cached_table_result = self.tables_cache.get(table_key.as_str()); + + if let Some(cached_table_result) = cached_table_result { + return Some(cached_table_result.clone()); + } + + let result = match &conn { DBConn::MySQLPooledConn(conn) => Self::mysql_fetch_table(self, table_name, conn), DBConn::PostgresConn(conn) => Self::postgres_fetch_table(self, table_name, conn), + }; + + if let Some(result) = &result { + let _ = &self.tables_cache.insert(table_key, result.clone()); } + + result } - fn postgres_fetch_table(&self, table_names: &Vec<&str>, conn: &RefCell<&mut postgres::Client>) -> Option { + fn postgres_fetch_table(&self, table_names: &Vec<&str>, conn: &Mutex) -> Option { let table_names = table_names .iter() .map(|x| format!("'{x}'")) @@ -69,7 +83,7 @@ impl DBSchema { ); let mut fields: HashMap = HashMap::new(); - let result = conn.borrow_mut().query(&query, &[]); + let result = conn.lock().unwrap().borrow_mut().query(&query, &[]); if let Ok(result) = result { for row in result { @@ -89,7 +103,7 @@ impl DBSchema { None } - fn mysql_fetch_table(&self, table_names: &Vec<&str>, conn: &RefCell<&mut mysql::Conn>) -> Option { + fn mysql_fetch_table(&self, table_names: &Vec<&str>, conn: &Mutex) -> Option { let table_names = table_names .iter() .map(|x| format!("'{x}'")) @@ -109,7 +123,7 @@ impl DBSchema { ); let mut fields: HashMap = HashMap::new(); - let result = conn.borrow_mut().query::(query); + let result = conn.lock().unwrap().borrow_mut().query::(query); if let Ok(result) = result { for row in result { diff --git a/src/ts_generator/sql_parser/expressions/translate_expr.rs b/src/ts_generator/sql_parser/expressions/translate_expr.rs index 0f44cf3c..99122c72 100644 --- a/src/ts_generator/sql_parser/expressions/translate_expr.rs +++ b/src/ts_generator/sql_parser/expressions/translate_expr.rs @@ -1,5 +1,6 @@ use crate::common::lazy::{CONFIG, DB_SCHEMA}; use crate::common::logger::warning; +use crate::core::connection::DBConn; use crate::ts_generator::errors::TsGeneratorError; use crate::ts_generator::sql_parser::expressions::translate_data_type::translate_value; use crate::ts_generator::sql_parser::expressions::translate_table_with_joins::translate_table_from_expr; @@ -7,7 +8,6 @@ use crate::ts_generator::sql_parser::expressions::{ functions::is_string_function, translate_data_type::translate_data_type, }; use crate::ts_generator::sql_parser::translate_query::translate_query; -use crate::ts_generator::types::db_conn::DBConn; use crate::ts_generator::types::ts_query::{TsFieldType, TsQuery}; use convert_case::{Case, Casing}; use regex::Regex; @@ -124,6 +124,8 @@ pub fn get_sql_query_param( let table_names = vec![table_name.as_str()]; let column_name = column_name.unwrap(); let columns = DB_SCHEMA + .lock() + .unwrap() .fetch_table(&table_names, db_conn) .unwrap_or_else(|| panic!("Failed to fetch columns for table {:?}", table_name)); @@ -155,7 +157,7 @@ pub fn translate_expr( Expr::Identifier(ident) => { let column_name = ident.value.to_string(); let table_name = single_table_name.expect("Missing table name for identifier"); - let table_details = &DB_SCHEMA.fetch_table(&vec![table_name], db_conn); + let table_details = &DB_SCHEMA.lock().unwrap().fetch_table(&vec![table_name], db_conn); // TODO: We can also memoize this method if let Some(table_details) = table_details { @@ -178,7 +180,10 @@ pub fn translate_expr( let table_name = translate_table_from_expr(table_with_joins, &expr) .ok_or_else(|| TsGeneratorError::IndentifierWithoutTable(expr.to_string()))?; - let table_details = &DB_SCHEMA.fetch_table(&vec![table_name.as_str()], db_conn); + let table_details = &DB_SCHEMA + .lock() + .unwrap() + .fetch_table(&vec![table_name.as_str()], db_conn); if let Some(table_details) = table_details { let field = table_details.get(&ident).unwrap(); @@ -516,7 +521,11 @@ pub fn translate_assignment( let value = get_expr_placeholder(&assignment.value); if value.is_some() { - let table_details = &DB_SCHEMA.fetch_table(&vec![table_name], db_conn).unwrap(); + let table_details = &DB_SCHEMA + .lock() + .unwrap() + .fetch_table(&vec![table_name], db_conn) + .unwrap(); let column_name = translate_column_name_assignment(assignment).unwrap(); let field = table_details .get(&column_name) diff --git a/src/ts_generator/sql_parser/expressions/translate_wildcard_expr.rs b/src/ts_generator/sql_parser/expressions/translate_wildcard_expr.rs index a4dccb4f..9bb666b6 100644 --- a/src/ts_generator/sql_parser/expressions/translate_wildcard_expr.rs +++ b/src/ts_generator/sql_parser/expressions/translate_wildcard_expr.rs @@ -1,8 +1,9 @@ use crate::common::lazy::DB_SCHEMA; use crate::common::logger::warning; +use crate::core::connection::DBConn; use crate::ts_generator::errors::TsGeneratorError; +use crate::ts_generator::types::ts_query::TsFieldType; use crate::ts_generator::types::ts_query::TsQuery; -use crate::ts_generator::types::{db_conn::DBConn, ts_query::TsFieldType}; use color_eyre::eyre::Result; use sqlparser::ast::{Join, Query, SetExpr, TableFactor, TableWithJoins}; @@ -61,7 +62,7 @@ pub fn translate_wildcard_expr( } let table_with_joins = table_with_joins.iter().map(|s| s.as_ref()).collect(); - let all_fields = DB_SCHEMA.fetch_table(&table_with_joins, db_conn); + let all_fields = DB_SCHEMA.lock().unwrap().fetch_table(&table_with_joins, db_conn); if let Some(all_fields) = all_fields { for key in all_fields.keys() { let field = all_fields.get(key).unwrap(); diff --git a/src/ts_generator/sql_parser/translate_delete.rs b/src/ts_generator/sql_parser/translate_delete.rs index 28a2d224..58998d63 100644 --- a/src/ts_generator/sql_parser/translate_delete.rs +++ b/src/ts_generator/sql_parser/translate_delete.rs @@ -1,6 +1,6 @@ +use crate::core::connection::DBConn; use crate::ts_generator::errors::TsGeneratorError; use crate::ts_generator::sql_parser::expressions::translate_expr::translate_expr; -use crate::ts_generator::types::db_conn::DBConn; use crate::ts_generator::types::ts_query::TsQuery; use sqlparser::ast::Expr; diff --git a/src/ts_generator/sql_parser/translate_insert.rs b/src/ts_generator/sql_parser/translate_insert.rs index 20296930..d34a4adf 100644 --- a/src/ts_generator/sql_parser/translate_insert.rs +++ b/src/ts_generator/sql_parser/translate_insert.rs @@ -1,8 +1,9 @@ +use crate::core::connection::DBConn; use crate::ts_generator::sql_parser::expressions::translate_expr::get_expr_placeholder; use sqlparser::ast::{Ident, Query, SetExpr}; use crate::common::lazy::DB_SCHEMA; -use crate::ts_generator::{errors::TsGeneratorError, types::db_conn::DBConn, types::ts_query::TsQuery}; +use crate::ts_generator::{errors::TsGeneratorError, types::ts_query::TsQuery}; pub fn translate_insert( ts_query: &mut TsQuery, @@ -11,7 +12,9 @@ pub fn translate_insert( table_name: &str, conn: &DBConn, ) -> Result<(), TsGeneratorError> { - let table_details = DB_SCHEMA + let table_details = &DB_SCHEMA + .lock() + .unwrap() .fetch_table(&vec![table_name], conn) // Nearly impossible to panic at this point as we've already validated queries with prepare statements .unwrap(); diff --git a/src/ts_generator/sql_parser/translate_query.rs b/src/ts_generator/sql_parser/translate_query.rs index 31995621..f795adba 100644 --- a/src/ts_generator/sql_parser/translate_query.rs +++ b/src/ts_generator/sql_parser/translate_query.rs @@ -1,8 +1,11 @@ use sqlparser::ast::{Query, SelectItem, SetExpr, TableWithJoins}; -use crate::ts_generator::{ - errors::TsGeneratorError, sql_parser::expressions::translate_table_with_joins::get_default_table, - types::db_conn::DBConn, types::ts_query::TsQuery, +use crate::{ + core::connection::DBConn, + ts_generator::{ + errors::TsGeneratorError, sql_parser::expressions::translate_table_with_joins::get_default_table, + types::ts_query::TsQuery, + }, }; use super::expressions::{ diff --git a/src/ts_generator/sql_parser/translate_stmt.rs b/src/ts_generator/sql_parser/translate_stmt.rs index 234be52b..23b2adcd 100644 --- a/src/ts_generator/sql_parser/translate_stmt.rs +++ b/src/ts_generator/sql_parser/translate_stmt.rs @@ -1,10 +1,10 @@ +use crate::core::connection::DBConn; use crate::ts_generator::errors::TsGeneratorError; use crate::ts_generator::sql_parser::translate_delete::translate_delete; use crate::ts_generator::sql_parser::translate_insert::translate_insert; use crate::ts_generator::sql_parser::translate_query::translate_query; use crate::ts_generator::sql_parser::translate_update::translate_update; -use crate::ts_generator::types::db_conn::DBConn; use crate::ts_generator::types::ts_query::TsQuery; use sqlparser::ast::Statement; diff --git a/src/ts_generator/sql_parser/translate_update.rs b/src/ts_generator/sql_parser/translate_update.rs index 631d9ec6..1005a410 100644 --- a/src/ts_generator/sql_parser/translate_update.rs +++ b/src/ts_generator/sql_parser/translate_update.rs @@ -1,8 +1,8 @@ use sqlparser::ast::{Assignment, Expr, TableWithJoins}; -use crate::ts_generator::{ - errors::TsGeneratorError, - types::{db_conn::DBConn, ts_query::TsQuery}, +use crate::{ + core::connection::DBConn, + ts_generator::{errors::TsGeneratorError, types::ts_query::TsQuery}, }; use super::expressions::{ diff --git a/src/ts_generator/types/db_conn.rs b/src/ts_generator/types/db_conn.rs index e0d02ce6..8b137891 100644 --- a/src/ts_generator/types/db_conn.rs +++ b/src/ts_generator/types/db_conn.rs @@ -1,11 +1 @@ -use std::cell::RefCell; -use mysql::Conn as MySQLConn; -use postgres::Client as PostgresConn; - -/// Enum to hold a specific database connection instance -pub enum DBConn<'a> { - // TODO: Maybe we can also pass down db_name through DBConn - MySQLPooledConn(&'a mut RefCell<&'a mut MySQLConn>), - PostgresConn(&'a mut RefCell<&'a mut PostgresConn>), -} diff --git a/src/ts_generator/types/ts_query.rs b/src/ts_generator/types/ts_query.rs index d3ef7fc9..69f3f3e3 100644 --- a/src/ts_generator/types/ts_query.rs +++ b/src/ts_generator/types/ts_query.rs @@ -153,7 +153,7 @@ impl TsFieldType { /// /// There are tests under `tests` folder that checks TsQuery generates the /// correct type definitions -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct TsQuery { pub name: String, param_order: i32, @@ -196,13 +196,6 @@ impl TsQuery { self.annotated_params = annotated_params; } - pub fn set_annotated_insert_params( - &mut self, - annotated_insert_params: BTreeMap>, - ) { - self.annotated_insert_params = annotated_insert_params; - } - pub fn format_column_name(&self, column_name: &str) -> String { let convert_to_camel_case_column_name = &CONFIG .generate_types_config diff --git a/tests/alias.rs b/tests/alias.rs index 5e55982c..ec6f4e48 100644 --- a/tests/alias.rs +++ b/tests/alias.rs @@ -1,13 +1,11 @@ #[cfg(test)] mod alias { use assert_cmd::prelude::*; - use pretty_assertions::assert_eq; - use std::env::current_dir; + use std::fs; use std::io::Write; use std::process::Command; use tempfile::tempdir; - use walkdir::WalkDir; #[test] fn should_warn_on_clashing_field_names_on_join() -> Result<(), Box> { @@ -21,8 +19,8 @@ JOIN tables ON items.table_id = tables.id // SETUP let dir = tempdir()?; let parent_path = dir.path(); - let file_path = parent_path.join(format!("index.ts")); - let mut temp_file = fs::File::create(&file_path)?; + let file_path = parent_path.join("index.ts".to_string()); + let mut temp_file = fs::File::create(file_path)?; writeln!(temp_file, "{}", ts_content)?; // EXECUTE @@ -56,8 +54,8 @@ JOIN tables ON items.table_id = tables.id // SETUP let dir = tempdir().unwrap(); let parent_path = dir.path(); - let file_path = parent_path.join(format!("index.ts")); - let mut temp_file = fs::File::create(&file_path).unwrap(); + let file_path = parent_path.join("index.ts".to_string()); + let mut temp_file = fs::File::create(file_path).unwrap(); writeln!(temp_file, "{}", ts_content).unwrap(); // EXECUTE @@ -89,8 +87,8 @@ JOIN tables ON items.table_id = tables.id // SETUP let dir = tempdir()?; let parent_path = dir.path(); - let file_path = parent_path.join(format!("index.ts")); - let mut temp_file = fs::File::create(&file_path)?; + let file_path = parent_path.join("index.ts".to_string()); + let mut temp_file = fs::File::create(file_path)?; writeln!(temp_file, "{}", ts_content)?; // EXECUTE