Skip to content

Commit

Permalink
read_postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
scsmithr committed Dec 15, 2024
1 parent e631142 commit 1402b0b
Showing 1 changed file with 51 additions and 103 deletions.
154 changes: 51 additions & 103 deletions crates/rayexec_postgres/src/read_postgres.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
use std::collections::HashMap;
use std::sync::Arc;

use futures::future::BoxFuture;
use futures::FutureExt;
use rayexec_bullet::datatype::DataTypeId;
use rayexec_bullet::field::Schema;
use rayexec_error::{OptionExt, RayexecError, Result};
use rayexec_bullet::scalar::OwnedScalarValue;
use rayexec_error::{RayexecError, Result};
use rayexec_execution::database::DatabaseContext;
use rayexec_execution::functions::table::inputs::TableFunctionInputs;
use rayexec_execution::functions::table::{PlannedTableFunction2, TableFunction};
use rayexec_execution::expr;
use rayexec_execution::functions::table::{
PlannedTableFunction,
ScanPlanner,
TableFunction,
TableFunctionImpl,
TableFunctionPlanner,
};
use rayexec_execution::functions::{FunctionInfo, Signature};
use rayexec_execution::logical::statistics::StatisticsValue;
use rayexec_execution::runtime::Runtime;
use rayexec_execution::storage::table_storage::DataTable;
use rayexec_proto::packed::{PackedDecoder, PackedEncoder};
use rayexec_proto::ProtoConv;
use serde::{Deserialize, Serialize};

use crate::{PostgresClient, PostgresDataTable};

Expand All @@ -34,82 +42,43 @@ impl<R: Runtime> FunctionInfo for ReadPostgres<R> {
}

impl<R: Runtime> TableFunction for ReadPostgres<R> {
fn plan_and_initialize<'a>(
&self,
_context: &'a DatabaseContext,
args: TableFunctionInputs,
) -> BoxFuture<'a, Result<Box<dyn PlannedTableFunction2>>> {
Box::pin(ReadPostgresImpl::initialize(self.clone(), args))
fn planner(&self) -> TableFunctionPlanner {
TableFunctionPlanner::Scan(self)
}

fn decode_state(&self, state: &[u8]) -> Result<Box<dyn PlannedTableFunction2>> {
Ok(Box::new(ReadPostgresImpl {
func: self.clone(),
state: ReadPostgresState::decode(state)?,
client: None,
}))
}
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct ReadPostgresState {
conn_str: String,
schema: String,
table: String,
table_schema: Schema,
}

impl ReadPostgresState {
fn encode(&self, buf: &mut Vec<u8>) -> Result<()> {
let mut packed = PackedEncoder::new(buf);
packed.encode_next(&self.conn_str)?;
packed.encode_next(&self.schema)?;
packed.encode_next(&self.table)?;
packed.encode_next(&self.table_schema.to_proto()?)?;
Ok(())
}

fn decode(buf: &[u8]) -> Result<Self> {
let mut packed = PackedDecoder::new(buf);
Ok(ReadPostgresState {
conn_str: packed.decode_next()?,
schema: packed.decode_next()?,
table: packed.decode_next()?,
table_schema: Schema::from_proto(packed.decode_next()?)?,
})
impl<R: Runtime> ScanPlanner for ReadPostgres<R> {
fn plan<'a>(
&self,
context: &'a DatabaseContext,
positional_inputs: Vec<OwnedScalarValue>,
named_inputs: HashMap<String, OwnedScalarValue>,
) -> BoxFuture<'a, Result<PlannedTableFunction>> {
Self::plan_inner(self.clone(), context, positional_inputs, named_inputs).boxed()
}
}

#[derive(Debug, Clone)]
struct ReadPostgresImpl<R: Runtime> {
func: ReadPostgres<R>,
state: ReadPostgresState,
client: Option<PostgresClient>,
}

impl<R> ReadPostgresImpl<R>
where
R: Runtime,
{
async fn initialize(
func: ReadPostgres<R>,
args: TableFunctionInputs,
) -> Result<Box<dyn PlannedTableFunction2>> {
if !args.named.is_empty() {
impl<R: Runtime> ReadPostgres<R> {
async fn plan_inner<'a>(
self: Self,
_context: &'a DatabaseContext,
positional_inputs: Vec<OwnedScalarValue>,
named_inputs: HashMap<String, OwnedScalarValue>,
) -> Result<PlannedTableFunction> {
if !named_inputs.is_empty() {
return Err(RayexecError::new(
"read_postgres does not accept named arguments",
));
}
if args.positional.len() != 3 {
if positional_inputs.len() != 3 {
return Err(RayexecError::new("read_postgres requires 3 arguments"));
}

let mut args = args.clone();
let table = args.positional.pop().unwrap().try_into_string()?;
let schema = args.positional.pop().unwrap().try_into_string()?;
let conn_str = args.positional.pop().unwrap().try_into_string()?;
let conn_str = positional_inputs.get(0).unwrap().try_as_str()?;
let schema = positional_inputs.get(1).unwrap().try_as_str()?;
let table = positional_inputs.get(2).unwrap().try_as_str()?;

let client = PostgresClient::connect(&conn_str, &func.runtime).await?;
let client = PostgresClient::connect(conn_str, &self.runtime).await?;

let fields = match client.get_fields_and_types(&schema, &table).await? {
Some((fields, _)) => fields,
Expand All @@ -118,40 +87,19 @@ where

let table_schema = Schema::new(fields);

Ok(Box::new(ReadPostgresImpl {
func,
state: ReadPostgresState {
conn_str,
schema,
table,
table_schema,
},
client: Some(client),
}))
}
}

impl<R> PlannedTableFunction2 for ReadPostgresImpl<R>
where
R: Runtime,
{
fn table_function(&self) -> &dyn TableFunction {
&self.func
}

fn schema(&self) -> Schema {
self.state.table_schema.clone()
}

fn encode_state(&self, state: &mut Vec<u8>) -> Result<()> {
self.state.encode(state)
}
let datatable = PostgresDataTable {
client,
schema: schema.to_string(),
table: table.to_string(),
};

fn datatable(&self) -> Result<Box<dyn DataTable>> {
Ok(Box::new(PostgresDataTable {
client: self.client.as_ref().required("postgres client")?.clone(),
schema: self.state.schema.clone(),
table: self.state.table.clone(),
}))
Ok(PlannedTableFunction {
function: Box::new(self),
positional_inputs: positional_inputs.into_iter().map(expr::lit).collect(),
named_inputs,
function_impl: TableFunctionImpl::Scan(Arc::new(datatable)),
cardinality: StatisticsValue::Unknown,
schema: table_schema,
})
}
}

0 comments on commit 1402b0b

Please sign in to comment.