Skip to content

Commit

Permalink
Merge pull request #21 from quant-aq/dmcc/wandb-tags/sc-14754
Browse files Browse the repository at this point in the history
Support associating tasks with tags, minor cleanups
  • Loading branch information
dmcc authored Sep 10, 2024
2 parents 8d6a11b + d520c8a commit 082d9d6
Show file tree
Hide file tree
Showing 18 changed files with 183 additions and 68 deletions.
12 changes: 4 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,21 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: 'v4.4.0'
rev: 'v4.6.0'
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-added-large-files

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.289'
rev: 'v0.5.5'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --show-fixes]

- repo: https://github.com/psf/black
rev: '23.9.1'
hooks:
- id: black
- id: ruff-format

- repo: https://github.com/google/keep-sorted
rev: v0.1.0
rev: v0.4.0
hooks:
- id: keep-sorted

Expand Down
3 changes: 3 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# This creates an image for running Aeromancy tasks, used by runner.py.
# See Aeromancy scaffolding docs for an overview of how this fits together.

# NOTE: If this file is updated, make a new Aeromancy release and point
# aeromancy.runner.build_docker to use it.

# This Python version must match a version required by pyproject.toml.
FROM python:3.10.5-bullseye
# To set this in Aeromancy, pass --extra-debian-package to `pdm go`.
Expand Down
19 changes: 13 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ doc.shell = "cd docs && mkdocs serve -a localhost:8030"
lint.help = "Run linters over all files"
lint.cmd = "pre-commit run --all-files"

# PDM hook to automatically sort dependencies when they change.
# "true" is there since GitHub Actions will fail if this returns a non-zero exit code.
post_install.shell = "pre-commit run keep-sorted --files pyproject.toml > /dev/null 2> /dev/null; true"

#
# Aeromancy-specific scripts.
#
Expand All @@ -100,7 +104,13 @@ line-length = 88
target-version = ["py310"]

[tool.ruff]
src = ["src"]
line-length = 88
extend-exclude = ["tests/fixtures", "__pycache__"]
target-version = "py310"
namespace-packages = ["docs", "tasks"]

