Skip to content
This repository has been archived by the owner on Jul 16, 2024. It is now read-only.

Migrate from aioredis to redis.asyncio #134

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/134.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Migrate from aioredis to redis-py v4.2.1 in favor of the official support
6 changes: 4 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ install_requires =
pyzmq>=22.1.0
aiohttp>=3.8.0
aiodns>=3.0
aioredis[hiredis]~=2.0.1
redis~=4.2.1
hiredis~=2.0
aiotools>=1.5.5
async-timeout~=4.0.1
asyncudp>=0.4
Expand Down Expand Up @@ -87,8 +88,9 @@ lint =
typecheck =
mypy>=0.942
types-python-dateutil
types-toml
types-redis
types-setuptools
types-toml
dev =
monitor =
backend.ai-monitor-sentry>=0.2.1
Expand Down
20 changes: 10 additions & 10 deletions src/ai/backend/common/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@
from typing_extensions import TypeAlias
import uuid

import aioredis
import aioredis.exceptions
import aioredis.sentinel
import redis.asyncio
import redis.asyncio.sentinel
import redis.exceptions
from aiotools.context import aclosing
from aiotools.server import process_index
from aiotools.taskgroup import PersistentTaskGroup
import attr

from . import msgpack, redis
from . import msgpack, redis_helper
from .logging import BraceStyleAdapter
from .types import (
EtcdRedisConfig,
Expand Down Expand Up @@ -522,7 +522,7 @@ class BgtaskFailedEvent(BgtaskDoneEventArgs, AbstractEvent):
class RedisConnectorFunc(Protocol):
def __call__(
self,
) -> aioredis.ConnectionPool:
) -> redis.asyncio.ConnectionPool:
...


