Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Window/Aggregate function implementation -- please review #14

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,11 @@ crate-type = ["cdylib"]
[[example]]
name = "load_permanent"
crate-type = ["cdylib"]

[[example]]
name = "sum_int"
crate-type = ["cdylib"]

[[example]]
name = "sum_int_aux"
crate-type = ["cdylib"]
8 changes: 8 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
FROM debian:bullseye-slim

RUN apt-get update && apt-get install -y curl valgrind build-essential clang
# Install Rust
ENV RUST_VERSION=stable
RUN curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain=$RUST_VERSION
# Install cargo-valgrind
RUN /bin/bash -c "source /root/.cargo/env && cargo install cargo-valgrind"
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ select * from xxx;

Some real-world non-Rust examples of traditional virtual tables in SQLite include the [CSV virtual table](https://www.sqlite.org/csv.html), the full-text search [fts5 extension](https://www.sqlite.org/fts5.html#fts5_table_creation_and_initialization), and the [R-Tree extension](https://www.sqlite.org/rtree.html#creating_an_r_tree_index).


### Window / Aggregate functions

A window function can be defined using the `define_window_function`. The step and final function must be defined. See the [`sum_int.rs`](./examples/sum_int.rs) implementation and the [sqlite's own example](https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions) for a full solution.

There is also a [`define_window_function_with_aux`](./src/window.rs), in case a mutable auxillary object is required in place of the context aggregate pointer provided by sqlite3. In this case, the object being passed is not required to implement the Copy trait.


## Examples

The [`examples/`](./examples/) directory has a few bare-bones examples of extensions, which you can build with:
Expand Down Expand Up @@ -241,7 +249,7 @@ A hello world extension in C is `17KB`, while one in Rust is `469k`. It's still

- [ ] Stabilize scalar function interface
- [ ] Stabilize virtual table interface
- [ ] Support [aggregate window functions](https://www.sqlite.org/windowfunctions.html#udfwinfunc) ([#1](https://github.com/asg017/sqlite-loadable-rs/issues/1))
- [x] Support [aggregate window functions](https://www.sqlite.org/windowfunctions.html#udfwinfunc) ([#1](https://github.com/asg017/sqlite-loadable-rs/issues/1))
- [ ] Support [collating sequences](https://www.sqlite.org/c3ref/create_collation.html) ([#2](https://github.com/asg017/sqlite-loadable-rs/issues/2))
- [ ] Support [virtual file systems](sqlite.org/vfs.html) ([#3](https://github.com/asg017/sqlite-loadable-rs/issues/3))

Expand Down
50 changes: 50 additions & 0 deletions examples/sum_int.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
//! cargo build --example sum_int
//! sqlite3 :memory: '.read examples/test.sql'

use libsqlite3_sys::sqlite3_int64;
use sqlite_loadable::prelude::*;
use sqlite_loadable::window::define_window_function;
use sqlite_loadable::{api, Result};

/// Example inspired by sqlite3's sumint
/// https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions
pub fn x_step(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> {
assert!(values.len() == 1);
let new_value = api::value_int64(values.get(0).expect("should be one"));
let previous_value = api::get_aggregate_context_value::<sqlite3_int64>(context)?;
api::set_aggregate_context_value::<sqlite3_int64>(context, previous_value + new_value)?;
Ok(())
}


pub fn x_final(context: *mut sqlite3_context) -> Result<()> {
let value = api::get_aggregate_context_value::<sqlite3_int64>(context)?;
api::result_int64(context, value);
Ok(())
}


pub fn x_value(context: *mut sqlite3_context) -> Result<()> {
let value = api::get_aggregate_context_value::<sqlite3_int64>(context)?;
api::result_int64(context, value);
Ok(())
}

pub fn x_inverse(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> {
assert!(values.len() == 1);
let new_value = api::value_int64(values.get(0).expect("should be one"));
let previous_value = api::get_aggregate_context_value::<sqlite3_int64>(context)?;
api::set_aggregate_context_value::<sqlite3_int64>(context, previous_value - new_value)?;
Ok(())
}

#[sqlite_entrypoint]
pub fn sqlite3_sumint_init(db: *mut sqlite3) -> Result<()> {
let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC;
define_window_function(
db, "sumint", -1, flags,
x_step, x_final, Some(x_value), Some(x_inverse),
)?;
Ok(())
}

44 changes: 44 additions & 0 deletions examples/sum_int_aux.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//! cargo build --example sum_int_aux

use sqlite_loadable::prelude::*;
use sqlite_loadable::window::define_window_function_with_aux;
use sqlite_loadable::{api, Result};

/// Example inspired by sqlite3's sumint
/// https://www.sqlite.org/windowfunctions.html#user_defined_aggregate_window_functions
pub fn x_step(_context: *mut sqlite3_context, values: &[*mut sqlite3_value], i: &mut i64) -> Result<()> {
assert!(values.len() == 1);
let new_value = api::value_int64(values.get(0).expect("should be one"));
*i = *i + new_value;
Ok(())
}


pub fn x_final(context: *mut sqlite3_context, i: &mut i64) -> Result<()> {
api::result_int64(context, *i);
Ok(())
}


pub fn x_value(context: *mut sqlite3_context, i: &mut i64) -> Result<()> {
api::result_int64(context, *i);
Ok(())
}

pub fn x_inverse(_context: *mut sqlite3_context, values: &[*mut sqlite3_value], i: &mut i64) -> Result<()> {
assert!(values.len() == 1);
let new_value = api::value_int64(values.get(0).expect("should be one"));
*i = *i - new_value;
Ok(())
}

#[sqlite_entrypoint]
pub fn sqlite3_sumintaux_init(db: *mut sqlite3) -> Result<()> {
let flags = FunctionFlags::UTF8 | FunctionFlags::DETERMINISTIC;
define_window_function_with_aux::<i64>(
db, "sumint_aux", -1, flags,
x_step, x_final, Some(x_value), Some(x_inverse),
0,
)?;
Ok(())
}
8 changes: 8 additions & 0 deletions run-docker.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/sh
NAME="io_uring:1.0"
docker image inspect "$NAME" || docker build -t "$NAME" .
docker run -it -v $PWD:/tmp -w /tmp $NAME

# see https://github.com/jfrimmel/cargo-valgrind/pull/58/commits/1c168f296e0b3daa50279c642dd37aecbd85c5ff#L59
# scan for double frees and leaks
# VALGRINDFLAGS="--leak-check=yes --trace-children=yes" cargo valgrind test
6 changes: 3 additions & 3 deletions sqlite-loadable-macros/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 39 additions & 0 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::ext::{
};
use crate::Error;
use sqlite3ext_sys::{SQLITE_BLOB, SQLITE_FLOAT, SQLITE_INTEGER, SQLITE_NULL, SQLITE_TEXT};
use sqlite3ext_sys::sqlite3_aggregate_context;
use std::os::raw::c_int;
use std::slice::from_raw_parts;
use std::str::Utf8Error;
Expand Down Expand Up @@ -568,3 +569,41 @@ impl ExtendedColumnAffinity {
ExtendedColumnAffinity::Numeric
}
}

// TODO write test
pub fn get_aggregate_context_value<T>(context: *mut sqlite3_context) -> Result<T, String>
where
T: Copy,
{
let p_value: *mut T = unsafe {
sqlite3_aggregate_context(context, std::mem::size_of::<T>() as i32) as *mut T
};

if p_value.is_null() {
return Err("sqlite3_aggregate_context returned a null pointer.".to_string());
}

let value: T = unsafe { *p_value };

Ok(value)
}

// TODO write test
pub fn set_aggregate_context_value<T>(context: *mut sqlite3_context, value: T) -> Result<(), String>
where
T: Copy,
{
let p_value: *mut T = unsafe {
sqlite3_aggregate_context(context, std::mem::size_of::<T>() as i32) as *mut T
};

if p_value.is_null() {
return Err("sqlite3_aggregate_context returned a null pointer.".to_string());
}

unsafe {
*p_value = value;
}

Ok(())
}
27 changes: 27 additions & 0 deletions src/bit_flags.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use bitflags::bitflags;

use sqlite3ext_sys::{
SQLITE_DETERMINISTIC, SQLITE_DIRECTONLY, SQLITE_INNOCUOUS, SQLITE_SUBTYPE, SQLITE_UTF16,
SQLITE_UTF16BE, SQLITE_UTF16LE, SQLITE_UTF8,
};

bitflags! {
/// Represents the possible flag values that can be passed into sqlite3_create_function_v2
/// or sqlite3_create_window_function, as the 4th "eTextRep" parameter.
/// Includes both the encoding options (utf8, utf16, etc.) and function-level parameters
/// (deterministion, innocuous, etc.).
pub struct FunctionFlags: i32 {
const UTF8 = SQLITE_UTF8 as i32;
const UTF16LE = SQLITE_UTF16LE as i32;
const UTF16BE = SQLITE_UTF16BE as i32;
const UTF16 = SQLITE_UTF16 as i32;

/// "... to signal that the function will always return the same result given the same
/// inputs within a single SQL statement."
/// <https://www.sqlite.org/c3ref/create_function.html#:~:text=ORed%20with%20SQLITE_DETERMINISTIC>
const DETERMINISTIC = SQLITE_DETERMINISTIC as i32;
const DIRECTONLY = SQLITE_DIRECTONLY as i32;
const SUBTYPE = SQLITE_SUBTYPE as i32;
const INNOCUOUS = SQLITE_INNOCUOUS as i32;
}
}
2 changes: 2 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ impl Error {
ErrorKind::CStringUtf8Error(_) => "utf8 err".to_owned(),
ErrorKind::Message(msg) => msg,
ErrorKind::TableFunction(_) => "table func error".to_owned(),
ErrorKind::DefineWindowFunction(_) => "Error defining window function".to_owned(),
}
}
}
Expand All @@ -53,6 +54,7 @@ impl Error {
#[derive(Debug, PartialEq, Eq)]
pub enum ErrorKind {
DefineScalarFunction(c_int),
DefineWindowFunction(c_int),
CStringError(NulError),
CStringUtf8Error(std::str::Utf8Error),
TableFunction(c_int),
Expand Down
38 changes: 37 additions & 1 deletion src/ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub use sqlite3ext_sys::{
sqlite3, sqlite3_api_routines, sqlite3_context, sqlite3_index_info,
sqlite3_index_info_sqlite3_index_constraint, sqlite3_index_info_sqlite3_index_constraint_usage,
sqlite3_index_info_sqlite3_index_orderby, sqlite3_module, sqlite3_stmt, sqlite3_value,
sqlite3_vtab, sqlite3_vtab_cursor,
sqlite3_vtab, sqlite3_vtab_cursor, sqlite3_create_window_function
};

/// If creating a dynmically loadable extension, this MUST be redefined to point
Expand Down Expand Up @@ -413,6 +413,42 @@ pub unsafe fn sqlite3ext_get_auxdata(context: *mut sqlite3_context, n: c_int) ->
((*SQLITE3_API).get_auxdata.expect(EXPECT_MESSAGE))(context, n)
}

#[cfg(feature = "static")]
pub unsafe fn sqlite3ext_create_window_function(
db: *mut sqlite3,
s: *const c_char,
argc: i32,
text_rep: i32,
p_app: *mut c_void,
x_step: Option<unsafe extern "C" fn(*mut sqlite3_context, i32, *mut *mut sqlite3_value)>,
x_final: Option<unsafe extern "C" fn(*mut sqlite3_context)>,
x_value: Option<unsafe extern "C" fn(*mut sqlite3_context)>,
x_inverse: Option<unsafe extern "C" fn(*mut sqlite3_context, i32, *mut *mut sqlite3_value)>,
destroy: Option<unsafe extern "C" fn(*mut c_void)>
) -> c_int {
sqlite3_create_window_function(
db, s, argc, text_rep, p_app, x_step, x_final, x_value, x_inverse, destroy,
)
}

#[cfg(not(feature = "static"))]
pub unsafe fn sqlite3ext_create_window_function(
db: *mut sqlite3,
s: *const c_char,
argc: i32,
text_rep: i32,
p_app: *mut c_void,
x_step: Option<unsafe extern "C" fn(*mut sqlite3_context, i32, *mut *mut sqlite3_value)>,
x_final: Option<unsafe extern "C" fn(*mut sqlite3_context)>,
x_value: Option<unsafe extern "C" fn(*mut sqlite3_context)>,
x_inverse: Option<unsafe extern "C" fn(*mut sqlite3_context, i32, *mut *mut sqlite3_value)>,
destroy: Option<unsafe extern "C" fn(*mut c_void)>
) -> c_int {
((*SQLITE3_API).create_window_function.expect(EXPECT_MESSAGE))(
db, s, argc, text_rep, p_app, x_step, x_final, x_value, x_inverse, destroy,
)
}

#[cfg(feature = "static")]
pub unsafe fn sqlite3ext_create_function_v2(
db: *mut sqlite3,
Expand Down
10 changes: 9 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,20 @@ pub mod prelude;
pub mod scalar;
pub mod table;
pub mod vtab_argparse;
pub mod window;
pub mod bit_flags;

#[doc(inline)]
pub use errors::{Error, ErrorKind, Result};

#[doc(inline)]
pub use scalar::{define_scalar_function, define_scalar_function_with_aux, FunctionFlags};
pub use bit_flags::FunctionFlags;

#[doc(inline)]
pub use scalar::{define_scalar_function, define_scalar_function_with_aux};

#[doc(inline)]
pub use window::{define_window_function,define_window_function_with_aux};

#[doc(inline)]
pub use collation::define_collation;
Expand Down
Loading