diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d51bdd..481eb52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## Next version + +- [#52](https://github.com/sdss/sdsstools/pull/52) Add basic support for `GatheringTaskGroup` for Python versions 3.10 and below. + ## [1.8.1](https://github.com/sdss/sdsstools/compare/1.8.0...1.8.1) - Fixed import of `GatheringTaskGroup` for Python < 3.11. diff --git a/src/sdsstools/utils.py b/src/sdsstools/utils.py index 56d1ec6..4119042 100644 --- a/src/sdsstools/utils.py +++ b/src/sdsstools/utils.py @@ -23,6 +23,7 @@ "get_temporary_file_path", "run_in_executor", "cancel_task", + "GatheringTaskGroup", ] @@ -139,4 +140,50 @@ def results(self): return [task.result() for task in self.__tasks] +else: + + class GatheringTaskGroup: + """Simple implementation of ``asyncio.TaskGroup`` for Python 3.10 and below. + + The behaviour of this class is not exactly the same as ``asyncio.TaskGroup``, + especially when it comes to handling of exceptions during execution. + + """ + + def __init__(self): + self._tasks = [] + self._joined: bool = False + + def __repr__(self): + return f"" + + async def __aenter__(self): + self._joined = False + self._tasks = [] + + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if exc_type is not None: + return False + + await asyncio.gather(*self._tasks) + self._joined = True + + def create_task(self, coro): + """Creates a task and appends it to the list of tasks.""" + + task = asyncio.create_task(coro) + self._tasks.append(task) + + return task + + def results(self): + """Returns the results of the tasks in the same order they were created.""" + + if not self._joined: + raise RuntimeError("Tasks have not been gathered yet.") + + return [task.result() for task in self._tasks] + __all__.append("GatheringTaskGroup") diff --git a/test/test_utils.py b/test/test_utils.py index d634109..0c61212 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -9,13 +9,13 @@ from __future__ import annotations import asyncio -import sys import warnings from time import sleep import pytest from sdsstools.utils import ( + GatheringTaskGroup, Timer, cancel_task, get_temporary_file_path, @@ -23,10 +23,6 @@ ) -if sys.version_info >= (3, 11): - from sdsstools.utils import GatheringTaskGroup - - def test_timer(): with Timer() as timer: sleep(0.1) @@ -101,7 +97,6 @@ async def test_cancel_task_None(): await cancel_task(task) -@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11 or higher") async def test_gathering_task_group(): async def _task(i): await asyncio.sleep(0.1) @@ -114,7 +109,6 @@ async def _task(i): assert group.results() == list(range(10)) -@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11 or higher") async def test_gathering_task_group_results_fails(): async def _task(i): await asyncio.sleep(0.1)