diff --git a/CHANGELOG.md b/CHANGELOG.md index 8322ff2c7c..b6df4f7113 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ #### Changes * Python: Added OBJECT ENCODING command ([#1471](https://github.com/aws/glide-for-redis/pull/1471)) * Python: Added OBJECT FREQ command ([#1472](https://github.com/aws/glide-for-redis/pull/1472)) +* Python: Added GEOSEARCH command ([#1482](https://github.com/aws/glide-for-redis/pull/1482)) ## 0.4.0 (2024-05-26) @@ -44,7 +45,7 @@ * Python: Added SDIFFSTORE command ([#1449](https://github.com/aws/glide-for-redis/pull/1449)) * Python: Added SINTERSTORE command ([#1459](https://github.com/aws/glide-for-redis/pull/1459)) * Python: Added SMISMEMBER command ([#1461](https://github.com/aws/glide-for-redis/pull/1461)) -* Python: Added SETRANGE command ([#1453](https://github.com/aws/glide-for-redis/pull/1453) +* Python: Added SETRANGE command ([#1453](https://github.com/aws/glide-for-redis/pull/1453)) #### Fixes * Python: Fix typing error "‘type’ object is not subscriptable" ([#1203](https://github.com/aws/glide-for-redis/pull/1203)) diff --git a/glide-core/src/client/value_conversion.rs b/glide-core/src/client/value_conversion.rs index 8b3569015e..fedf951b5d 100644 --- a/glide-core/src/client/value_conversion.rs +++ b/glide-core/src/client/value_conversion.rs @@ -26,6 +26,7 @@ pub(crate) enum ExpectedReturnType { ArrayOfMemberScorePairs, ZMPopReturnType, KeyWithMemberAndScore, + GeoSearchReturnType(bool, bool, bool), } pub(crate) fn convert_to_expected_type( @@ -379,6 +380,68 @@ pub(crate) fn convert_to_expected_type( ) .into()), }, + ExpectedReturnType::GeoSearchReturnType(dist, hash, coord) => match value { + Value::Array(array) => { + if !dist && !hash && !coord { + return Ok(Value::Array(array)); + } + + let mut converted_array = Vec::with_capacity(array.len()); + let num_true = dist as usize + hash as usize + coord as usize; + for item in array.iter() { + if let Value::Array(inner_array) = item { + if let Some((name, rest)) = inner_array.split_first() { + if rest.len() != num_true { + return Err(( + ErrorKind::TypeError, + "Response couldn't be converted to GeoSearchReturn type", + format!("(response was {:?})", array), + ) + .into()); + } + let rest = rest + .iter() + .map(|v| match v { + Value::Array(coord) => { + if coord.len() != 2 { + return Err(( + ErrorKind::TypeError, + "Inner Array must contain exactly two elements, longitude and latitude", + ) + .into()); + } + let longitude = convert_to_expected_type( + coord[0].clone(), + Some(ExpectedReturnType::Double), + )?; + let latitude = convert_to_expected_type( + coord[1].clone(), + Some(ExpectedReturnType::Double), + )?; + Ok(Value::Array(vec![longitude, latitude])) + } + Value::BulkString(dist) => convert_to_expected_type( + Value::BulkString(dist.to_vec()), + Some(ExpectedReturnType::Double), + ), + _ => Ok(v).cloned(), + }) + .collect::, _>>()?; + + converted_array + .push(Value::Array(vec![name.clone(), Value::Array(rest)])); + } + } + } + Ok(Value::Array(converted_array)) + } + _ => Err(( + ErrorKind::TypeError, + "Response couldn't be converted to GeoSearchReturnType", + format!("(response was {:?})", value), + ) + .into()), + }, } } @@ -567,6 +630,11 @@ pub(crate) fn expected_type_for_cmd(cmd: &Cmd) -> Option { } } b"LOLWUT" => Some(ExpectedReturnType::Lolwut), + b"GEOSEARCH" => Some(ExpectedReturnType::GeoSearchReturnType( + cmd.position(b"WITHDIST").is_some(), + cmd.position(b"WITHHASH").is_some(), + cmd.position(b"WITHCOORD").is_some(), + )), _ => None, } } diff --git a/glide-core/src/protobuf/redis_request.proto b/glide-core/src/protobuf/redis_request.proto index a2ef588281..026e997358 100644 --- a/glide-core/src/protobuf/redis_request.proto +++ b/glide-core/src/protobuf/redis_request.proto @@ -194,6 +194,7 @@ enum RequestType { ExpireTime = 156; PExpireTime = 157; BLMPop = 158; + GeoSearch = 159; } message Command { diff --git a/glide-core/src/request_type.rs b/glide-core/src/request_type.rs index be4befae0b..f82014d733 100644 --- a/glide-core/src/request_type.rs +++ b/glide-core/src/request_type.rs @@ -164,6 +164,7 @@ pub enum RequestType { ExpireTime = 156, PExpireTime = 157, BLMPop = 158, + GeoSearch = 159, } fn get_two_word_command(first: &str, second: &str) -> Cmd { @@ -331,6 +332,7 @@ impl From<::protobuf::EnumOrUnknown> for RequestType { ProtobufRequestType::HStrlen => RequestType::HStrlen, ProtobufRequestType::ExpireTime => RequestType::ExpireTime, ProtobufRequestType::PExpireTime => RequestType::PExpireTime, + ProtobufRequestType::GeoSearch => RequestType::GeoSearch, } } } @@ -494,6 +496,7 @@ impl RequestType { RequestType::HStrlen => Some(cmd("HSTRLEN")), RequestType::ExpireTime => Some(cmd("EXPIRETIME")), RequestType::PExpireTime => Some(cmd("PEXPIRETIME")), + RequestType::GeoSearch => Some(cmd("GEOSEARCH")), } } } diff --git a/python/python/glide/__init__.py b/python/python/glide/__init__.py index 8928249303..9a9c12dd00 100644 --- a/python/python/glide/__init__.py +++ b/python/python/glide/__init__.py @@ -5,8 +5,6 @@ ExpireOptions, ExpirySet, ExpiryType, - GeospatialData, - GeoUnit, InfoSection, InsertPosition, StreamAddOptions, @@ -18,6 +16,11 @@ from glide.async_commands.redis_modules import json from glide.async_commands.sorted_set import ( AggregationType, + GeoSearchByBox, + GeoSearchByRadius, + GeoSearchCount, + GeospatialData, + GeoUnit, InfBound, LexBoundary, Limit, @@ -89,6 +92,9 @@ "ExpireOptions", "ExpirySet", "ExpiryType", + "GeoSearchByBox", + "GeoSearchByRadius", + "GeoSearchCount", "GeoUnit", "GeospatialData", "AggregationType", diff --git a/python/python/glide/async_commands/core.py b/python/python/glide/async_commands/core.py index cd86ef09fb..d35bc62e2e 100644 --- a/python/python/glide/async_commands/core.py +++ b/python/python/glide/async_commands/core.py @@ -18,13 +18,20 @@ from glide.async_commands.sorted_set import ( AggregationType, + GeoSearchByBox, + GeoSearchByRadius, + GeoSearchCount, + GeospatialData, + GeoUnit, InfBound, LexBoundary, + OrderBy, RangeByIndex, RangeByLex, RangeByScore, ScoreBoundary, ScoreFilter, + _create_geosearch_args, _create_z_cmd_store_args, _create_zrange_args, ) @@ -135,21 +142,62 @@ class UpdateOptions(Enum): GREATER_THAN = "GT" -class GeospatialData: - def __init__(self, longitude: float, latitude: float): - """ - Represents a geographic position defined by longitude and latitude. - - The exact limits, as specified by EPSG:900913 / EPSG:3785 / OSGEO:41001 are the following: - - Valid longitudes are from -180 to 180 degrees. - - Valid latitudes are from -85.05112878 to 85.05112878 degrees. +class ExpirySet: + """SET option: Represents the expiry type and value to be executed with "SET" command.""" + def __init__( + self, + expiry_type: ExpiryType, + value: Optional[Union[int, datetime, timedelta]], + ) -> None: + """ Args: - longitude (float): The longitude coordinate. - latitude (float): The latitude coordinate. + - expiry_type (ExpiryType): The expiry type. + - value (Optional[Union[int, datetime, timedelta]]): The value of the expiration type. The type of expiration + determines the type of expiration value: + - SEC: Union[int, timedelta] + - MILLSEC: Union[int, timedelta] + - UNIX_SEC: Union[int, datetime] + - UNIX_MILLSEC: Union[int, datetime] + - KEEP_TTL: Type[None] """ - self.longitude = longitude - self.latitude = latitude + self.set_expiry_type_and_value(expiry_type, value) + + def set_expiry_type_and_value( + self, expiry_type: ExpiryType, value: Optional[Union[int, datetime, timedelta]] + ): + if not isinstance(value, get_args(expiry_type.value[1])): + raise ValueError( + f"The value of {expiry_type} should be of type {expiry_type.value[1]}" + ) + self.expiry_type = expiry_type + if self.expiry_type == ExpiryType.SEC: + self.cmd_arg = "EX" + if isinstance(value, timedelta): + value = int(value.total_seconds()) + elif self.expiry_type == ExpiryType.MILLSEC: + self.cmd_arg = "PX" + if isinstance(value, timedelta): + value = int(value.total_seconds() * 1000) + elif self.expiry_type == ExpiryType.UNIX_SEC: + self.cmd_arg = "EXAT" + if isinstance(value, datetime): + value = int(value.timestamp()) + elif self.expiry_type == ExpiryType.UNIX_MILLSEC: + self.cmd_arg = "PXAT" + if isinstance(value, datetime): + value = int(value.timestamp() * 1000) + elif self.expiry_type == ExpiryType.KEEP_TTL: + self.cmd_arg = "KEEPTTL" + self.value = str(value) if value else None + + def get_cmd_args(self) -> List[str]: + return [self.cmd_arg] if self.value is None else [self.cmd_arg, self.value] + + +class InsertPosition(Enum): + BEFORE = "BEFORE" + AFTER = "AFTER" class StreamTrimOptions(ABC): @@ -167,7 +215,6 @@ def __init__( ): """ Initialize stream trim options. - Args: exact (bool): If `true`, the stream will be trimmed exactly. Otherwise the stream will be trimmed in a near-exact manner, which is more efficient. @@ -188,7 +235,6 @@ def __init__( def to_args(self) -> List[str]: """ Convert options to arguments for Redis command. - Returns: List[str]: List of arguments for Redis command. """ @@ -210,7 +256,6 @@ class TrimByMinId(StreamTrimOptions): def __init__(self, exact: bool, threshold: str, limit: Optional[int] = None): """ Initialize trim option by minimum ID. - Args: exact (bool): If `true`, the stream will be trimmed exactly. Otherwise the stream will be trimmed in a near-exact manner, which is more efficient. @@ -229,7 +274,6 @@ class TrimByMaxLen(StreamTrimOptions): def __init__(self, exact: bool, threshold: int, limit: Optional[int] = None): """ Initialize trim option by maximum length. - Args: exact (bool): If `true`, the stream will be trimmed exactly. Otherwise the stream will be trimmed in a near-exact manner, which is more efficient. @@ -253,7 +297,6 @@ def __init__( ): """ Initialize stream add options. - Args: id (Optional[str]): ID for the new entry. If set, the new entry will be added with this ID. If not specified, '*' is used. make_stream (bool, optional): If set to False, a new stream won't be created if no stream matches the given key. @@ -266,7 +309,6 @@ def __init__( def to_args(self) -> List[str]: """ Convert options to arguments for Redis command. - Returns: List[str]: List of arguments for Redis command. """ @@ -280,87 +322,6 @@ def to_args(self) -> List[str]: return option_args -class GeoUnit(Enum): - """ - Enumeration representing distance units options for the `GEODIST` command. - """ - - METERS = "m" - """ - Represents distance in meters. - """ - KILOMETERS = "km" - """ - Represents distance in kilometers. - """ - MILES = "mi" - """ - Represents distance in miles. - """ - FEET = "ft" - """ - Represents distance in feet. - """ - - -class ExpirySet: - """SET option: Represents the expiry type and value to be executed with "SET" command.""" - - def __init__( - self, - expiry_type: ExpiryType, - value: Optional[Union[int, datetime, timedelta]], - ) -> None: - """ - Args: - - expiry_type (ExpiryType): The expiry type. - - value (Optional[Union[int, datetime, timedelta]]): The value of the expiration type. The type of expiration - determines the type of expiration value: - - SEC: Union[int, timedelta] - - MILLSEC: Union[int, timedelta] - - UNIX_SEC: Union[int, datetime] - - UNIX_MILLSEC: Union[int, datetime] - - KEEP_TTL: Type[None] - """ - self.set_expiry_type_and_value(expiry_type, value) - - def set_expiry_type_and_value( - self, expiry_type: ExpiryType, value: Optional[Union[int, datetime, timedelta]] - ): - if not isinstance(value, get_args(expiry_type.value[1])): - raise ValueError( - f"The value of {expiry_type} should be of type {expiry_type.value[1]}" - ) - self.expiry_type = expiry_type - if self.expiry_type == ExpiryType.SEC: - self.cmd_arg = "EX" - if isinstance(value, timedelta): - value = int(value.total_seconds()) - elif self.expiry_type == ExpiryType.MILLSEC: - self.cmd_arg = "PX" - if isinstance(value, timedelta): - value = int(value.total_seconds() * 1000) - elif self.expiry_type == ExpiryType.UNIX_SEC: - self.cmd_arg = "EXAT" - if isinstance(value, datetime): - value = int(value.timestamp()) - elif self.expiry_type == ExpiryType.UNIX_MILLSEC: - self.cmd_arg = "PXAT" - if isinstance(value, datetime): - value = int(value.timestamp() * 1000) - elif self.expiry_type == ExpiryType.KEEP_TTL: - self.cmd_arg = "KEEPTTL" - self.value = str(value) if value else None - - def get_cmd_args(self) -> List[str]: - return [self.cmd_arg] if self.value is None else [self.cmd_arg, self.value] - - -class InsertPosition(Enum): - BEFORE = "BEFORE" - AFTER = "AFTER" - - class CoreCommands(Protocol): async def _execute_command( self, @@ -2306,6 +2267,72 @@ async def geopos( await self._execute_command(RequestType.GeoPos, [key] + members), ) + async def geosearch( + self, + key: str, + search_from: Union[str, GeospatialData], + seach_by: Union[GeoSearchByRadius, GeoSearchByBox], + order_by: Optional[OrderBy] = None, + count: Optional[GeoSearchCount] = None, + with_coord: bool = False, + with_dist: bool = False, + with_hash: bool = False, + ) -> List[Union[str, List[Union[str, float, int, List[float]]]]]: + """ + Searches for members in a sorted set stored at `key` representing geospatial data within a circular or rectengal area. + + See https://valkey.io/commands/geosearch/ for more details. + + Args: + key (str): The key of the sorted set representing geospatial data. + search_from (Union[str, GeospatialData]): The location to search from. Can be specified either as a member + from the sorted set or as a geospatial data (see `GeospatialData`). + search_by (Union[GeoSearchByRadius, GeoSearchByBox]): The search criteria. + For circular area search, see `GeoSearchByRadius`. + For rectengal area search, see `GeoSearchByBox`. + order_by (Optional[OrderBy]): Specifies the order in which the results should be returned. + Defaults to None. See `OrderBy`. + count (Optional[GeoSearchCount]): Specifies the maximum number of results to return. + Defaults to None. See `GeoSearchCount`. + with_coord (bool): Whether to include coordinates of the returned items. Defaults to False. + with_dist (bool): Whether to include distance from the center in the returned items. + The distance is returned in the same unit as specified for the `search_by` arguments. Defaults to False. + with_hash (bool): Whether to include geohash of the returned items. Defaults to False. + + Returns: + List[Union[str, List[Union[str, float, int, List[float]]]]]: If no with_* was set to True, returns a list of memebers + names. + If `with_coord`, `with_dist` or `with_hash` was specified, returns an array of arrays, we're each sub array represents a single item with: + (str): The member name. + (float): The distance from the center as a floating point number, in the same unit specified in the radius, if `with_dist` is set to True. + (int) The Geohash integer, if `with_dist` is set to True. + List[float]: The coordinates as a two items x,y array (longitude,latitude), if `with_dist` is set to True. + + Examples: + >>> await client.geoadd("my_geo_sorted_set", {"edge1": GeospatialData(12.758489, 38.788135), "edge2": GeospatialData(17.241510, 38.788135)}}) + >>> await client.geoadd("my_geo_sorted_set", {"Palermo": GeospatialData(13.361389, 38.115556), "Catania": GeospatialData(15.087269, 37.502669)}) + >>> await client.geosearch("my_location", "search_center", GeoSearchByRadius(5.0, GeoUnit.KILOMETERS)) + [ + ['Catania', [56.4413, 3479447370796909, [15.087267458438873, 37.50266842333162]]], + ['Palermo', [190.4424, 3479099956230698, [13.361389338970184, 38.1155563954963]]] + ] + """ + args = _create_geosearch_args( + key, + search_from, + seach_by, + order_by, + count, + with_coord, + with_dist, + with_hash, + ) + + return cast( + List[Union[str, List[Union[str, float, int, List[float]]]]], + await self._execute_command(RequestType.GeoSearch, args), + ) + async def zadd( self, key: str, diff --git a/python/python/glide/async_commands/sorted_set.py b/python/python/glide/async_commands/sorted_set.py index 9fff352159..73a97d0ab8 100644 --- a/python/python/glide/async_commands/sorted_set.py +++ b/python/python/glide/async_commands/sorted_set.py @@ -94,6 +94,7 @@ class Limit: The optional LIMIT argument can be used to obtain a sub-range from the matching elements (similar to SELECT LIMIT offset, count in SQL). + Args: offset (int): The offset from the start of the range. count (int): The number of elements to include in the range. @@ -169,6 +170,145 @@ def __init__( self.limit = limit +class GeospatialData: + """ + Represents a geographic position defined by longitude and latitude. + + The exact limits, as specified by EPSG:900913 / EPSG:3785 / OSGEO:41001 are the following: + - Valid longitudes are from -180 to 180 degrees. + - Valid latitudes are from -85.05112878 to 85.05112878 degrees. + + Args: + longitude (float): The longitude coordinate. + latitude (float): The latitude coordinate. + """ + + def __init__(self, longitude: float, latitude: float): + self.longitude = longitude + self.latitude = latitude + + +class GeoUnit(Enum): + """ + Enumeration representing distance units options for the `GEODIST` command. + """ + + METERS = "m" + """ + Represents distance in meters. + """ + KILOMETERS = "km" + """ + Represents distance in kilometers. + """ + MILES = "mi" + """ + Represents distance in miles. + """ + FEET = "ft" + """ + Represents distance in feet. + """ + + +class OrderBy(Enum): + """ + Enumeration representing sorting order options for `GEOSEARCH` command. + """ + + ASCENDING = "ASC" + """ + Sort returned items from the nearest to the farthest, relative to the center point. + """ + DESCENDING = "DESC" + """ + Sort returned items from the farthest to the nearest, relative to the center point. + """ + + +class GeoSearchByRadius: + """ + Represents search criteria of searching within a certain radius from a specified point. + + Args: + radius (float): Radius of the search area. + unit (GeoUnit): Unit of the radius. See `GeoUnit`. + """ + + def __init__(self, radius: float, unit: GeoUnit): + """ + Initialize the search criteria. + """ + self.radius = radius + self.unit = unit + + def __str__(self) -> str: + """ + Convert the search criteria to the corresponding part of the Redis command. + + Returns: + str: String representation of the search criteria. + """ + return f"BYRADIUS {self.radius} {self.unit.value}" + + +class GeoSearchByBox: + """ + Represents search criteria of searching within a specified rectangular area. + + Args: + width (float): Width of the bounding box. + height (float): Height of the bounding box + unit (GeoUnit): Unit of the radius. See `GeoUnit`. + """ + + def __init__(self, width: float, height: float, unit: GeoUnit): + """ + Initialize the search criteria. + """ + self.width = width + self.height = height + self.unit = unit + + def __str__(self) -> str: + """ + Convert the search criteria to the corresponding part of the Redis command. + + Returns: + str: String representation of the search criteria. + """ + return f"BYBOX {self.width} {self.height} {self.unit.value}" + + +class GeoSearchCount: + """ + Represents the count option for limiting the number of results in a GeoSearch. + + Args: + count (int): The maximum number of results to return. + any_option (bool): Whether to allow returning as enough matches are found. + This means that the results returned may not be the ones closest to the specified point. Default to False. + """ + + def __init__(self, count: int, any_option: bool = False): + """ + Initialize the count option. + """ + self.count = count + self.any_option = any_option + + def __str__(self) -> str: + """ + Convert the count option to the corresponding part of the Redis command. + + Returns: + str: String representation of the count option. + """ + if self.any_option: + return f"COUNT {self.count} ANY" + return f"COUNT {self.count}" + + def _create_zrange_args( key: str, range_query: Union[RangeByLex, RangeByScore, RangeByIndex], @@ -240,3 +380,40 @@ def _create_z_cmd_store_args( args.append(aggregation_type.value) return args + + +def _create_geosearch_args( + key: str, + search_from: Union[str, GeospatialData], + seach_by: Union[GeoSearchByRadius, GeoSearchByBox], + order_by: Optional[OrderBy] = None, + count: Optional[GeoSearchCount] = None, + with_coord: bool = False, + with_dist: bool = False, + with_hash: bool = False, +) -> List[str]: + args = [key] + if isinstance(search_from, str): + args += ["FROMMEMBER", search_from] + else: + args += [ + "FROMLONLAT", + str(search_from.longitude), + str(search_from.latitude), + ] + + args += str(seach_by).split() + + if order_by: + args.append(order_by.value) + if count: + args.extend(str(count).split()) + + if with_coord: + args.append("WITHCOORD") + if with_dist: + args.append("WITHDIST") + if with_hash: + args.append("WITHHASH") + + return args diff --git a/python/python/glide/async_commands/transaction.py b/python/python/glide/async_commands/transaction.py index 80130da109..23091a6db1 100644 --- a/python/python/glide/async_commands/transaction.py +++ b/python/python/glide/async_commands/transaction.py @@ -17,13 +17,18 @@ ) from glide.async_commands.sorted_set import ( AggregationType, + GeoSearchByBox, + GeoSearchByRadius, + GeoSearchCount, InfBound, LexBoundary, + OrderBy, RangeByIndex, RangeByLex, RangeByScore, ScoreBoundary, ScoreFilter, + _create_geosearch_args, _create_z_cmd_store_args, _create_zrange_args, ) @@ -1590,6 +1595,69 @@ def geopos( """ return self.append_command(RequestType.GeoPos, [key] + members) + def geosearch( + self: TTransaction, + key: str, + search_from: Union[str, GeospatialData], + seach_by: Union[GeoSearchByRadius, GeoSearchByBox], + order_by: Optional[OrderBy] = None, + count: Optional[GeoSearchCount] = None, + with_coord: bool = False, + with_dist: bool = False, + with_hash: bool = False, + ) -> TTransaction: + """ + Searches for members in a sorted set stored at `key` representing geospatial data within a circular or rectengal area. + + See https://valkey.io/commands/geosearch/ for more details. + + Args: + key (str): The key of the sorted set representing geospatial data. + search_from (Union[str, GeospatialData]): The location to search from. Can be specified either as a member + from the sorted set or as a geospatial data (see `GeospatialData`). + search_by (Union[GeoSearchByRadius, GeoSearchByBox]): The search criteria. + For circular area search, see `GeoSearchByRadius`. + For rectengal area search, see `GeoSearchByBox`. + order_by (Optional[OrderBy]): Specifies the order in which the results should be returned. + Defaults to None. See `OrderBy`. + count (Optional[GeoSearchCount]): Specifies the maximum number of results to return. + Defaults to None. See `GeoSearchCount`. + with_coord (bool): Whether to include coordinates of the returned items. Defaults to False. + with_dist (bool): Whether to include distance from the center in the returned items. + The distance is returned in the same unit as specified for the `search_by` arguments. Defaults to False. + with_hash (bool): Whether to include geohash of the returned items. Defaults to False. + + Returns: + List[Union[str, List[Union[str, float, int, List[float]]]]]: If no with_* was set to True, returns a list of memebers + names. + If `with_coord`, `with_dist` or `with_hash` was specified, returns an array of arrays, we're each sub array represents a single item with: + (str): The member name. + (float): The distance from the center as a floating point number, in the same unit specified in the radius, if `with_dist` is set to True. + (int) The Geohash integer, if `with_dist` is set to True. + List[float]: The coordinates as a two items x,y array (longitude,latitude), if `with_dist` is set to True. + + Examples: + >>> await client.geoadd("my_geo_sorted_set", {"edge1": GeospatialData(12.758489, 38.788135), "edge2": GeospatialData(17.241510, 38.788135)}}) + >>> await client.geoadd("my_geo_sorted_set", {"Palermo": GeospatialData(13.361389, 38.115556), "Catania": GeospatialData(15.087269, 37.502669)}) + >>> await client.geosearch("my_location", "search_center", GeoSearchByRadius(5.0, GeoUnit.KILOMETERS)) + [ + ['Catania', [56.4413, 3479447370796909, [15.087267458438873, 37.50266842333162]]], + ['Palermo', [190.4424, 3479099956230698, [13.361389338970184, 38.1155563954963]]] + ] + """ + args = _create_geosearch_args( + key, + search_from, + seach_by, + order_by, + count, + with_coord, + with_dist, + with_hash, + ) + + return self.append_command(RequestType.GeoSearch, args) + def zadd( self: TTransaction, key: str, diff --git a/python/python/tests/test_async_client.py b/python/python/tests/test_async_client.py index 0e639a4afe..eb8f023fa3 100644 --- a/python/python/tests/test_async_client.py +++ b/python/python/tests/test_async_client.py @@ -28,9 +28,13 @@ ) from glide.async_commands.sorted_set import ( AggregationType, + GeoSearchByBox, + GeoSearchByRadius, + GeoSearchCount, InfBound, LexBoundary, Limit, + OrderBy, RangeByIndex, RangeByLex, RangeByScore, @@ -1664,6 +1668,244 @@ async def test_geoadd_invalid_args(self, redis_client: TRedisClient): with pytest.raises(RequestError): await redis_client.geoadd(key, {"Place": GeospatialData(0, -86)}) + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_geosearch_by_box(self, redis_client: TRedisClient): + key = get_random_string(10) + members = ["Catania", "Palermo", "edge2", "edge1"] + members_coordinates = { + "Palermo": GeospatialData(13.361389, 38.115556), + "Catania": GeospatialData(15.087269, 37.502669), + "edge1": GeospatialData(12.758489, 38.788135), + "edge2": GeospatialData(17.241510, 38.788135), + } + result = [ + [ + "Catania", + [56.4413, 3479447370796909, [15.087267458438873, 37.50266842333162]], + ], + [ + "Palermo", + [190.4424, 3479099956230698, [13.361389338970184, 38.1155563954963]], + ], + [ + "edge2", + [279.7403, 3481342659049484, [17.241510450839996, 38.78813451624225]], + ], + [ + "edge1", + [279.7405, 3479273021651468, [12.75848776102066, 38.78813451624225]], + ], + ] + assert await redis_client.geoadd(key, members_coordinates) == 4 + + # Test search by box, unit: kilometers, from a geospatial data + assert ( + await redis_client.geosearch( + key, + GeospatialData(15, 37), + GeoSearchByBox(400, 400, GeoUnit.KILOMETERS), + OrderBy.ASCENDING, + ) + == members + ) + + assert ( + await redis_client.geosearch( + key, + GeospatialData(15, 37), + GeoSearchByBox(400, 400, GeoUnit.KILOMETERS), + OrderBy.DESCENDING, + with_coord=True, + with_dist=True, + with_hash=True, + ) + == result[::-1] + ) + + # Test search by box, unit: meters, from a member + meters = 400 * 1000 + assert await redis_client.geosearch( + key, + "Catania", + GeoSearchByBox(meters, meters, GeoUnit.METERS), + OrderBy.DESCENDING, + ) == ["edge2", "Palermo", "Catania"] + + # Test search by box, unit: feet, from a member, with limited count to 2 + feet = 400 * 3280.8399 + assert await redis_client.geosearch( + key, + "Palermo", + GeoSearchByBox(feet, feet, GeoUnit.FEET), + OrderBy.ASCENDING, + count=GeoSearchCount(2), + ) == ["Palermo", "edge1"] + + # Test search by box, unit: miles, from a geospatial data, with limited ANY count to 1 + assert ( + await redis_client.geosearch( + key, + GeospatialData(15, 37), + GeoSearchByBox(250, 250, GeoUnit.MILES), + OrderBy.ASCENDING, + count=GeoSearchCount(1, True), + ) + )[0] in members + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_geosearch_by_radius(self, redis_client: TRedisClient): + key = get_random_string(10) + members_coordinates = { + "Palermo": GeospatialData(13.361389, 38.115556), + "Catania": GeospatialData(15.087269, 37.502669), + "edge1": GeospatialData(12.758489, 38.788135), + "edge2": GeospatialData(17.241510, 38.788135), + } + result = [ + [ + "Catania", + [56.4413, 3479447370796909, [15.087267458438873, 37.50266842333162]], + ], + [ + "Palermo", + [190.4424, 3479099956230698, [13.361389338970184, 38.1155563954963]], + ], + ] + members = ["Catania", "Palermo", "edge2", "edge1"] + assert await redis_client.geoadd(key, members_coordinates) == 4 + + # Test search by radius, units: feet, from a member + feet = 200 * 3280.8399 + assert ( + await redis_client.geosearch( + key, + "Catania", + GeoSearchByRadius(feet, GeoUnit.FEET), + OrderBy.ASCENDING, + ) + == members[:2] + ) + + # Test search by radius, units: meters, from a member + meters = 200 * 1000 + assert ( + await redis_client.geosearch( + key, + "Catania", + GeoSearchByRadius(meters, GeoUnit.METERS), + OrderBy.DESCENDING, + ) + == members[:2][::-1] + ) + + # Test search by box, unit: miles, from a geospatial data, with limited count to 1 + assert ( + await redis_client.geosearch( + key, + GeospatialData(15, 37), + GeoSearchByRadius(175, GeoUnit.MILES), + OrderBy.DESCENDING, + ) + == members[::-1] + ) + + # Test search by box, unit: kilometers, from a geospatial data, with limited count to 2 + assert ( + await redis_client.geosearch( + key, + GeospatialData(15, 37), + GeoSearchByRadius(200, GeoUnit.KILOMETERS), + OrderBy.ASCENDING, + count=GeoSearchCount(2), + with_coord=True, + with_dist=True, + with_hash=True, + ) + == result[:2] + ) + + # Test search by box, unit: kilometers, from a geospatial data, with limited ANY count to 1 + assert ( + await redis_client.geosearch( + key, + GeospatialData(15, 37), + GeoSearchByRadius(200, GeoUnit.KILOMETERS), + OrderBy.ASCENDING, + count=GeoSearchCount(1, True), + with_coord=True, + with_dist=True, + with_hash=True, + ) + )[0] in result + + @pytest.mark.parametrize("cluster_mode", [True, False]) + @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) + async def test_geosearch_no_result(self, redis_client: TRedisClient): + key = get_random_string(10) + members_coordinates = { + "Palermo": GeospatialData(13.361389, 38.115556), + "Catania": GeospatialData(15.087269, 37.502669), + "edge1": GeospatialData(12.758489, 38.788135), + "edge2": GeospatialData(17.241510, 38.788135), + } + assert await redis_client.geoadd(key, members_coordinates) == 4 + + # No membes within the aea + assert ( + await redis_client.geosearch( + key, + GeospatialData(15, 37), + GeoSearchByBox(50, 50, GeoUnit.METERS), + OrderBy.ASCENDING, + ) + == [] + ) + + assert ( + await redis_client.geosearch( + key, + GeospatialData(15, 37), + GeoSearchByRadius(10, GeoUnit.METERS), + OrderBy.ASCENDING, + ) + == [] + ) + + # No members in the area (apart from the member we seach fom itself) + assert await redis_client.geosearch( + key, + "Catania", + GeoSearchByBox(10, 10, GeoUnit.KILOMETERS), + OrderBy.DESCENDING, + ) == ["Catania"] + + assert await redis_client.geosearch( + key, + "Catania", + GeoSearchByRadius(10, GeoUnit.METERS), + OrderBy.DESCENDING, + ) == ["Catania"] + + # Search from non exiting memeber + with pytest.raises(RequestError): + await redis_client.geosearch( + key, + "non_existing_member", + GeoSearchByBox(10, 10, GeoUnit.MILES), + OrderBy.DESCENDING, + ) + + assert await redis_client.set(key, "foo") == OK + with pytest.raises(RequestError): + await redis_client.geosearch( + key, + "Catania", + GeoSearchByBox(10, 10, GeoUnit.MILES), + OrderBy.DESCENDING, + ) + @pytest.mark.parametrize("cluster_mode", [True, False]) @pytest.mark.parametrize("protocol", [ProtocolVersion.RESP2, ProtocolVersion.RESP3]) async def test_geohash(self, redis_client: TRedisClient): diff --git a/python/python/tests/test_transaction.py b/python/python/tests/test_transaction.py index a8cb937a88..42ac275b11 100644 --- a/python/python/tests/test_transaction.py +++ b/python/python/tests/test_transaction.py @@ -13,8 +13,11 @@ ) from glide.async_commands.sorted_set import ( AggregationType, + GeoSearchByRadius, + GeoUnit, InfBound, LexBoundary, + OrderBy, RangeByIndex, ScoreBoundary, ScoreFilter, @@ -326,6 +329,10 @@ async def transaction_test( None, ] ) + transaction.geosearch( + key12, "Catania", GeoSearchByRadius(200, GeoUnit.KILOMETERS), OrderBy.ASCENDING + ) + args.append(["Catania", "Palermo"]) transaction.xadd(key11, [("foo", "bar")], StreamAddOptions(id="0-1")) args.append("0-1")