Skip to content

Commit

Permalink
Merge pull request #157 from lsst-dm/tickets/DM-48047
Browse files Browse the repository at this point in the history
DM-48047 : Refactor sync calls in async functions
  • Loading branch information
tcjennings authored Jan 15, 2025
2 parents a24381e + f2961de commit 8fa552e
Show file tree
Hide file tree
Showing 19 changed files with 489 additions and 337 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ select = [
"COM", # pyflakes-commas
"FBT", # flake8-boolean-trap
"UP", # pyupgrade
"ASYNC", # flake8-async
]
extend-select = [
"RUF100", # Warn about unused noqa
Expand Down
198 changes: 123 additions & 75 deletions src/lsst/cmservice/common/bash.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Utility functions for working with bash scripts"""

import contextlib
import os
import subprocess
import pathlib
from collections import deque
from typing import Any

import yaml
from anyio import Path, open_file, open_process
from anyio.streams.text import TextReceiveStream

from ..config import config
from .enums import StatusEnum
Expand All @@ -16,36 +17,64 @@ async def get_diagnostic_message(
log_url: str,
) -> str:
"""Read the last line of a log file, aspirational hoping
that it contains a diagnostic error message"""
with open(log_url, encoding="utf-8") as fin:
lines = fin.readlines()
if lines:
return lines[-1].strip()
that it contains a diagnostic error message
Parameters
----------
log_url : `str`
The url of the log which may contain a diagnostic message
Returns
-------
The last line of the log file, potentially containing a diagnostic message.
"""
log_path = Path(log_url)
last_line: deque[str] = deque(maxlen=1)
if not await log_path.exists():
return f"Log file {log_url} does not exist"
try:
async with await open_file(log_url) as f:
async for line in f:
last_line.append(line)

if last_line:
return last_line.pop().strip()
return "Empty log file"
except Exception as e:
return f"Error reading log file: {e}"


def parse_bps_stdout(url: str) -> dict[str, str]:
"""Parse the std from a bps submit job"""
async def parse_bps_stdout(url: str | Path) -> dict[str, str]:
"""Parse the stdout from a bps submit job.
Parameters
----------
url : `str | anyio.Path`
url for BPS submit stdout
Returns
-------
out_dict `str`
a dictionary containing the stdout from BPS submit
"""
out_dict = {}
with open(url, encoding="utf8") as fin:
line = fin.readline()
while line:
async with await open_file(url, encoding="utf8") as f:
async for line in f:
tokens = line.split(":")
if len(tokens) != 2: # pragma: no cover
line = fin.readline()
continue
out_dict[tokens[0]] = tokens[1]
line = fin.readline()
return out_dict


def run_bash_job(
script_url: str,
async def run_bash_job(
script_url: str | Path,
log_url: str,
stamp_url: str,
stamp_url: str | Path,
fake_status: StatusEnum | None = None,
) -> None:
"""Run a bash job
"""Run a bash job and write a "stamp file" with the value of the script's
resulting status.
Parameters
----------
Expand All @@ -55,43 +84,64 @@ def run_bash_job(
log_url: str
Location of log file to write
log_url: str
stamp_url: str | anyio.Path
Location of stamp file to write
fake_status: StatusEnum | None,
If set, don't actually submit the job
"""
fake_status = fake_status or config.mock_status
if fake_status is not None:
with open(stamp_url, "w", encoding="utf-8") as fstamp:
fields = dict(status=StatusEnum.reviewable.name)
yaml.dump(fields, fstamp)
yaml_output = yaml.dump(dict(status=StatusEnum.reviewable.name))
await Path(stamp_url).write_text(yaml_output)
return
try:
with open(log_url, "w", encoding="utf-8") as fout:
os.system(f"chmod +x {script_url}")
with subprocess.Popen(
[os.path.abspath(script_url)],
stdout=fout,
stderr=fout,
) as process:
process.wait()
if process.returncode != 0: # pragma: no cover
assert process.stderr
msg = process.stderr.read().decode()
raise CMBashSubmitError(f"Bad bash submit: {msg}")
await submit_file_to_run_in_bash(script_url, log_url)
except Exception as msg:
raise CMBashSubmitError(f"Bad bash submit: {msg}") from msg
with open(stamp_url, "w", encoding="utf-8") as fstamp:
fields = dict(status="accepted")
yaml.dump(fields, fstamp)
fields = dict(status=StatusEnum.accepted.name)
yaml_output = yaml.dump(fields)
await Path(stamp_url).write_text(yaml_output)


async def submit_file_to_run_in_bash(script_url: str | Path, log_url: str | pathlib.Path) -> None:
"""Make a script executable, then submit to run in bash.
def check_stamp_file(
stamp_file: str | None,
Parameters
----------
script_url : `str | anyio.Path`
Path to the script to run. Must be or will be cast as an async Path.
log_url : `str | pathlib.Path`
Path to output the logs. Must be or will be case as a sync Path.
"""
script_path = Path(script_url)
log_path = Path(log_url)

if await script_path.exists():
await script_path.chmod(0o755)
else:
raise CMBashSubmitError(f"No script at path {script_url}")

script_command = await script_path.resolve()

async with await open_process([script_command]) as process:
await process.wait()
assert process.stdout
assert process.stderr
async with await open_file(log_path, "w") as log_out:
async for text in TextReceiveStream(process.stdout):
await log_out.write(text)
async for text in TextReceiveStream(process.stderr):
await log_out.write(text)
if process.returncode != 0: # pragma: no cover
raise CMBashSubmitError("Bad bash submit, check log file.")


async def check_stamp_file(
stamp_file: str | Path | None,
default_status: StatusEnum,
) -> StatusEnum:
"""Check a 'stamp' file for a status code
"""Check a 'stamp' file for a status code.
Parameters
----------
Expand All @@ -108,70 +158,68 @@ def check_stamp_file(
"""
if stamp_file is None:
return default_status
if not os.path.exists(stamp_file):
stamp_file = Path(stamp_file)
if not await stamp_file.exists():
return default_status
with open(stamp_file, encoding="utf-8") as fin:
fields = yaml.safe_load(fin)
return StatusEnum[fields["status"]]
stamp = await Path(stamp_file).read_text()
fields = yaml.safe_load(stamp)
return StatusEnum[fields["status"]]


def write_bash_script(
script_url: str,
async def write_bash_script(
script_url: str | Path,
command: str,
**kwargs: Any,
) -> str:
"""Utility function to write a bash script for later execution
) -> Path:
"""Utility function to write a bash script for later execution.
Parameters
----------
script_url: str
script_url: `str | anyio.Path`
Location to write the script
command: str
command: `str`
Main command line(s) in the script
Keywords
--------
prepend: str | None
prepend: `str | None`
Text to prepend before command
append: str | None
append: `str | None`
Test to append after command
stamp: str | None
stamp: `str | None`
Text to echo to stamp file when script completes
stamp_url: str | None
stamp_url: `str | None`
Stamp file to write to when script completes
fake: str | None
fake: `str | None`
Echo command instead of running it
rollback: str | None
rollback: `str | Path | None`
Prefix to script_url used when rolling back
processing
processing. Will default to CWD (".").
Returns
-------
script_url : str
script_url : `anyio.Path`
The path to the newly written script
"""
prepend = kwargs.get("prepend")
append = kwargs.get("append")
fake = kwargs.get("fake")
rollback_prefix = kwargs.get("rollback", "")

script_url = f"{rollback_prefix}{script_url}"
with contextlib.suppress(OSError):
os.makedirs(os.path.dirname(script_url))

with open(script_url, "w", encoding="utf-8") as fout:
if prepend:
fout.write(f"{prepend}\n")
if fake:
command = f"echo '{command}'"
fout.write(command)
fout.write("\n")
if append:
fout.write(f"{append}\n")
return script_url
rollback_prefix = Path(kwargs.get("rollback", "."))

script_path = rollback_prefix / script_url

if fake:
command = f"echo '{command}'"

await script_path.parent.mkdir(parents=True, exist_ok=True)
contents = (prepend if prepend else "") + "\n" + command + "\n" + (append if append else "")

async with await open_file(script_path, "w") as fout:
await fout.write(contents)
return script_path
33 changes: 25 additions & 8 deletions src/lsst/cmservice/common/butler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Utility functions for working with butler commands"""

from functools import partial

from anyio import to_thread

from lsst.daf.butler import Butler, MissingCollectionError

from ..common import errors


def remove_run_collections(
async def remove_run_collections(
butler_repo: str,
collection_name: str,
*,
Expand All @@ -25,20 +29,27 @@ def remove_run_collections(
Allow for missing butler
"""
try:
butler = Butler.from_config(butler_repo, collections=[collection_name], without_datastore=True)
butler_f = partial(
Butler.from_config,
butler_repo,
collections=[collection_name],
without_datastore=True,
)
butler = await to_thread.run_sync(butler_f)
except Exception as e:
if fake_reset:
return
raise errors.CMNoButlerError(e) from e # pragma: no cover
try: # pragma: no cover
butler.registry.removeCollection(collection_name)
await to_thread.run_sync(butler.registry.removeCollection, collection_name)
except MissingCollectionError:
pass
except Exception as msg:
raise errors.CMButlerCallError(msg) from msg


def remove_non_run_collections(
# FIXME how is this different to `remove_run_collections`?
async def remove_non_run_collections(
butler_repo: str,
collection_name: str,
*,
Expand All @@ -58,18 +69,24 @@ def remove_non_run_collections(
Allow for missing butler
"""
try:
butler = Butler.from_config(butler_repo, collections=[collection_name], without_datastore=True)
butler_f = partial(
Butler.from_config,
butler_repo,
collections=[collection_name],
without_datastore=True,
)
butler = await to_thread.run_sync(butler_f)
except Exception as e:
if fake_reset:
return
raise errors.CMNoButlerError(e) from e # pragma: no cover
try: # pragma: no cover
butler.registry.removeCollection(collection_name)
await to_thread.run_sync(butler.registry.removeCollection, collection_name)
except Exception as msg:
raise errors.CMButlerCallError(msg) from msg


def remove_collection_from_chain( # pylint: disable=unused-argument
async def remove_collection_from_chain(
butler_repo: str,
chain_collection: str,
collection_name: str,
Expand Down Expand Up @@ -97,7 +114,7 @@ def remove_collection_from_chain( # pylint: disable=unused-argument
raise NotImplementedError


def remove_datasets_from_collections( # pylint: disable=unused-argument
async def remove_datasets_from_collections(
butler_repo: str,
tagged_collection: str,
collection_name: str,
Expand Down
4 changes: 2 additions & 2 deletions src/lsst/cmservice/common/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ class CMButlerCallError(RuntimeError):
"""Raised when a call to butler fails"""


class CMSpecficiationError(KeyError):
"""Raised when Specification calls out an non-existing fragement"""
class CMSpecificationError(KeyError):
"""Raised when Specification calls out a non-existing fragment"""


class CMTooFewAcceptedJobsError(KeyError):
Expand Down
Loading

0 comments on commit 8fa552e

Please sign in to comment.