Skip to content

Commit

Permalink
support for docuemnt ids
Browse files Browse the repository at this point in the history
  • Loading branch information
rustonaut committed Nov 16, 2023
1 parent dcd3385 commit ddb9c80
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 4 deletions.
98 changes: 95 additions & 3 deletions web-api/src/personalization/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use serde_json::{json, Number, Value};

use crate::{
error::common::InvalidDocumentProperty,
models::{DocumentProperty, DocumentPropertyId},
models::{DocumentId, DocumentProperty, DocumentPropertyId},
storage::property_filter::IndexedPropertiesSchema,
};

Expand Down Expand Up @@ -350,16 +350,59 @@ impl<'de> Deserialize<'de> for Combine {
}
}

#[derive(Clone, Debug, PartialEq)]
pub(crate) struct Ids {
pub(crate) ids: Vec<DocumentId>,
}

impl<'de> Deserialize<'de> for Ids {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct SerdeProxy {
#[serde(rename = "$ids")]
pub(crate) ids: Vec<String>,
}

let SerdeProxy { ids } = SerdeProxy::deserialize(deserializer)?;

let len = ids.len();
let max = CompareOp::MAX_VALUES_PER_IN;
if !(1..=max).contains(&len) {
return Err(D::Error::invalid_length(
len,
&format!("$ids must contain 1..={max} ids").as_str(),
));
}

// HINT: (de-)serialization doesn't check id validity as it can't differentiate between from-db
// deserialization and query filter deserialization so we need to run a post validation step
let ids = ids
.into_iter()
.map(|id| {
DocumentId::new(&id).map_err(|_| {
D::Error::invalid_value(Unexpected::Str(&id), &"a valid DocumentId")
})
})
.try_collect()?;

Ok(Self { ids })
}
}

#[derive(Clone, Debug, PartialEq)]
pub(crate) enum Filter {
Compare(Compare),
Combine(Combine),
Ids(Ids),
}

impl Filter {
fn is_below_depth(&self, max: usize) -> bool {
match self {
Self::Compare(_) => true,
Self::Compare(_) | Self::Ids(_) => true,
Self::Combine(combine) => combine.filters.is_below_depth(max),
}
}
Expand Down Expand Up @@ -388,6 +431,10 @@ impl Filter {
operation: CombineOp::And,
filters: Filters(vec![compare, published_after]),
}),
ids @ Self::Ids(_) => Self::Combine(Combine {
operation: CombineOp::And,
filters: Filters(vec![ids, published_after]),
}),
Self::Combine(mut combine) if matches!(combine.operation, CombineOp::And) => {
combine.filters.push(published_after);
Self::Combine(combine)
Expand All @@ -412,6 +459,7 @@ impl Filter {
schema: &IndexedPropertiesSchema,
) -> Result<(), InvalidDocumentProperty> {
match self {
Self::Ids(_) => Ok(()),
Self::Compare(compare) => schema.validate_filter(&compare.field, &compare.value),
Self::Combine(combine) => combine
.filters
Expand All @@ -433,6 +481,7 @@ impl<'de> Deserialize<'de> for Filter {

let filter = Content::deserialize(deserializer)?;
let deserializer = ContentRefDeserializer::<D::Error>::new(&filter);

let compare = match Compare::deserialize(deserializer) {
Ok(compare) => return Ok(Filter::Compare(compare)),
Err(error) => error,
Expand All @@ -441,9 +490,13 @@ impl<'de> Deserialize<'de> for Filter {
Ok(combine) => return Ok(Filter::Combine(combine)),
Err(error) => error,
};
let ids = match Ids::deserialize(deserializer) {
Ok(ids) => return Ok(Filter::Ids(ids)),
Err(error) => error,
};

Err(D::Error::custom(format!(
"invalid variant, expected one of: Compare({compare}); Combine({combine})",
"invalid variant, expected one of: Compare({compare}); Combine({combine}); Ids({ids})",
)))
}
}
Expand Down Expand Up @@ -1063,4 +1116,43 @@ mod tests {
};
assert_eq!(filter.validate(&schema).unwrap_err(), error);
}

#[test]
fn test_ids_filter_can_be_provided() {
let filter = serde_json::from_value::<Filter>(json!({
"$ids": ["foo", "bar", "baz"]
}))
.unwrap();

assert_eq!(
filter,
Filter::Ids(Ids {
ids: vec![
"foo".parse().unwrap(),
"bar".parse().unwrap(),
"baz".parse().unwrap(),
]
})
);

serde_json::from_value::<Filter>(json!({
"$ids": ["$$$$"]
}))
.unwrap_err();

serde_json::from_value::<Filter>(json!({
"$ids": []
}))
.unwrap_err();

serde_json::from_value::<Filter>(json!({
"$ids": vec!["foo"; CompareOp::MAX_VALUES_PER_IN]
}))
.unwrap();

serde_json::from_value::<Filter>(json!({
"$ids": vec!["foo"; CompareOp::MAX_VALUES_PER_IN + 1]
}))
.unwrap_err();
}
}
10 changes: 9 additions & 1 deletion web-api/src/storage/elastic/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ enum Clause<'a> {
Range(Range<'a>),
Filter(Filter<'a>),
Should(Should<'a>),
Ids(Ids<'a, DocumentId>),
}

fn merge_range_and(mut clause: Vec<Clause<'_>>) -> Vec<Clause<'_>> {
Expand Down Expand Up @@ -504,6 +505,10 @@ impl<'a> Clause<'a> {
}),
}
}

filter::Filter::Ids(ids) => Self::Ids(Ids {
values: Cow::Borrowed(&ids.ids),
}),
}
}
}
Expand Down Expand Up @@ -540,7 +545,10 @@ impl<'a> Clauses<'a> {
};
if let Some(filter) = filter {
match Clause::new(filter, true) {
clause @ (Clause::Term(_) | Clause::Terms(_) | Clause::Range(_)) => {
clause @ (Clause::Term(_)
| Clause::Terms(_)
| Clause::Range(_)
| Clause::Ids(_)) => {
clauses.filter.filter.push(clause);
}
Clause::Filter(clause) => clauses.filter = clause,
Expand Down

0 comments on commit ddb9c80

Please sign in to comment.