Skip to content

Commit

Permalink
improving the solution (now it works)
Browse files Browse the repository at this point in the history
  • Loading branch information
angelip2303 committed Jun 29, 2023
1 parent fa719a5 commit be3c397
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 83 deletions.
2 changes: 1 addition & 1 deletion examples/wikidata_dump.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fn main() -> Result<(), String> {
));

// Load Wikidata entities
let edges = match DuckDB::import("wikidata-20170821-all.duckdb") {
let edges = match DuckDB::import("250_000-lines.duckdb") {
Ok(edges) => edges,
Err(_) => return Err(String::from("Error creating the edges :(")),
};
Expand Down
9 changes: 6 additions & 3 deletions src/backends/ntriples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::io::BufWriter;
use std::{fs::File, io::BufReader};

use polars::df;
use polars::enable_string_cache;
use polars::prelude::*;
use pregel_rs::pregel::Column;
use rio_api::formatter::TriplesFormatter;
Expand All @@ -17,6 +18,8 @@ pub struct NTriples;

impl Backend for NTriples {
fn import(path: &str) -> Result<DataFrame, String> {
enable_string_cache(true);

let mut subjects = Vec::<String>::new();
let mut predicates = Vec::<String>::new();
let mut objects = Vec::<String>::new();
Expand Down Expand Up @@ -44,9 +47,9 @@ impl Backend for NTriples {
}

match df![
Column::Subject.as_ref() => Series::new(Column::Subject.as_ref(), subjects),
Column::Predicate.as_ref() => Series::new(Column::Subject.as_ref(), predicates),
Column::Object.as_ref() => Series::new(Column::Subject.as_ref(), objects),
Column::Subject.as_ref() => Series::new(Column::Subject.as_ref(), subjects).cast(&DataType::Categorical(None)).unwrap(),
Column::Predicate.as_ref() => Series::new(Column::Predicate.as_ref(), predicates).cast(&DataType::Categorical(None)).unwrap(),
Column::Object.as_ref() => Series::new(Column::Object.as_ref(), objects).cast(&DataType::Categorical(None)).unwrap(),
] {
Ok(edges) => Ok(edges),
Err(_) => Err(String::from("Error creating the edges DataFrame")),
Expand Down
75 changes: 17 additions & 58 deletions src/pschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,6 @@ impl<T: Literal + Clone> PSchema<T> {
// the latter is used during the phase where the vertices are updated.
let start = self.start;
let mut send_messages_iter = ShapeTree::new(start.to_owned()).into_iter(); // iterator to send messages
let mut v_prog_iter = ShapeTree::new(start.to_owned()).into_iter(); // iterator to update vertices
v_prog_iter.next(); // skip the leaf nodes :D
// Then, we can define the algorithm that will be executed on the graph. The algorithm
// will be executed in parallel on all vertices of the graph.
let pregel = PregelBuilder::new(graph.to_owned())
.max_iterations(ShapeTree::new(start).iterations())
.with_vertex_column(Column::Custom("labels"))
Expand All @@ -89,14 +85,14 @@ impl<T: Literal + Clone> PSchema<T> {
Self::send_messages(send_messages_iter.by_ref())
})
.aggregate_messages_function(Self::aggregate_messages)
.v_prog_function(|| Self::v_prog(v_prog_iter.by_ref()))
.v_prog_function(Self::v_prog)
.build();
// Finally, we can run the algorithm and get the result. The result is a DataFrame
// containing the labels of the vertices.
match pregel.run() {
Ok(result) => result
.lazy()
.select(&[
.select([
col(Column::VertexId.as_ref()),
col(Column::Custom("labels").as_ref())
.explode()
Expand All @@ -109,7 +105,7 @@ impl<T: Literal + Clone> PSchema<T> {
col(Column::Custom("labels").as_ref())
.list()
.lengths()
.gt(lit(0)),
.gt(0),
)
.left_join(
graph.edges.lazy(),
Expand All @@ -127,12 +123,6 @@ impl<T: Literal + Clone> PSchema<T> {
}
}

/// The function returns a NULL value.
///
/// Returns:
///
/// The function `initial_message()` is returning a NULL value, represented by the
/// `NULL` literal.
fn initial_message() -> Expr {
lit(NULL)
}
Expand All @@ -144,19 +134,13 @@ impl<T: Literal + Clone> PSchema<T> {
ans = match node {
Shape::TripleConstraint(shape) => shape.validate(ans),
Shape::ShapeReference(shape) => shape.validate(ans),
Shape::ShapeAnd(_) => ans,
Shape::ShapeOr(_) => ans,
Shape::Cardinality(_) => ans,
Shape::ShapeAnd(shape) => shape.validate(ans),
Shape::ShapeOr(shape) => shape.validate(ans),
Shape::Cardinality(shape) => shape.validate(ans),
}
}
}
match concat_list([
Column::subject(Column::Custom("labels")),
ans.cast(DataType::Categorical(None)),
]) {
Ok(concat) => concat,
Err(_) => Column::subject(Column::Custom("labels")),
}
ans
}

/// The function returns an expression that aggregates messages by exploding a
Expand All @@ -169,35 +153,11 @@ impl<T: Literal + Clone> PSchema<T> {
/// element in the column), and drops any rows that have NULL values in the
/// resulting column.
fn aggregate_messages() -> Expr {
Column::msg(None).explode()
Column::msg(None).filter(Column::msg(None).is_not_null())
}

/// The function takes a shape iterator, validates the shapes in it, concatenates
/// the validation results, and returns a unique array.
///
/// Arguments:
///
/// * `iterator`: The `iterator` parameter is a mutable reference to a
/// `ShapeIterator`. It is used to iterate over a collection of `WShape` nodes.
///
/// Returns:
///
/// The function `v_prog` returns an `Expr` which is the result of calling the
/// `unique` method on an array created from the `ans` variable.
fn v_prog(iterator: &mut dyn Iterator<Item = ShapeTreeItem<T>>) -> Expr {
let mut ans = Column::msg(None);
if let Some(nodes) = iterator.next() {
for node in nodes {
ans = match node {
Shape::TripleConstraint(_) => ans,
Shape::ShapeReference(_) => ans,
Shape::ShapeAnd(shape) => shape.validate(ans),
Shape::ShapeOr(shape) => shape.validate(ans),
Shape::Cardinality(shape) => shape.validate(ans),
}
}
}
ans
fn v_prog() -> Expr {
Column::msg(None)
}
}

Expand Down Expand Up @@ -249,7 +209,6 @@ mod tests {
assert(expected, actual)
}
Err(error) => {
println!("asd");
println!("{}", error);
Err(error.to_string())
}
Expand All @@ -263,27 +222,27 @@ mod tests {

#[test]
fn paper_test() -> Result<(), String> {
test(paper_graph(), vec![4u32, 1u32], paper_schema())
test(paper_graph(), vec![1u32], paper_schema())
}

#[test]
fn complex_test() -> Result<(), String> {
test(paper_graph(), vec![4u32, 1u32, 1u32], complex_schema())
test(paper_graph(), vec![1u32], complex_schema())
}

#[test]
fn reference_test() -> Result<(), String> {
test(paper_graph(), vec![2u32, 1u32, 1u32], reference_schema())
test(paper_graph(), vec![1u32], reference_schema())
}

#[test]
fn optional_test() -> Result<(), String> {
test(paper_graph(), vec![3u32, 1u32, 1u32], optional_schema())
test(paper_graph(), vec![1u32, 1u32], optional_schema())
}

#[test]
fn conditional_test() -> Result<(), String> {
test(paper_graph(), vec![2u32, 2u32, 2u32], conditional_schema())
test(paper_graph(), vec![1u32, 1u32, 1u32], conditional_schema())
}

#[test]
Expand All @@ -293,12 +252,12 @@ mod tests {

#[test]
fn cardinality_test() -> Result<(), String> {
test(paper_graph(), vec![3u32, 1u32], cardinality_schema())
test(paper_graph(), vec![1u32, 1u32], cardinality_schema())
}

#[test]
fn vprog_to_vprog_test() -> Result<(), String> {
test(paper_graph(), vec![3u32], vprog_to_vprog())
test(paper_graph(), vec![1u32], vprog_to_vprog())
}

#[test]
Expand Down
6 changes: 3 additions & 3 deletions src/shape/shape_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ impl<T: Literal + Clone> ShapeTree<T> {
match shape {
Shape::TripleConstraint(_) => continue,
Shape::ShapeReference(_) => continue,
Shape::ShapeAnd(_) => return true,
Shape::ShapeOr(_) => return true,
Shape::Cardinality(_) => return true,
Shape::ShapeAnd(_) => continue,
Shape::ShapeOr(_) => continue,
Shape::Cardinality(_) => continue,
};
}
}
Expand Down
26 changes: 8 additions & 18 deletions src/shape/shex.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use polars::lazy::dsl::concat_list;
use polars::prelude::*;
use pregel_rs::pregel::Column;
use pregel_rs::pregel::Column::{Custom, Object, Predicate};
Expand Down Expand Up @@ -391,12 +390,9 @@ impl<T: Literal + Clone> Validate for ShapeAnd<T> {
/// and `prev` using the `
fn validate(self, prev: Expr) -> Expr {
when(self.shapes.iter().fold(lit(true), |acc, shape| {
acc.and(lit(shape.get_label()).is_in(Column::msg(None)))
acc.and(lit(shape.get_label()).is_in(Column::subject(Column::Custom("labels"))))
}))
.then(match concat_list([lit(self.label), prev.to_owned()]) {
Ok(concat) => concat,
Err(_) => prev.to_owned(),
})
.then(lit(self.label))
.otherwise(prev)
}
}
Expand All @@ -422,12 +418,9 @@ impl<T: Literal + Clone> From<ShapeOr<T>> for Shape<T> {
impl<T: Literal + Clone> Validate for ShapeOr<T> {
fn validate(self, prev: Expr) -> Expr {
when(self.shapes.iter().fold(lit(false), |acc, shape| {
acc.or(lit(shape.get_label()).is_in(Column::msg(None)))
acc.or(lit(shape.get_label()).is_in(Column::subject(Column::Custom("labels"))))
}))
.then(match concat_list([lit(self.label), prev.to_owned()]) {
Ok(concat) => concat,
Err(_) => prev.to_owned(),
})
.then(lit(self.label))
.otherwise(prev)
}
}
Expand Down Expand Up @@ -492,11 +485,11 @@ impl<T: Literal + Clone> Validate for Cardinality<T> {
///
/// The `validate` function is returning an `Expr` object.
fn validate(self, prev: Expr) -> Expr {
let count = Column::msg(None)
let count = Column::subject(Column::Custom("labels"))
.list()
.eval(col("").eq(lit(self.shape.get_label())).cumsum(false), true)
.eval(col("").eq(lit(self.shape.get_label())), true)
.list()
.first();
.sum();

when(
match self.min {
Expand All @@ -512,10 +505,7 @@ impl<T: Literal + Clone> Validate for Cardinality<T> {
Bound::Many => count.lt_eq(lit(u8::MAX)),
}),
)
.then(match concat_list([lit(self.label), prev.to_owned()]) {
Ok(concat) => concat,
Err(_) => prev.to_owned(),
})
.then(lit(self.label))
.otherwise(prev)
}
}
Expand Down

0 comments on commit be3c397

Please sign in to comment.