Expand Down Expand Up @@ -650,7 +650,7 @@ def __init__(
_redis_config = redis_config.copy()
if service_name:
_redis_config['service_name'] = service_name
self.redis_client = redis.get_redis_object(_redis_config, db=db)
self.redis_client = redis_helper.get_redis_object(_redis_config, db=db)
self._log_events = log_events
self._closed = False
self.consumers = defaultdict(set)
Expand Down Expand Up @@ -778,7 +778,7 @@ async def dispatch_subscribers(
await asyncio.sleep(0)

async def _consume_loop(self) -> None:
async with aclosing(redis.read_stream_by_group(
async with aclosing(redis_helper.read_stream_by_group(
self.redis_client,
self._stream_key,
self._consumer_group,
Expand All @@ -801,7 +801,7 @@ async def _consume_loop(self) -> None:
log.exception('EventDispatcher.consume(): unexpected-error')

async def _subscribe_loop(self) -> None:
async with aclosing(redis.read_stream(
async with aclosing(redis_helper.read_stream(
self.redis_client,
self._stream_key,
)) as agen:
Expand Down Expand Up @@ -839,7 +839,7 @@ def __init__(
if service_name:
_redis_config['service_name'] = service_name
self._closed = False
self.redis_client = redis.get_redis_object(_redis_config, db=db)
self.redis_client = redis_helper.get_redis_object(_redis_config, db=db)
self._log_events = log_events
self._stream_key = stream_key

Expand All @@ -863,7 +863,7 @@ async def produce_event(
b'source': source.encode(),
b'args': msgpack.packb(event.serialize()),
}
await redis.execute(
await redis_helper.execute(
self.redis_client,
lambda r: r.xadd(self._stream_key, raw_event), # type: ignore # aio-libs/aioredis-py#1182
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
Union,
)

import aioredis
import aioredis.client
import aioredis.sentinel
import aioredis.exceptions
import redis.asyncio
import redis.asyncio.client
import redis.asyncio.sentinel
import redis.client
import redis.exceptions
import yarl

from .types import EtcdRedisConfig, RedisConnectionInfo
Expand Down Expand Up @@ -76,7 +77,7 @@ def _parse_stream_msg_id(msg_id: bytes) -> Tuple[int, int]:


async def subscribe(
channel: aioredis.client.PubSub,
channel: redis.asyncio.client.PubSub,
*,
reconnect_poll_interval: float = 0.3,
) -> AsyncIterator[Any]:
Expand All @@ -88,7 +89,7 @@ async def _reset_chan():
channel.connection = None
try:
await channel.ping()
except aioredis.exceptions.ConnectionError:
except redis.exceptions.ConnectionError:
pass
else:
assert channel.connection is not None
Expand All @@ -102,24 +103,24 @@ async def _reset_chan():
if message is not None:
yield message["data"]
except (
aioredis.exceptions.ConnectionError,
aioredis.sentinel.MasterNotFoundError,
aioredis.sentinel.SlaveNotFoundError,
aioredis.exceptions.ReadOnlyError,
aioredis.exceptions.ResponseError,
redis.exceptions.ConnectionError,
redis.asyncio.sentinel.MasterNotFoundError,
redis.asyncio.sentinel.SlaveNotFoundError,
redis.exceptions.ReadOnlyError,
redis.exceptions.ResponseError,
ConnectionResetError,
ConnectionNotAvailable,
):
await asyncio.sleep(reconnect_poll_interval)
await _reset_chan()
continue
except aioredis.exceptions.ResponseError as e:
except redis.exceptions.ResponseError as e:
if e.args[0].startswith("NOREPLICAS "):
await asyncio.sleep(reconnect_poll_interval)
await _reset_chan()
continue
raise
except (TimeoutError, asyncio.TimeoutError):
except (redis.exceptions.TimeoutError, asyncio.TimeoutError):
continue
except asyncio.CancelledError:
raise
Expand All @@ -128,7 +129,7 @@ async def _reset_chan():


async def blpop(
redis: RedisConnectionInfo | aioredis.Redis | aioredis.sentinel.Sentinel,
redis_connector: RedisConnectionInfo | redis.asyncio.Redis | redis.asyncio.Sentinel,
key: str,
*,
service_name: str = None,
Expand All @@ -142,18 +143,18 @@ async def blpop(
**_default_conn_opts,
'socket_connect_timeout': reconnect_poll_interval,
}
if isinstance(redis, RedisConnectionInfo):
redis_client = redis.client
service_name = service_name or redis.service_name
if isinstance(redis_connector, RedisConnectionInfo):
redis_client = redis_connector.client
service_name = service_name or redis_connector.service_name
else:
redis_client = redis
redis_client = redis_connector

if isinstance(redis_client, aioredis.sentinel.Sentinel):
if isinstance(redis_client, redis.asyncio.Sentinel):
assert service_name is not None
r = redis_client.master_for(
service_name,
redis_class=aioredis.Redis,
connection_pool_class=aioredis.sentinel.SentinelConnectionPool,
redis_class=redis.asyncio.Redis,
connection_pool_class=redis.asyncio.SentinelConnectionPool,
**_conn_opts,
)
else:
Expand All @@ -165,20 +166,20 @@ async def blpop(
continue
yield raw_msg[1]
except (
aioredis.exceptions.ConnectionError,
aioredis.sentinel.MasterNotFoundError,
aioredis.exceptions.ReadOnlyError,
aioredis.exceptions.ResponseError,
redis.exceptions.ConnectionError,
redis.asyncio.sentinel.MasterNotFoundError,
redis.exceptions.ReadOnlyError,
redis.exceptions.ResponseError,
ConnectionResetError,
):
await asyncio.sleep(reconnect_poll_interval)
continue
except aioredis.exceptions.ResponseError as e:
except redis.exceptions.ResponseError as e:
if e.args[0].startswith("NOREPLICAS "):
await asyncio.sleep(reconnect_poll_interval)
continue
raise
except (TimeoutError, asyncio.TimeoutError):
except (redis.exceptions.TimeoutError, asyncio.TimeoutError):
continue
except asyncio.CancelledError:
raise
Expand All @@ -187,8 +188,8 @@ async def blpop(


async def execute(
redis: RedisConnectionInfo | aioredis.Redis | aioredis.sentinel.Sentinel,
func: Callable[[aioredis.Redis], Awaitable[Any]],
redis_connector: RedisConnectionInfo | redis.asyncio.Redis | redis.asyncio.Sentinel,
func: Callable[[redis.asyncio.Redis], Awaitable[Any]],
*,
service_name: str = None,
read_only: bool = False,
Expand All @@ -206,26 +207,26 @@ async def execute(
**_default_conn_opts,
'socket_connect_timeout': reconnect_poll_interval,
}
if isinstance(redis, RedisConnectionInfo):
redis_client = redis.client
service_name = service_name or redis.service_name
if isinstance(redis_connector, RedisConnectionInfo):
redis_client = redis_connector.client
service_name = service_name or redis_connector.service_name
else:
redis_client = redis
redis_client = redis_connector

if isinstance(redis_client, aioredis.sentinel.Sentinel):
if isinstance(redis_client, redis.asyncio.Sentinel):
assert service_name is not None
if read_only:
r = redis_client.slave_for(
service_name,
redis_class=aioredis.Redis,
connection_pool_class=aioredis.sentinel.SentinelConnectionPool,
redis_class=redis.asyncio.Redis,
connection_pool_class=redis.asyncio.SentinelConnectionPool,
**_conn_opts,
)
else:
r = redis_client.master_for(
service_name,
redis_class=aioredis.Redis,
connection_pool_class=aioredis.sentinel.SentinelConnectionPool,
redis_class=redis.asyncio.Redis,
connection_pool_class=redis.asyncio.SentinelConnectionPool,
**_conn_opts,
)
else:
Expand All @@ -238,14 +239,14 @@ async def execute(
else:
raise TypeError('The func must be a function or a coroutinefunction '
'with no arguments.')
if isinstance(aw_or_pipe, aioredis.client.Pipeline):
if isinstance(aw_or_pipe, redis.asyncio.client.Pipeline):
result = await aw_or_pipe.execute()
elif inspect.isawaitable(aw_or_pipe):
result = await aw_or_pipe
else:
raise TypeError('The return value must be an awaitable'
'or aioredis.commands.Pipeline object')
if isinstance(result, aioredis.client.Pipeline):
if isinstance(result, redis.asyncio.client.Pipeline):
# This happens when func is an async function that returns a pipeline.
result = await result.execute()
if encoding:
Expand All @@ -259,20 +260,20 @@ async def execute(
else:
return result
except (
aioredis.exceptions.ConnectionError,
aioredis.sentinel.MasterNotFoundError,
aioredis.sentinel.SlaveNotFoundError,
aioredis.exceptions.ReadOnlyError,
redis.exceptions.ConnectionError,
redis.asyncio.sentinel.MasterNotFoundError,
redis.asyncio.sentinel.SlaveNotFoundError,
redis.exceptions.ReadOnlyError,
ConnectionResetError,
):
await asyncio.sleep(reconnect_poll_interval)
continue
except aioredis.exceptions.ResponseError as e:
except redis.exceptions.ResponseError as e:
if "NOREPLICAS" in e.args[0]:
await asyncio.sleep(reconnect_poll_interval)
continue
raise
except (TimeoutError, asyncio.TimeoutError):
except (redis.exceptions.TimeoutError, asyncio.TimeoutError):
continue
except asyncio.CancelledError:
raise
Expand All @@ -281,7 +282,7 @@ async def execute(


async def execute_script(
redis: RedisConnectionInfo | aioredis.Redis | aioredis.sentinel.Sentinel,
redis_connector: RedisConnectionInfo | redis.asyncio.Redis | redis.asyncio.Sentinel,
script_id: str,
script: str,
keys: Sequence[str],
Expand All @@ -303,20 +304,20 @@ async def execute_script(
script_hash = _scripts.get(script_id, 'x')
while True:
try:
ret = await execute(redis, lambda r: r.evalsha(
ret = await execute(redis_connector, lambda r: r.evalsha(
script_hash,
len(keys),
*keys, *args,
))
break
except aioredis.exceptions.NoScriptError:
except redis.exceptions.NoScriptError:
# Redis may have been restarted.
script_hash = await execute(redis, lambda r: r.script_load(script))
script_hash = await execute(redis_connector, lambda r: r.script_load(script))
_scripts[script_id] = script_hash
except aioredis.exceptions.ResponseError as e:
except redis.exceptions.ResponseError as e:
if 'NOSCRIPT' in e.args[0]:
# Redis may have been restarted.
script_hash = await execute(redis, lambda r: r.script_load(script))
script_hash = await execute(redis_connector, lambda r: r.script_load(script))
_scripts[script_id] = script_hash
else:
raise
Expand Down Expand Up @@ -393,7 +394,7 @@ async def read_stream_by_group(
autoclaim_start_id,
),
)
for msg_id, msg_data in aioredis.client.parse_stream_list(reply[1]):
for msg_id, msg_data in redis.client.parse_stream_list(reply[1]): # type: ignore
messages.append((msg_id, msg_data))
if reply[0] == b'0-0':
break
Expand Down Expand Up @@ -422,7 +423,7 @@ async def read_stream_by_group(
yield msg_id, msg_data
except asyncio.CancelledError:
raise
except aioredis.exceptions.ResponseError as e:
except redis.exceptions.ResponseError as e:
if e.args[0].startswith("NOGROUP "):
try:
await execute(
Expand All @@ -434,7 +435,7 @@ async def read_stream_by_group(
mkstream=True,
),
)
except aioredis.exceptions.ResponseError as e:
except redis.exceptions.ResponseError as e:
if e.args[0].startswith("BUSYGROUP "):
pass
else:
Expand All @@ -456,7 +457,7 @@ def get_redis_object(
sentinel_addresses = _sentinel_addresses

assert redis_config.get('service_name') is not None
sentinel = aioredis.sentinel.Sentinel(
sentinel = redis.asyncio.Sentinel(
[(str(host), port) for host, port in sentinel_addresses],
password=redis_config.get('password'),
db=str(db),
Expand All @@ -478,6 +479,6 @@ def get_redis_object(
.with_password(redis_config.get('password')) / str(db)
)
return RedisConnectionInfo(
client=aioredis.Redis.from_url(str(url), **kwargs),
client=redis.asyncio.Redis.from_url(str(url), **kwargs),
service_name=None,
)
Loading