[tool.ruff.lint]
select = [
# keep-sorted start
"ASYNC", # flake8-async
Expand Down Expand Up @@ -148,6 +158,7 @@ ignore = [
# keep-sorted start
"D203", # one-blank-line-before-class (incompatible with D211)
"D213", # multi-line-summary-second-line (incompatible with D212)
"D413", # blank-line-after-last-section (not needed for numpy conventions)
"PGH003", # blanket-type-ignore: Use specific rule codes when ignoring type issues (no always feasible)
"RET504", # unnecessary-assign: Sometimes worth doing for readability
"RET505", # superfluous-else-return: Currently has a bug
Expand All @@ -159,15 +170,11 @@ ignore = [
"TRY003", # raise-vanilla-args: Not for prototyping, not always an error
# keep-sorted end
]
src = ["src"]
extend-exclude = ["tests/fixtures", "__pycache__"]
target-version = "py310"
namespace-packages = ["docs", "tasks"]

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"**/tests/*" = ["S101"] # Allow using asserts in tests

[tool.ruff.isort]
[tool.ruff.lint.isort]
known-first-party = ["aeromancy"]

[tool.mypy]
Expand Down
35 changes: 24 additions & 11 deletions src/aeromancy/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ class Action:
"""A specific piece of work to track.
This includes the code to run, artifacts it depends on, and artifacts it
produces. For organizational purposes, they can fill in class variables:
produces. For organizational purposes, subclasses can fill in class
variables:
- `job_type`
- `job_group`
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(self, parents: list["Action"], **config):
self.parents = parents
self._tracker_class = WandbTracker
self._project_name = None
self._tags = None

def outputs(self) -> list[str]:
"""Describe what this `Action` will produce after being run.
Expand Down Expand Up @@ -73,22 +75,17 @@ def _run(self) -> None:
f"Must set project_name on your Action class: {self.__class__}",
)

# TODO: Exceptions that happen in this period can get squelched and
# tough to debug. We should check for bad interactions with pydoit.
with self._tracker_class(
job_type=self.job_type,
job_group=self.job_group,
config=self.config,
project_name=self._project_name,
tags=self._tags,
) as tracker:
self.run(tracker)

def _set_tracker(self, tracker_class: type[Tracker]) -> None:
"""Set a different class to use for tracking.
This should only be called under special circumstances (e.g., testing
environments, offline mode).
"""
self._tracker_class = tracker_class

def get_io(self, resolve_outputs=False) -> tuple[list[str], list[str]]:
"""Get inputs and outputs for this `Action`.
Expand Down Expand Up @@ -122,9 +119,25 @@ def get_io(self, resolve_outputs=False) -> tuple[list[str], list[str]]:
]
return (full_inputs, full_outputs)

def _set_runtime_properties(self, project_name: str, skip: bool):
"""Set properties that we won't know until we're ready to run."""
def _set_tracker(self, tracker_class: type[Tracker]) -> None:
"""Set a different class to use for tracking.
This should only be called under special circumstances (e.g., testing
environments, offline mode).
"""
# TODO: With some refactoring, this could be combined with
# _set_buildtime_properties
self._tracker_class = tracker_class

def _set_buildtime_properties(self, project_name: str, skip: bool):
"""Set properties that we won't know until ActionBuilder time."""
self._skip = skip
self._project_name = project_name

def _set_runtime_properties(self, tags: set[str] | None = None):
"""Set properties that we won't know until ActionRunner time."""
# TODO: With some refactoring, this could be combined with
# _set_buildtime_properties
self._tags = tags

skip = property(lambda self: self._skip, doc="Whether this action should be run")
2 changes: 1 addition & 1 deletion src/aeromancy/action_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def add_action(
-------
The `Action` passed as `action`, with additional run state added
"""
action._set_runtime_properties(self._project_name, skip=skip)
action._set_buildtime_properties(self._project_name, skip=skip)
actions.append(action)
return action

Expand Down
17 changes: 14 additions & 3 deletions src/aeromancy/action_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self, actions: list[Action]):
"""
self.actions = actions
self.job_name_filter = None
self.job_tags = set()

@override
def load_doit_config(self):
Expand All @@ -106,7 +107,12 @@ def load_doit_config(self):

@override
def load_tasks(self, **unused) -> list[DoitTask]:
return [self._convert_action_to_doittask(action) for action in self.actions]
tasks = []
for action in self.actions:
action._set_runtime_properties(tags=self.job_tags)
tasks.append(self._convert_action_to_doittask(action))

return tasks

def _convert_action_to_doittask(
self,
Expand Down Expand Up @@ -179,9 +185,10 @@ def _list_actions(self):

def run_actions(
self,
only: str | None,
only: set[str] | None,
graph: bool,
list_actions: bool,
tags: set[str] | None,
**unused_kwargs,
):
"""Run the stored `Action`s using pydoit.
Expand All @@ -196,14 +203,16 @@ def run_actions(
If True, show the action dependency graph and exit.
list_actions
If True, show a list of action names and exit.
tags
If set, a comma-separated list of tags to apply to all jobs launched.
unused_kwargs
Should not be used -- this is here as part of some Click hackery to
show all options in the help menu.
"""
if only:

def job_name_filter(job_name):
for job_name_substring in only.split(","):
for job_name_substring in only:
if job_name_substring.strip() in job_name:
return True
return False
Expand All @@ -218,6 +227,8 @@ def job_name_filter(job_name):
self._list_actions()
raise SystemExit

self.job_tags = tags

if get_runtime_environment().dev_mode:
console.rule(
"[red bold]DEV MODE[/red bold]",
Expand Down
1 change: 1 addition & 0 deletions src/aeromancy/aeroview.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
shell> pdm aeroview <Weights and Biases artifact name>
"""

import subprocess
from pathlib import Path

Expand Down
5 changes: 1 addition & 4 deletions src/aeromancy/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,7 @@ def matches(self, other_artifact_name: "WandbArtifactName") -> bool:
):
return False

if self.artifact_name != other_artifact_name.artifact_name:
return False

return True
return self.artifact_name == other_artifact_name.artifact_name

def incorporate_overrides(self):
"""Incorporate artifact version overrides.
Expand Down
42 changes: 34 additions & 8 deletions src/aeromancy/click_options.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Groups of Click options for the main Aeromancy CLI interface."""

import functools
import shlex

import rich_click as click

Expand All @@ -9,7 +11,12 @@
"main.py": [
{
"name": "Task runner options",
"options": ["--only", "--graph", "--list-actions"],
"options": [
"--only",
"--graph",
"--list-actions",
"--tags",
],
},
{
"name": "Aeromancy runtime options",
Expand All @@ -27,6 +34,14 @@
}


def csv_string_to_set(ctx, param, value) -> set[str] | None:
"""Parse a CSV string into a set of strings."""
if value is not None:
return {piece.strip() for piece in value.split(",")}

return None


def runner_click_options(function):
"""Wrap `function` with all Click options for Aeromancy runtime."""
# NOTE: Keep in sync with OPTION_GROUPS.
Expand Down Expand Up @@ -58,14 +73,16 @@ def runner_click_options(function):
"mount). You should generally not need to change this, but it may be "
"set in pdm scripts when setting up a project."
),
# Parse a string into a list honoring shell-style quoting.
callback=lambda ctx, param, value: shlex.split(value),
)
@click.option(
"--extra-debian-package",
"extra_debian_packages",
metavar="PKGS",
metavar="PKG",
multiple=True,
help=(
"Names of Debian packages to include in the Docker image in addition to "
"Name of a Debian package to include in the Docker image in addition to "
"standard packages required by Aeromancy. Specify this option once per "
"extra package. You should generally not need to change this, but it may "
"be set in pdm scripts when setting up a project."
Expand All @@ -74,10 +91,10 @@ def runner_click_options(function):
@click.option(
"--extra-env-var",
"extra_env_vars",
metavar="VARS",
metavar="VAR",
multiple=True,
help=(
"Extra environment variables to passthrough to Aeromancy. Specify this "
"Extra environment variable to passthrough to Aeromancy. Specify this "
"option once per variable. You should generally not need to change this, "
"but it may be set in pdm scripts when setting up a project."
),
Expand All @@ -99,8 +116,8 @@ def runner_click_options(function):
"aeromain_path",
default="src/main.py",
metavar="PATH",
# To minimize confusion since this is only intended for
# debugging/testing Aeromancy itself.
# This is only intended for debugging/testing Aeromancy itself so hidden
# to minimize confusion.
hidden=True,
help="Set an alternate Aeromain file to run.",
)
Expand All @@ -125,7 +142,8 @@ def aeromancy_click_options(function):
type=str,
metavar="SUBSTRS",
help="If set: comma-separated list of substrings. We'll only run jobs which "
"match at least one of these (in dependency order).",
"match at least one of these.",
callback=csv_string_to_set,
)
@click.option(
"--graph",
Expand All @@ -138,6 +156,14 @@ def aeromancy_click_options(function):
is_flag=True,
help="If set: show a list of all job names and exit.",
)
@click.option(
"--tags",
"tags",
metavar="TAGS",
help="Comma-separated tags to add to each task launched. These tags are purely "
"for organizational purposes.",
callback=csv_string_to_set,
)
@runner_click_options
@functools.wraps(function)
def wrapper_aeromancy_options(*args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions src/aeromancy/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Warning: Here be dragons (*cough* hacks). Code in this module relies on
Aeromancy internals and is subject to change.
"""

import os
from collections.abc import Sequence
from pathlib import Path
Expand Down
2 changes: 2 additions & 0 deletions src/aeromancy/fake_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,15 @@ def __init__(
config: dict | None = None,
job_type: str | None = None,
job_group: str | None = None,
tags: set[str] | None = None,
):
Tracker.__init__(
self,
project_name=project_name,
config=config,
job_type=job_type,
job_group=job_group,
tags=tags,
)

self.cache_root_path = Path("~/FakeCache").expanduser().resolve()
Expand Down
1 change: 1 addition & 0 deletions src/aeromancy/rerun.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
After it creates a Git repo for that run, it will provide instructions for how to rerun.
"""

import subprocess
import tempfile
from dataclasses import dataclass
Expand Down
Loading

0 comments on commit 082d9d6

Please sign in to comment.