Skip to content

Commit

Permalink
Add client class and extend CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzm committed Oct 17, 2023
1 parent e03f5d5 commit 9b5b987
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 27 deletions.
156 changes: 140 additions & 16 deletions src/lsst/cmservice/cli/commands.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,158 @@
import json
from dataclasses import dataclass
from typing import Generator, Iterable, TypeVar

import click
import structlog
import uvicorn
import yaml
from safir.asyncio import run_with_asyncio
from safir.database import create_database_engine, initialize_database
from tabulate import tabulate

from .. import models
from ..client import CMClient
from ..config import config
from ..db import Base
from . import options

T = TypeVar("T")


@click.group(context_settings={"help_option_names": ["-h", "--help"]})
@click.group()
@click.version_option(package_name="lsst-cm-service")
def main() -> None:
"""Administrative command-line interface for cm-service."""


@main.command(name="help")
@click.argument("topic", default=None, required=False, nargs=1)
@click.pass_context
def help_command(ctx: click.Context, topic: str | None) -> None:
"""Show help for any command."""
# The help command implementation is taken from
# https://www.burgundywall.com/post/having-click-help-subcommand
if topic:
if topic in main.commands:
click.echo(main.commands[topic].get_help(ctx))
else:
raise click.UsageError(f"Unknown help topic {topic}", ctx)
else:
assert ctx.parent
click.echo(ctx.parent.get_help())
@main.group()
def get() -> None:
"""Display one or many resources."""


@get.command()
@options.cmclient()
@options.output()
def productions(client: CMClient, output: options.OutputEnum | None) -> None:
"""Display one or more productions."""
productions = client.get_productions()
match output:
case options.OutputEnum.json:
jtable = [p.dict() for p in productions]
click.echo(json.dumps(jtable, indent=4))
case options.OutputEnum.yaml:
ytable = [p.dict() for p in productions]
click.echo(yaml.dump(ytable))
case _:
ptable = [[p.name, p.id] for p in productions]
click.echo(tabulate(ptable, headers=["NAME", "ID"], tablefmt="plain"))


@get.command()
@options.cmclient()
@options.output()
def campaigns(client: CMClient, output: options.OutputEnum | None) -> None:
"""Display one or more campaigns."""
campaigns = client.get_campaigns()
match output:
case options.OutputEnum.json:
jtable = [c.dict() for c in campaigns]
click.echo(json.dumps(jtable, indent=4))
case options.OutputEnum.yaml:
ytable = [c.dict() for c in campaigns]
click.echo(yaml.dump(ytable))
case _:
productions = client.get_productions()
pbyid = {p.id: p.name for p in productions}
ctable = [[c.name, pbyid.get(c.production, ""), c.id] for c in campaigns]
click.echo(tabulate(ctable, headers=["NAME", "CAMPAIGN", "ID"], tablefmt="plain"))


@main.group()
def create() -> None:
"""Create a resource."""


@main.group()
def apply() -> None:
"""Apply configuration to a resource."""


@main.group()
def delete() -> None:
"""Delete a resource."""


def _lookahead(iterable: Iterable[T]) -> Generator[tuple[T, bool], None, None]:
"""A generator which returns all elements of the provided iteratable as
tuples with an additional `bool`; the `bool` will be `True` on the last
element and `False` otherwise."""
it = iter(iterable)
try:
last = next(it)
except StopIteration:
return
for val in it:
yield last, False
last = val
yield last, True


@dataclass
class Root:
name: str = "."


def _tree_prefix(last: list[bool]) -> str:
prefix = ""
for li, ll in _lookahead(last):
match (ll, li):
case (False, False):
prefix += "│ "
case (False, True):
prefix += " "
case (True, False):
prefix += "├── "
case (True, True): # pragma: no branch
prefix += "└── "
return prefix


def _tree_children(
client: CMClient,
node: Root | models.Production | models.Campaign | models.Step | models.Group,
) -> list:
match node:
case Root():
return client.get_productions()
case models.Production():
return client.get_campaigns(node.id)
case models.Campaign():
return client.get_steps(node.id)
case models.Step():
return client.get_groups(node.id)
case _:
return []


def _tree1(
client: CMClient,
node: Root | models.Production | models.Campaign | models.Step | models.Group,
last: list[bool],
) -> None:
click.echo(_tree_prefix(last) + node.name)
last.append(True)
for child, last[-1] in _lookahead(_tree_children(client, node)):
_tree1(client, child, last)
last.pop()


@main.command()
@options.cmclient()
@click.argument("path", required=False)
def tree(client: CMClient, path: str | None) -> None:
"""List resources recursively beneath PATH."""
_tree1(client, Root(), [])


@main.command()
Expand Down
66 changes: 66 additions & 0 deletions src/lsst/cmservice/cli/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from enum import Enum, auto
from functools import partial
from typing import Any, Callable, Type

import click
from click.decorators import FC

from ..client import CMClient

__all__ = [
"cmclient",
"output",
"OutputEnum",
]


class EnumChoice(click.Choice):
"""A version of click.Choice specialized for enum types."""

def __init__(self, enum: Type[Enum], case_sensitive: bool = True) -> None:
self._enum = enum
super().__init__(list(enum.__members__.keys()), case_sensitive=case_sensitive)

