Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Override some PG field types from a parsed query #253

Merged
merged 2 commits into from
Aug 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 176 additions & 73 deletions crates/corro-pg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ use rusqlite::{
};
use spawn::spawn_counted;
use sqlite3_parser::ast::{
As, Cmd, ColumnDefinition, CreateTableBody, Expr, FromClause, Id, InsertBody, Limit, Name,
OneSelect, ResultColumn, Select, SelectTable, Stmt, With,
As, Cmd, ColumnDefinition, CreateTableBody, Expr, FromClause, Id, InsertBody, Limit, Literal,
Name, OneSelect, ResultColumn, Select, SelectBody, SelectTable, Stmt, With,
};
use sqlparser::ast::Statement as PgStatement;
use tokio::{
Expand Down Expand Up @@ -848,27 +848,18 @@ pub async fn start(
.collect();
}

let mut fields = vec![];
for col in prepped.columns() {
let col_type = match name_to_type(
col.decl_type().unwrap_or("text"),
) {
Ok(t) => t,
Err(e) => {
back_tx
.blocking_send((e.into(), true).into())?;
discard_until_sync = true;
continue 'outer;
}
};
fields.push(FieldInfo::new(
col.name().to_string(),
None,
None,
col_type,
FieldFormat::Text,
));
}
let fields = match field_types(
&prepped,
&parsed_cmd,
FieldFormats::All(FieldFormat::Text),
) {
Ok(fields) => fields,
Err(e) => {
back_tx.blocking_send((e.into(), true).into())?;
discard_until_sync = true;
continue 'outer;
}
};

prepared.insert(
name.into(),
Expand Down Expand Up @@ -985,38 +976,29 @@ pub async fn start(
Some(Portal::Parsed {
stmt,
result_formats,
cmd,
..
}) => {
let mut oids = vec![];
let mut fields = vec![];
for (i, col) in stmt.columns().into_iter().enumerate() {
let col_type =
match name_to_type(
col.decl_type().unwrap_or("text"),
) {
Ok(t) => t,
Err(e) => {
back_tx.blocking_send((
let fields = match field_types(
stmt,
cmd,
FieldFormats::Each(&result_formats),
) {
Ok(fields) => fields,
Err(e) => {
back_tx.blocking_send(
(
PgWireBackendMessage::ErrorResponse(
e.into(),
),
true,
).into())?;
continue 'outer;
}
};
oids.push(col_type.oid());
fields.push(FieldInfo::new(
col.name().to_string(),
None,
None,
col_type,
result_formats
.get(i)
.copied()
.unwrap_or(FieldFormat::Text),
));
}
)
.into(),
)?;
continue 'outer;
}
};

back_tx.blocking_send(
(
PgWireBackendMessage::RowDescription(
Expand Down Expand Up @@ -1808,17 +1790,7 @@ impl Session {
conn.prepare(&cmd.to_string())?
};

let mut fields = vec![];
for col in prepped.columns() {
let col_type = name_to_type(col.decl_type().unwrap_or("text"))?;
fields.push(FieldInfo::new(
col.name().to_string(),
None,
None,
col_type,
FieldFormat::Text,
));
}
let fields = field_types(&prepped, cmd, FieldFormats::All(FieldFormat::Text))?;

if send_row_desc {
back_tx
Expand Down Expand Up @@ -1914,19 +1886,7 @@ impl Session {
back_tx: &Sender<BackendResponse>,
) -> Result<(), QueryError> {
// TODO: maybe we don't need to recompute this...
let mut fields = vec![];
for (i, col) in prepped.columns().into_iter().enumerate() {
trace!("col decl_type: {:?}", col.decl_type());
let col_type = name_to_type(col.decl_type().unwrap_or("any"))?;

fields.push(FieldInfo::new(
col.name().to_string(),
None,
None,
col_type,
result_formats.get(i).copied().unwrap_or(FieldFormat::Text),
));
}
let fields = field_types(prepped, cmd, FieldFormats::Each(result_formats))?;

trace!("fields: {fields:?}");

Expand Down Expand Up @@ -3048,6 +3008,141 @@ fn parameter_types<'schema, 'stmt>(
params
}

enum FieldFormats<'a> {
All(FieldFormat),
Each(&'a [FieldFormat]),
}

impl<'a> FieldFormats<'a> {
fn get(&self, i: usize) -> FieldFormat {
match self {
FieldFormats::All(format) => *format,
FieldFormats::Each(formats) => formats.get(i).copied().unwrap_or(FieldFormat::Text),
}
}
}

fn field_types(
prepped: &Statement,
parsed_cmd: &ParsedCmd,
field_formats: FieldFormats<'_>,
) -> Result<Vec<FieldInfo>, UnsupportedSqliteToPostgresType> {
let mut field_type_overrides = HashMap::new();

match parsed_cmd {
ParsedCmd::Sqlite(Cmd::Stmt(stmt)) => match stmt {
Stmt::Select(Select {
body:
SelectBody {
select: OneSelect::Select { columns: cols, .. },
..
},
..
})
| Stmt::Delete {
returning: Some(cols),
..
}
| Stmt::Insert {
returning: Some(cols),
..
}
| Stmt::Update {
returning: Some(cols),
..
} => {
for (i, col) in cols.iter().enumerate() {
if let ResultColumn::Expr(expr, _as) = col {
let type_override = match expr {
Expr::Cast { type_name, .. } => Some(name_to_type(&type_name.name)?),
Expr::FunctionCall { name, .. }
| Expr::FunctionCallStar { name, .. } => {
match name.0.as_str().to_uppercase().as_ref() {
"COUNT" => Some(Type::INT8),
_ => None,
}
}
Expr::Literal(lit) => match lit {
Literal::Numeric(s) => Some(if s.contains('.') {
Type::FLOAT8
} else {
Type::INT8
}),
Literal::String(_) => Some(Type::TEXT),
Literal::Blob(_) => Some(Type::BYTEA),
Literal::Keyword(_) => None,
Literal::Null => None,
Literal::CurrentDate => Some(Type::DATE),
Literal::CurrentTime => Some(Type::TIME),
Literal::CurrentTimestamp => Some(Type::TIMESTAMP),
},
_ => None,
};
if let Some(type_override) = type_override {
match prepped.column_name(i) {
Ok(col_name) => {
field_type_overrides.insert(col_name, type_override);
}
Err(e) => {
error!("col index didn't exist at {i}, attempted to override type as: {type_override}: {e}");
}
}
}
} else {
break;
}
}
}
_ => {}
},
ParsedCmd::Postgres(_stmt) => {
// TODO: handle type overrides here too
// let cols = match stmt {
// PgStatement::Insert { returning, .. }
// | PgStatement::Update { returning, .. }
// | PgStatement::Delete { returning, .. } => {
// returning
// }
// PgStatement::Query(query) => {
// match *query.body {
// sqlparser::ast::SetExpr::Select(
// select,
// ) => Some(select.projection),
// _ => None,
// }
// }
// _ => None,
// };

// if let Some(cols) = cols {

// }
}
_ => {}
}

let mut fields = vec![];
for (i, col) in prepped.columns().iter().enumerate() {
let col_name = col.name();
let col_type = match field_type_overrides.remove(col_name) {
Some(t) => t,
None => match col.decl_type() {
None => Type::TEXT,
Some(decl_type) => name_to_type(decl_type)?,
},
};
fields.push(FieldInfo::new(
col_name.to_string(),
None,
None,
col_type,
field_formats.get(i),
));
}

Ok(fields)
}

#[cfg(test)]
mod tests {
use std::time::{Duration, Instant};
Expand Down Expand Up @@ -3253,6 +3348,14 @@ mod tests {
println!("updated_at: {updated_at:?}");

assert_eq!(future, updated_at);

let row = client
.query_one(
"SELECT COUNT(*) AS yep, COUNT(id) yeppers FROM kitchensink",
&[],
)
.await?;
println!("COUNT ROW: {row:?}");
}

tripwire_tx.send(()).await.ok();
Expand Down
Loading