Skip to content

Commit

Permalink
fix: revert tests
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-friday committed Nov 16, 2024
1 parent 036dc03 commit be98953
Showing 1 changed file with 33 additions and 22 deletions.
55 changes: 33 additions & 22 deletions src/tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pickle
import time
from concurrent.futures import ThreadPoolExecutor, wait
from functools import partial
from itertools import product
from pathlib import Path
from typing import Literal
Expand Down Expand Up @@ -346,17 +348,21 @@ async def test_pop_default(self, uid):
)
def test_filter(self, tags, method, expected):
assert len(self.origin_cache) == 0
for index, add_tags in enumerate([
[],
["tag0"],
["tag1"],
["tag2"],
["tag0", "tag1"],
["tag0", "tag2"],
["tag1", "tag2"],
["tag0", "tag1", "tag2"],
]):
self.origin_cache.set(index, index, tags=add_tags)
with ThreadPoolExecutor() as pool:
futures = [
pool.submit(self.origin_cache.set, index, index, tags=add_tags)
for index, add_tags in enumerate([
[],
["tag0"],
["tag1"],
["tag2"],
["tag0", "tag1"],
["tag0", "tag2"],
["tag1", "tag2"],
["tag0", "tag1", "tag2"],
])
]
wait(futures)

assert len(self.origin_cache) == 8
select = set(self.origin_cache.filter(tags, method=method))
Expand All @@ -373,17 +379,22 @@ def test_filter(self, tags, method, expected):
)
async def test_afilter(self, tags, method, expected):
assert len(self.origin_cache) == 0
for index, add_tags in enumerate([
[],
["tag0"],
["tag1"],
["tag2"],
["tag0", "tag1"],
["tag0", "tag2"],
["tag1", "tag2"],
["tag0", "tag1", "tag2"],
]):
await self.origin_cache.aset(index, index, tags=add_tags)
async with anyio.create_task_group() as task_group:
for index, add_tags in enumerate([
[],
["tag0"],
["tag1"],
["tag2"],
["tag0", "tag1"],
["tag0", "tag2"],
["tag1", "tag2"],
["tag0", "tag1", "tag2"],
]):
task_group.start_soon(
partial(
self.origin_cache.aset, index, index, tags=add_tags, retry=True
)
)

assert len(self.origin_cache) == 8
select = [x async for x in self.origin_cache.afilter(tags, method=method)]
Expand Down

0 comments on commit be98953

Please sign in to comment.