diff --git a/examples/wikidata_dump.rs b/examples/wikidata_dump.rs index 77783c2..ac7de44 100644 --- a/examples/wikidata_dump.rs +++ b/examples/wikidata_dump.rs @@ -32,7 +32,7 @@ fn main() -> Result<(), String> { )); // Load Wikidata entities - let edges = match DuckDB::import("wikidata-20170821-all.duckdb") { + let edges = match DuckDB::import("250_000-lines.duckdb") { Ok(edges) => edges, Err(_) => return Err(String::from("Error creating the edges :(")), }; diff --git a/src/backends/ntriples.rs b/src/backends/ntriples.rs index 9b33b4d..2620058 100644 --- a/src/backends/ntriples.rs +++ b/src/backends/ntriples.rs @@ -2,6 +2,7 @@ use std::io::BufWriter; use std::{fs::File, io::BufReader}; use polars::df; +use polars::enable_string_cache; use polars::prelude::*; use pregel_rs::pregel::Column; use rio_api::formatter::TriplesFormatter; @@ -17,6 +18,8 @@ pub struct NTriples; impl Backend for NTriples { fn import(path: &str) -> Result { + enable_string_cache(true); + let mut subjects = Vec::::new(); let mut predicates = Vec::::new(); let mut objects = Vec::::new(); @@ -44,9 +47,9 @@ impl Backend for NTriples { } match df![ - Column::Subject.as_ref() => Series::new(Column::Subject.as_ref(), subjects), - Column::Predicate.as_ref() => Series::new(Column::Subject.as_ref(), predicates), - Column::Object.as_ref() => Series::new(Column::Subject.as_ref(), objects), + Column::Subject.as_ref() => Series::new(Column::Subject.as_ref(), subjects).cast(&DataType::Categorical(None)).unwrap(), + Column::Predicate.as_ref() => Series::new(Column::Predicate.as_ref(), predicates).cast(&DataType::Categorical(None)).unwrap(), + Column::Object.as_ref() => Series::new(Column::Object.as_ref(), objects).cast(&DataType::Categorical(None)).unwrap(), ] { Ok(edges) => Ok(edges), Err(_) => Err(String::from("Error creating the edges DataFrame")), diff --git a/src/pschema.rs b/src/pschema.rs index 8c2e900..4c08859 100644 --- a/src/pschema.rs +++ b/src/pschema.rs @@ -77,10 +77,6 @@ impl PSchema { // the latter is used during the phase where the vertices are updated. let start = self.start; let mut send_messages_iter = ShapeTree::new(start.to_owned()).into_iter(); // iterator to send messages - let mut v_prog_iter = ShapeTree::new(start.to_owned()).into_iter(); // iterator to update vertices - v_prog_iter.next(); // skip the leaf nodes :D - // Then, we can define the algorithm that will be executed on the graph. The algorithm - // will be executed in parallel on all vertices of the graph. let pregel = PregelBuilder::new(graph.to_owned()) .max_iterations(ShapeTree::new(start).iterations()) .with_vertex_column(Column::Custom("labels")) @@ -89,14 +85,14 @@ impl PSchema { Self::send_messages(send_messages_iter.by_ref()) }) .aggregate_messages_function(Self::aggregate_messages) - .v_prog_function(|| Self::v_prog(v_prog_iter.by_ref())) + .v_prog_function(Self::v_prog) .build(); // Finally, we can run the algorithm and get the result. The result is a DataFrame // containing the labels of the vertices. match pregel.run() { Ok(result) => result .lazy() - .select(&[ + .select([ col(Column::VertexId.as_ref()), col(Column::Custom("labels").as_ref()) .explode() @@ -109,7 +105,7 @@ impl PSchema { col(Column::Custom("labels").as_ref()) .list() .lengths() - .gt(lit(0)), + .gt(0), ) .left_join( graph.edges.lazy(), @@ -127,12 +123,6 @@ impl PSchema { } } - /// The function returns a NULL value. - /// - /// Returns: - /// - /// The function `initial_message()` is returning a NULL value, represented by the - /// `NULL` literal. fn initial_message() -> Expr { lit(NULL) } @@ -144,19 +134,13 @@ impl PSchema { ans = match node { Shape::TripleConstraint(shape) => shape.validate(ans), Shape::ShapeReference(shape) => shape.validate(ans), - Shape::ShapeAnd(_) => ans, - Shape::ShapeOr(_) => ans, - Shape::Cardinality(_) => ans, + Shape::ShapeAnd(shape) => shape.validate(ans), + Shape::ShapeOr(shape) => shape.validate(ans), + Shape::Cardinality(shape) => shape.validate(ans), } } } - match concat_list([ - Column::subject(Column::Custom("labels")), - ans.cast(DataType::Categorical(None)), - ]) { - Ok(concat) => concat, - Err(_) => Column::subject(Column::Custom("labels")), - } + ans } /// The function returns an expression that aggregates messages by exploding a @@ -169,35 +153,11 @@ impl PSchema { /// element in the column), and drops any rows that have NULL values in the /// resulting column. fn aggregate_messages() -> Expr { - Column::msg(None).explode() + Column::msg(None).filter(Column::msg(None).is_not_null()) } - /// The function takes a shape iterator, validates the shapes in it, concatenates - /// the validation results, and returns a unique array. - /// - /// Arguments: - /// - /// * `iterator`: The `iterator` parameter is a mutable reference to a - /// `ShapeIterator`. It is used to iterate over a collection of `WShape` nodes. - /// - /// Returns: - /// - /// The function `v_prog` returns an `Expr` which is the result of calling the - /// `unique` method on an array created from the `ans` variable. - fn v_prog(iterator: &mut dyn Iterator>) -> Expr { - let mut ans = Column::msg(None); - if let Some(nodes) = iterator.next() { - for node in nodes { - ans = match node { - Shape::TripleConstraint(_) => ans, - Shape::ShapeReference(_) => ans, - Shape::ShapeAnd(shape) => shape.validate(ans), - Shape::ShapeOr(shape) => shape.validate(ans), - Shape::Cardinality(shape) => shape.validate(ans), - } - } - } - ans + fn v_prog() -> Expr { + Column::msg(None) } } @@ -249,7 +209,6 @@ mod tests { assert(expected, actual) } Err(error) => { - println!("asd"); println!("{}", error); Err(error.to_string()) } @@ -263,27 +222,27 @@ mod tests { #[test] fn paper_test() -> Result<(), String> { - test(paper_graph(), vec![4u32, 1u32], paper_schema()) + test(paper_graph(), vec![1u32], paper_schema()) } #[test] fn complex_test() -> Result<(), String> { - test(paper_graph(), vec![4u32, 1u32, 1u32], complex_schema()) + test(paper_graph(), vec![1u32], complex_schema()) } #[test] fn reference_test() -> Result<(), String> { - test(paper_graph(), vec![2u32, 1u32, 1u32], reference_schema()) + test(paper_graph(), vec![1u32], reference_schema()) } #[test] fn optional_test() -> Result<(), String> { - test(paper_graph(), vec![3u32, 1u32, 1u32], optional_schema()) + test(paper_graph(), vec![1u32, 1u32], optional_schema()) } #[test] fn conditional_test() -> Result<(), String> { - test(paper_graph(), vec![2u32, 2u32, 2u32], conditional_schema()) + test(paper_graph(), vec![1u32, 1u32, 1u32], conditional_schema()) } #[test] @@ -293,12 +252,12 @@ mod tests { #[test] fn cardinality_test() -> Result<(), String> { - test(paper_graph(), vec![3u32, 1u32], cardinality_schema()) + test(paper_graph(), vec![1u32, 1u32], cardinality_schema()) } #[test] fn vprog_to_vprog_test() -> Result<(), String> { - test(paper_graph(), vec![3u32], vprog_to_vprog()) + test(paper_graph(), vec![1u32], vprog_to_vprog()) } #[test] diff --git a/src/shape/shape_tree.rs b/src/shape/shape_tree.rs index 2a5d22a..32b8615 100644 --- a/src/shape/shape_tree.rs +++ b/src/shape/shape_tree.rs @@ -116,9 +116,9 @@ impl ShapeTree { match shape { Shape::TripleConstraint(_) => continue, Shape::ShapeReference(_) => continue, - Shape::ShapeAnd(_) => return true, - Shape::ShapeOr(_) => return true, - Shape::Cardinality(_) => return true, + Shape::ShapeAnd(_) => continue, + Shape::ShapeOr(_) => continue, + Shape::Cardinality(_) => continue, }; } } diff --git a/src/shape/shex.rs b/src/shape/shex.rs index 686bb1f..51dab31 100644 --- a/src/shape/shex.rs +++ b/src/shape/shex.rs @@ -1,4 +1,3 @@ -use polars::lazy::dsl::concat_list; use polars::prelude::*; use pregel_rs::pregel::Column; use pregel_rs::pregel::Column::{Custom, Object, Predicate}; @@ -391,12 +390,9 @@ impl Validate for ShapeAnd { /// and `prev` using the ` fn validate(self, prev: Expr) -> Expr { when(self.shapes.iter().fold(lit(true), |acc, shape| { - acc.and(lit(shape.get_label()).is_in(Column::msg(None))) + acc.and(lit(shape.get_label()).is_in(Column::subject(Column::Custom("labels")))) })) - .then(match concat_list([lit(self.label), prev.to_owned()]) { - Ok(concat) => concat, - Err(_) => prev.to_owned(), - }) + .then(lit(self.label)) .otherwise(prev) } } @@ -422,12 +418,9 @@ impl From> for Shape { impl Validate for ShapeOr { fn validate(self, prev: Expr) -> Expr { when(self.shapes.iter().fold(lit(false), |acc, shape| { - acc.or(lit(shape.get_label()).is_in(Column::msg(None))) + acc.or(lit(shape.get_label()).is_in(Column::subject(Column::Custom("labels")))) })) - .then(match concat_list([lit(self.label), prev.to_owned()]) { - Ok(concat) => concat, - Err(_) => prev.to_owned(), - }) + .then(lit(self.label)) .otherwise(prev) } } @@ -492,11 +485,11 @@ impl Validate for Cardinality { /// /// The `validate` function is returning an `Expr` object. fn validate(self, prev: Expr) -> Expr { - let count = Column::msg(None) + let count = Column::subject(Column::Custom("labels")) .list() - .eval(col("").eq(lit(self.shape.get_label())).cumsum(false), true) + .eval(col("").eq(lit(self.shape.get_label())), true) .list() - .first(); + .sum(); when( match self.min { @@ -512,10 +505,7 @@ impl Validate for Cardinality { Bound::Many => count.lt_eq(lit(u8::MAX)), }), ) - .then(match concat_list([lit(self.label), prev.to_owned()]) { - Ok(concat) => concat, - Err(_) => prev.to_owned(), - }) + .then(lit(self.label)) .otherwise(prev) } }