Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: add missing docstrings to async_utils.py #3442

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions cve_bin_tool/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,19 @@


def async_wrap(func):
"""
Wrapper to use synchronous functions in asynchronous context.
"""

@wraps(func)
async def run(*args, loop=None, executor=None, **kwargs):
"""
Takes a synchronous function and executes it using specified executor.

Parameters :
loop (optional, event loop): Event loop to be used.
executor (optional, executor): Executor for calling the synchronous function in.
"""
if loop is None:
loop = asyncio.get_event_loop()
pfunc = partial(func, *args, **kwargs)
Expand All @@ -55,6 +66,9 @@ async def run(*args, loop=None, executor=None, **kwargs):


def get_event_loop():
"""
Gets or creates an event loop.
"""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
Expand All @@ -68,13 +82,27 @@ def get_event_loop():


def run_coroutine(coro):
"""
Runs an asynchronous coroutine and returns its result.
"""
loop = get_event_loop()
aws = asyncio.ensure_future(coro, loop=loop)
result = loop.run_until_complete(aws)
return result


async def aio_run_command(args, process_can_fail=True):
"""
Asynchronously run a command in a subprocess and return its output, error and return code

Parameters :
process_can_fail (Optional, bool) : If False, non-zero return codes result in errors.

Returns :
stdout: The output of the subprocess.
stderr: The error of the subprocess.
returncode: The returncode of the subprocess
"""
process = await asyncio.create_subprocess_exec(
*args,
stdout=asyncio.subprocess.PIPE,
Expand All @@ -89,18 +117,33 @@ async def aio_run_command(args, process_can_fail=True):


class ChangeDirContext:
"""
Allows temporary changes in the current working directory.
Manages context to allow going to destination directory and return back to original.
"""

def __init__(self, destination_dir):
self.current_dir = os.getcwd()
self.destination_dir = destination_dir

async def __aenter__(self):
"""
Changes into specified destination directory.
"""
os.chdir(self.destination_dir)

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""
Revert changes to return to current working directory.
"""
os.chdir(self.current_dir)


class FileIO:
"""
Provides asynchronous methods for file operations
"""

_open = async_wrap(open)
_name_idx: int | None = 0
_mode_idx = 1
Expand Down Expand Up @@ -131,12 +174,18 @@ async def __call__(self):
return await self.open()

async def open(self):
"""
Opens the file asynchronously.
"""
file = await self.__class__._open(*self._args, **self._kwargs)
self._file = file
self._setup()
return self

def _setup(self):
"""
Sets up the file object with asynchronous methods.
"""
if not self._file:
raise RuntimeError("Invalid Use: Call open() before calling _setup()")
common_async_attrs = {
Expand Down Expand Up @@ -189,43 +238,74 @@ def _setup(self):
]

async def __aenter__(self):
"""
Enters the asynchronous context.
"""
return await self.open()

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""
Exits the asynchronous context.
"""
return await self.close()

async def __anext__(self):
"""
Retrieves next line from the file asynchronously.
"""
line = await self.readline()
if line:
return line
else:
raise StopAsyncIteration

def __aiter__(self):
"""
Returns an asynchronous iterator for the file.
"""
return self


class TemporaryFile(FileIO):
"""
Asynchronous temporary FileIO wrapper.
"""

_open = async_wrap(tempfile.TemporaryFile)
_name_idx: int | None = None
_mode_idx = 0
_mode = "w+b"

def _setup(self):
"""
Sets up the temporary file.
"""
super()._setup()
self.name = self._file.name


class NamedTemporaryFile(TemporaryFile):
"""
Asynchronous Named Temporary File I/O Wrapper.
"""

_open = async_wrap(tempfile.NamedTemporaryFile)


class SpooledTemporaryFile(TemporaryFile):
"""
Asynchronous Spooled Temporary File I/O Wrapper.
"""

_open = async_wrap(tempfile.SpooledTemporaryFile)
_mode_idx = 1


class GzipFile(FileIO):
"""
Asynchronous Gzip File I/O Wrapper.
"""

_open = async_wrap(gzip.GzipFile)


Expand All @@ -248,16 +328,24 @@ def __init__(self, client):
self.updated_at = time.monotonic()

async def get(self, *args, **kwargs):
"""
Waits for a token then performs a get request."""
await self.wait_for_token()
return self.client.get(*args, **kwargs)

async def wait_for_token(self):
"""
Waits for a token to be available.
"""
while self.tokens < 1:
self.add_new_tokens()
await asyncio.sleep(0.1)
self.tokens -= 1

def add_new_tokens(self):
"""
Add new tokens if needed. Updates token count as required.
"""
now = time.monotonic()
time_since_update = now - self.updated_at
new_tokens = time_since_update * self.RATE
Expand All @@ -266,6 +354,9 @@ def add_new_tokens(self):
self.updated_at = now

async def close(self):
"""
Closes the client connection.
"""
await self.client.close()


Expand Down