Skip to content

Commit

Permalink
feat(wip): add plugin registration
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo committed Sep 3, 2024
1 parent fe0738a commit 88855de
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 4 deletions.
21 changes: 21 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ parking_lot = "0.12"
regex-syntax = "0.8"
syn = "2.0.68"
url = "2"
datafusion-plugin-core = { git = "https://github.com/mesejo/datafusion-plugin-core.git", rev = "233dc5b4a7749c254e17b9c168440240eb7e022a" }
libloading = "0.8.5"

[build-dependencies]
pyo3-build-config = "0.21"
Expand Down
60 changes: 60 additions & 0 deletions examples/register-extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from pathlib import Path
import pyarrow as pa
import json

from datafusion.context import SessionContext


def register_plugin_json(context: SessionContext, path: Path, func: str):
# access the inner context because as of now the function register_plugin is not exposed
context.ctx.register_plugin(path, func)


ctx = SessionContext()
extension_path = (
Path(__file__).parent.resolve() / "extension" / "libdatafusion_plugin_json.so"
)
function_to_call = "json_functions"

# Sample data
data = [
{
"id": 1,
"name": "Alice",
"json_data": json.dumps({"age": 30, "city": "New York"}),
},
{
"id": 2,
"name": "Bob",
"json_data": json.dumps({"age": 25, "city": "San Francisco"}),
},
{
"id": 3,
"name": "Charlie",
"json_data": json.dumps({"age": 35, "city": "London"}),
},
]

# Define the schema
schema = pa.schema(
[
("id", pa.int64()),
("name", pa.string()),
("json_data", pa.string()), # JSON data stored as string
]
)

# Create arrays for each column
id_array = pa.array([row["id"] for row in data], type=pa.int64())
name_array = pa.array([row["name"] for row in data], type=pa.string())
json_array = pa.array([row["json_data"] for row in data], type=pa.string())

# Create the table
table = pa.Table.from_arrays([id_array, name_array, json_array], schema=schema)

ctx.register_record_batches("json_table", [table.to_batches()])
register_plugin_json(ctx, extension_path, function_to_call)
res = ctx.sql(
"select name, json_get_str(json_data, 'city') from json_table"
).to_pandas()
print(res)
135 changes: 135 additions & 0 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
// under the License.

use std::collections::{HashMap, HashSet};
use std::ffi::OsStr;
use std::io;
use std::io::ErrorKind;
use std::path::PathBuf;
use std::rc::Rc;
use std::str::FromStr;
use std::sync::Arc;

Expand Down Expand Up @@ -62,6 +66,8 @@ use datafusion::prelude::{
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
};
use datafusion_common::ScalarValue;
use datafusion_plugin_core::{Function, InvocationError, PluginDeclaration};
use libloading::Library;
use pyo3::types::{PyDict, PyList, PyTuple};
use tokio::task::JoinHandle;

Expand Down Expand Up @@ -967,12 +973,36 @@ impl PySessionContext {
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
Ok(PyRecordBatchStream::new(stream?))
}

pub fn register_plugin(&mut self, path: PathBuf, name: &str, py: Python) -> PyResult<()> {
wait_for_future(py, self._register_plugin(path, name)).map_err(DataFusionError::from)?;
Ok(())
}
}

impl PySessionContext {
async fn _table(&self, name: &str) -> datafusion_common::Result<DataFrame> {
self.ctx.table(name).await
}

async fn _register_plugin(
&mut self,
path: PathBuf,
name: &str,
) -> datafusion_common::Result<()> {
let mut functions = ExternalFunctions::new();

unsafe {
functions.load(path).expect("Function loading failed");
}

// then call the function
functions
.call(name, &mut self.ctx)
.expect("Invocation failed");

Ok(())
}
}

pub fn convert_table_partition_cols(
Expand Down Expand Up @@ -1010,3 +1040,108 @@ impl From<SessionContext> for PySessionContext {
PySessionContext { ctx }
}
}

#[derive(Default)]
pub struct ExternalFunctions {
functions: HashMap<String, FunctionProxy>,
libraries: Vec<Rc<Library>>,
}

impl ExternalFunctions {
pub fn new() -> ExternalFunctions {
ExternalFunctions::default()
}

pub fn call(
&self,
function: &str,
argument: &mut SessionContext,
) -> Result<(), InvocationError> {
self.functions
.get(function)
.ok_or_else(|| format!("\"{}\" not found", function))?
.call(argument)
}

/// Load a plugin library and add all contained functions to the internal
/// function table.
///
/// # Safety
///
/// A plugin library **must** be implemented using the
/// [`plugins_core::plugin_declaration!()`] macro. Trying manually implement
/// a plugin without going through that macro will result in undefined
/// behaviour.
pub unsafe fn load<P: AsRef<OsStr>>(&mut self, library_path: P) -> io::Result<()> {
// load the library into memory
let library = Rc::new(
Library::new(library_path)
.map_err(|e| io::Error::new(ErrorKind::Other, e.to_string()))?,
);

// get a pointer to the plugin_declaration symbol.
let decl = library
.get::<*mut PluginDeclaration>(b"plugin_declaration\0")
.map_err(|e| io::Error::new(ErrorKind::Other, e.to_string()))?
.read();

// version checks to prevent accidental ABI incompatibilities
if decl.rustc_version != datafusion_plugin_core::RUSTC_VERSION
|| decl.core_version != datafusion_plugin_core::CORE_VERSION
{
return Err(io::Error::new(io::ErrorKind::Other, "Version mismatch"));
}

let mut registrar = PluginRegistrar::new(Rc::clone(&library));

(decl.register)(&mut registrar);

// add all loaded plugins to the functions map
self.functions.extend(registrar.functions);
// and make sure ExternalFunctions keeps a reference to the library
self.libraries.push(library);

Ok(())
}
}

struct PluginRegistrar {
functions: HashMap<String, FunctionProxy>,
lib: Rc<Library>,
}

impl PluginRegistrar {
fn new(lib: Rc<Library>) -> PluginRegistrar {
PluginRegistrar {
lib,
functions: HashMap::default(),
}
}
}

impl datafusion_plugin_core::PluginRegistrar for PluginRegistrar {
fn register_function(&mut self, name: &str, function: Box<dyn Function>) {
let proxy = FunctionProxy {
function,
_lib: Rc::clone(&self.lib),
};
self.functions.insert(name.to_string(), proxy);
}
}

/// A proxy object which wraps a [`Function`] and makes sure it can't outlive
/// the library it came from.
pub struct FunctionProxy {
function: Box<dyn Function>,
_lib: Rc<Library>,
}

impl Function for FunctionProxy {
fn call(&self, context: &mut SessionContext) -> Result<(), InvocationError> {
self.function.call(context)
}

fn help(&self) -> Option<&str> {
self.function.help()
}
}
6 changes: 2 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
// specific language governing permissions and limitations
// under the License.

#[cfg(feature = "mimalloc")]
use mimalloc::MiMalloc;
use pyo3::prelude::*;
use std::alloc::System;

// Re-export Apache Arrow DataFusion dependencies
pub use datafusion;
Expand Down Expand Up @@ -59,9 +58,8 @@ mod udaf;
mod udf;
pub mod utils;

#[cfg(feature = "mimalloc")]
#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;
static ALLOCATOR: System = System;

// Used to define Tokio Runtime as a Python module attribute
#[pyclass]
Expand Down

0 comments on commit 88855de

Please sign in to comment.