def convert(self, value: Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum:
converted_str = super().convert(value, param, ctx)
return self._enum.__members__[converted_str]


class PartialOption:
"""Wraps click.option decorator with partial arguments for convenient
reuse."""

def __init__(self, *param_decls: str, **attrs: Any) -> None:
self._partial = partial(click.option, *param_decls, cls=partial(click.Option), **attrs)

def __call__(self, *param_decls: str, **attrs: Any) -> Callable[[FC], FC]:
return self._partial(*param_decls, **attrs)


class OutputEnum(Enum):
yaml = auto()
json = auto()


output = PartialOption(
"--output",
"-o",
type=EnumChoice(OutputEnum),
help="Output format. Summary table if not specified.",
)


def make_client(ctx: click.Context, param: click.Parameter, value: Any) -> CMClient:
return CMClient(value)


cmclient = PartialOption(
"--server",
"client",
type=str,
default="http://localhost:8080/cm-service/v1",
envvar="CM_SERVICE",
show_envvar=True,
callback=make_client,
help="URL of cm service.",
)
49 changes: 49 additions & 0 deletions src/lsst/cmservice/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import httpx
from pydantic import parse_obj_as

from . import models

__all__ = ["CMClient"]


class CMClient:
"""Interface for accessing remote cm-service"""

def __init__(self, url: str) -> None:
self._client = httpx.Client(base_url=url)

def get_productions(self) -> list[models.Production]:
skip = 0
productions = []
query = "productions?"
while (results := self._client.get(f"{query}skip={skip}").json()) != []:
productions.extend(parse_obj_as(list[models.Production], results))
skip += len(results)
return productions

def get_campaigns(self, production: int | None = None) -> list[models.Campaign]:
skip = 0
campaigns = []
query = f"campaigns?{f'production={production}&' if production else ''}"
while (results := self._client.get(f"{query}skip={skip}").json()) != []:
campaigns.extend(parse_obj_as(list[models.Campaign], results))
skip += len(results)
return campaigns

def get_steps(self, campaign: int | None = None) -> list[models.Step]:
skip = 0
steps = []
query = f"steps?{f'campaign={campaign}&' if campaign else ''}"
while (results := self._client.get(f"{query}skip={skip}").json()) != []:
steps.extend(parse_obj_as(list[models.Step], results))
skip += len(results)
return steps

def get_groups(self, step: int | None = None) -> list[models.Group]:
skip = 0
groups = []
query = f"groups?{f'step={step}&' if step else ''}"
while (results := self._client.get(f"{query}skip={skip}").json()) != []:
groups.extend(parse_obj_as(list[models.Group], results))
skip += len(results)
return groups
37 changes: 26 additions & 11 deletions tests/cli/test_commands.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,42 @@
from click.testing import CliRunner
from safir.testing.uvicorn import UvicornProcess

from lsst.cmservice.cli.commands import main
from lsst.cmservice.config import config


def test_commands() -> None:
runner = CliRunner()
def test_commands(uvicorn: UvicornProcess) -> None:
env = {"CM_SERVICE": f"{uvicorn.url}{config.prefix}"}
runner = CliRunner(env=env)

result = runner.invoke(main, ["--version"])
result = runner.invoke(main, "--version")
assert result.exit_code == 0
assert "version" in result.output

result = runner.invoke(main, ["--help"])
result = runner.invoke(main, "--help")
assert result.exit_code == 0
assert "Usage:" in result.output

result = runner.invoke(main, ["help"])
result = runner.invoke(main, "init")
assert result.exit_code == 0
assert "Usage:" in result.output

result = runner.invoke(main, ["help", "init"])
result = runner.invoke(main, "get productions")
assert result.exit_code == 0
assert "Usage:" in result.output

result = runner.invoke(main, ["help", "bogus"])
assert result.exit_code == 2
assert "Usage:" in result.output
result = runner.invoke(main, "get productions -o yaml")
assert result.exit_code == 0

result = runner.invoke(main, "get productions -o json")
assert result.exit_code == 0

result = runner.invoke(main, "get campaigns")
assert result.exit_code == 0

result = runner.invoke(main, "get campaigns -o yaml")
assert result.exit_code == 0

result = runner.invoke(main, "get campaigns -o json")
assert result.exit_code == 0

result = runner.invoke(main, "tree")
assert result.exit_code == 0
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from asyncio import AbstractEventLoop, get_event_loop_policy
from pathlib import Path
from typing import AsyncIterator, Iterator

import pytest
Expand All @@ -8,6 +9,7 @@
from fastapi import FastAPI
from httpx import AsyncClient
from safir.database import create_database_engine, initialize_database
from safir.testing.uvicorn import UvicornProcess, spawn_uvicorn
from sqlalchemy.ext.asyncio import AsyncEngine

from lsst.cmservice import main
Expand Down Expand Up @@ -49,3 +51,10 @@ async def client(app: FastAPI) -> AsyncIterator[AsyncClient]:
"""Return an ``httpx.AsyncClient`` configured to talk to the test app."""
async with AsyncClient(app=app, base_url="https:") as client:
yield client


@pytest_asyncio.fixture
async def uvicorn(tmp_path: Path) -> AsyncIterator[UvicornProcess]:
uvicorn = spawn_uvicorn(working_directory=tmp_path, app="lsst.cmservice.main:app")
yield uvicorn
uvicorn.process.terminate()

0 comments on commit 9b5b987

Please sign in to comment.