diff --git a/web-api/src/personalization/filter.rs b/web-api/src/personalization/filter.rs index 11737de1c..fbaa6dc6b 100644 --- a/web-api/src/personalization/filter.rs +++ b/web-api/src/personalization/filter.rs @@ -28,7 +28,7 @@ use serde_json::{json, Number, Value}; use crate::{ error::common::InvalidDocumentProperty, - models::{DocumentProperty, DocumentPropertyId}, + models::{DocumentId, DocumentProperty, DocumentPropertyId}, storage::property_filter::IndexedPropertiesSchema, }; @@ -350,16 +350,59 @@ impl<'de> Deserialize<'de> for Combine { } } +#[derive(Clone, Debug, PartialEq)] +pub(crate) struct Ids { + pub(crate) ids: Vec, +} + +impl<'de> Deserialize<'de> for Ids { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + struct SerdeProxy { + #[serde(rename = "$ids")] + pub(crate) ids: Vec, + } + + let SerdeProxy { ids } = SerdeProxy::deserialize(deserializer)?; + + let len = ids.len(); + let max = CompareOp::MAX_VALUES_PER_IN; + if !(1..=max).contains(&len) { + return Err(D::Error::invalid_length( + len, + &format!("$ids must contain 1..={max} ids").as_str(), + )); + } + + // HINT: (de-)serialization doesn't check id validity as it can't differentiate between from-db + // deserialization and query filter deserialization so we need to run a post validation step + let ids = ids + .into_iter() + .map(|id| { + DocumentId::new(&id).map_err(|_| { + D::Error::invalid_value(Unexpected::Str(&id), &"a valid DocumentId") + }) + }) + .try_collect()?; + + Ok(Self { ids }) + } +} + #[derive(Clone, Debug, PartialEq)] pub(crate) enum Filter { Compare(Compare), Combine(Combine), + Ids(Ids), } impl Filter { fn is_below_depth(&self, max: usize) -> bool { match self { - Self::Compare(_) => true, + Self::Compare(_) | Self::Ids(_) => true, Self::Combine(combine) => combine.filters.is_below_depth(max), } } @@ -388,6 +431,10 @@ impl Filter { operation: CombineOp::And, filters: Filters(vec![compare, published_after]), }), + ids @ Self::Ids(_) => Self::Combine(Combine { + operation: CombineOp::And, + filters: Filters(vec![ids, published_after]), + }), Self::Combine(mut combine) if matches!(combine.operation, CombineOp::And) => { combine.filters.push(published_after); Self::Combine(combine) @@ -412,6 +459,7 @@ impl Filter { schema: &IndexedPropertiesSchema, ) -> Result<(), InvalidDocumentProperty> { match self { + Self::Ids(_) => Ok(()), Self::Compare(compare) => schema.validate_filter(&compare.field, &compare.value), Self::Combine(combine) => combine .filters @@ -433,6 +481,7 @@ impl<'de> Deserialize<'de> for Filter { let filter = Content::deserialize(deserializer)?; let deserializer = ContentRefDeserializer::::new(&filter); + let compare = match Compare::deserialize(deserializer) { Ok(compare) => return Ok(Filter::Compare(compare)), Err(error) => error, @@ -441,9 +490,13 @@ impl<'de> Deserialize<'de> for Filter { Ok(combine) => return Ok(Filter::Combine(combine)), Err(error) => error, }; + let ids = match Ids::deserialize(deserializer) { + Ok(ids) => return Ok(Filter::Ids(ids)), + Err(error) => error, + }; Err(D::Error::custom(format!( - "invalid variant, expected one of: Compare({compare}); Combine({combine})", + "invalid variant, expected one of: Compare({compare}); Combine({combine}); Ids({ids})", ))) } } @@ -1063,4 +1116,43 @@ mod tests { }; assert_eq!(filter.validate(&schema).unwrap_err(), error); } + + #[test] + fn test_ids_filter_can_be_provided() { + let filter = serde_json::from_value::(json!({ + "$ids": ["foo", "bar", "baz"] + })) + .unwrap(); + + assert_eq!( + filter, + Filter::Ids(Ids { + ids: vec![ + "foo".parse().unwrap(), + "bar".parse().unwrap(), + "baz".parse().unwrap(), + ] + }) + ); + + serde_json::from_value::(json!({ + "$ids": ["$$$$"] + })) + .unwrap_err(); + + serde_json::from_value::(json!({ + "$ids": [] + })) + .unwrap_err(); + + serde_json::from_value::(json!({ + "$ids": vec!["foo"; CompareOp::MAX_VALUES_PER_IN] + })) + .unwrap(); + + serde_json::from_value::(json!({ + "$ids": vec!["foo"; CompareOp::MAX_VALUES_PER_IN + 1] + })) + .unwrap_err(); + } } diff --git a/web-api/src/storage/elastic/filter.rs b/web-api/src/storage/elastic/filter.rs index 583b0999c..bbba7c8c8 100644 --- a/web-api/src/storage/elastic/filter.rs +++ b/web-api/src/storage/elastic/filter.rs @@ -329,6 +329,7 @@ enum Clause<'a> { Range(Range<'a>), Filter(Filter<'a>), Should(Should<'a>), + Ids(Ids<'a, DocumentId>), } fn merge_range_and(mut clause: Vec>) -> Vec> { @@ -504,6 +505,10 @@ impl<'a> Clause<'a> { }), } } + + filter::Filter::Ids(ids) => Self::Ids(Ids { + values: Cow::Borrowed(&ids.ids), + }), } } } @@ -540,7 +545,10 @@ impl<'a> Clauses<'a> { }; if let Some(filter) = filter { match Clause::new(filter, true) { - clause @ (Clause::Term(_) | Clause::Terms(_) | Clause::Range(_)) => { + clause @ (Clause::Term(_) + | Clause::Terms(_) + | Clause::Range(_) + | Clause::Ids(_)) => { clauses.filter.filter.push(clause); } Clause::Filter(clause) => clauses.filter = clause,