Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs(async_utils): add docstrings #3398

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions cve_bin_tool/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,16 @@


def async_wrap(func):
'''
A decorator that wraps a synchronous function in an asynchronous wrapper.
Takes in a func(The synchronous function to be wrapped).
And returns an asynchronous function that wraps the synchronous function.
'''
@wraps(func)
async def run(*args, loop=None, executor=None, **kwargs):
'''
Runs the synchronous function in an asynchronous wrapper.
'''
if loop is None:
loop = asyncio.get_event_loop()
pfunc = partial(func, *args, **kwargs)
Expand All @@ -55,6 +63,7 @@ async def run(*args, loop=None, executor=None, **kwargs):


def get_event_loop():
"""Get the current event loop or create a new one if none exists."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
Expand All @@ -68,13 +77,19 @@ def get_event_loop():


def run_coroutine(coro):
"""Run a coroutine in a new event loop.
Takes in a coroutine and returns the result of the coroutine.
"""
loop = get_event_loop()
aws = asyncio.ensure_future(coro, loop=loop)
result = loop.run_until_complete(aws)
return result


async def aio_run_command(args, process_can_fail=True):
"""Run a command asynchronously.
Takes in a list of arguments and returns the stdout, stderr, and return code.
"""
process = await asyncio.create_subprocess_exec(
*args,
stdout=asyncio.subprocess.PIPE,
Expand All @@ -90,10 +105,18 @@ async def aio_run_command(args, process_can_fail=True):

class ChangeDirContext:
def __init__(self, destination_dir):
"""Context manager to change directory to destination_dir and then change back to current_dir
Usage:
async with ChangeDirContext(destination_dir):
# do something
"""
self.current_dir = os.getcwd()
self.destination_dir = destination_dir

async def __aenter__(self):
"""Change directory to destination_dir
Takes in a destination directory and changes the current working directory to that directory.
"""
os.chdir(self.destination_dir)

async def __aexit__(self, exc_type, exc_val, exc_tb):
Expand All @@ -107,6 +130,10 @@ class FileIO:
_mode = "r"

def __init__(self, *args, **kwargs):
"""Initialize FileIO object
Takes in a list of arguments and keyword arguments and sets up the FileIO object by
setting the name and mode attributes and storing the arguments and keyword arguments.
"""
# Do some trick to get exact filename and mode regardless of args or kwargs
flatargs = list(itertools.chain(args, kwargs.values()))
if (
Expand All @@ -131,12 +158,16 @@ async def __call__(self):
return await self.open()

async def open(self):
"""Open file and setup async methods"""
file = await self.__class__._open(*self._args, **self._kwargs)
self._file = file
self._setup()
return self

def _setup(self):
"""Setup async methods
Takes in a file object and sets up the async methods for that file object.
"""
if not self._file:
raise RuntimeError("Invalid Use: Call open() before calling _setup()")
common_async_attrs = {
Expand Down Expand Up @@ -189,19 +220,25 @@ def _setup(self):
]

async def __aenter__(self):
"""Open file and setup async methods
Takes in a file object and sets up the async methods for that file object.
"""
return await self.open()

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Close file"""
return await self.close()

async def __anext__(self):
"""Read next line from file"""
line = await self.readline()
if line:
return line
else:
raise StopAsyncIteration

def __aiter__(self):
"""Returns FileIO object"""
return self


Expand All @@ -212,6 +249,9 @@ class TemporaryFile(FileIO):
_mode = "w+b"

def _setup(self):
"""Setup async methods
Takes in a file object and sets up the async methods for that file object.
"""
super()._setup()
self.name = self._file.name

Expand All @@ -238,26 +278,43 @@ class RateLimiter:
Copyright 2018 Quentin Pradet
See license at top of file.
"""

RATE = 10
MAX_TOKENS = 10

def __init__(self, client):
"""Initialize RateLimiter object
Takes in a client object and sets up the rate limiter for that client.
"""
self.client = client
self.tokens = self.MAX_TOKENS
self.updated_at = time.monotonic()

async def get(self, *args, **kwargs):
"""Make a GET request
Takes in a list of arguments and keyword arguments and makes a GET request.
"""
await self.wait_for_token()
return self.client.get(*args, **kwargs)

async def wait_for_token(self):
"""Wait for a token to become available
Keeps adding new tokens until one is available.
Delays for 0.1 seconds between each check.
After checking, subtracts one token from the total.
"""
while self.tokens < 1:
self.add_new_tokens()
await asyncio.sleep(0.1)
self.tokens -= 1

def add_new_tokens(self):
""" This method calculates the number of new tokens to add based on the elapsed time since
the last update and the rate limit. If the total number of tokens (including the new ones)
is less than the maximum allowed, the new tokens are added to the token pool and the update
time is set to the current time. Otherwise, the token pool is filled to the maximum allowed.
This method is called by the `wait_for_token` method to ensure that enough tokens are available
before making a request.
"""
now = time.monotonic()
time_since_update = now - self.updated_at
new_tokens = time_since_update * self.RATE
Expand All @@ -266,6 +323,7 @@ def add_new_tokens(self):
self.updated_at = now

async def close(self):
"""This method closes the client."""
await self.client.close()


Expand All @@ -275,4 +333,4 @@ async def close(self):
aio_glob = async_wrap(glob.glob)
aio_mkdtemp = async_wrap(tempfile.mkdtemp)
aio_makedirs = async_wrap(os.makedirs)
aio_inpath = async_wrap(inpath)
aio_inpath = async_wrap(inpath)