From 9004599d02d97f32cf3a6ceb4cf164a175fce16b Mon Sep 17 00:00:00 2001 From: Jelmer Date: Wed, 4 Sep 2024 09:28:27 +0200 Subject: [PATCH] Move sample command to opuscleaner-datasets (#162) --- opuscleaner/datasets.py | 112 ++++++++++++++++++++++++++++++++++++---- opuscleaner/server.py | 99 ++++------------------------------- pyproject.toml | 1 + 3 files changed, 114 insertions(+), 98 deletions(-) diff --git a/opuscleaner/datasets.py b/opuscleaner/datasets.py index b639b02..be91559 100644 --- a/opuscleaner/datasets.py +++ b/opuscleaner/datasets.py @@ -1,11 +1,17 @@ #!/usr/bin/env python3 """Lists datasets given a directory. It works by scanning the directory and looking for gz files.""" +import asyncio +import os +import pprint +import sys from glob import glob from itertools import groupby -from pathlib import Path as Path -from typing import Dict, List, Tuple +from pathlib import Path +from shutil import copyfileobj +from tempfile import TemporaryFile +from typing import Dict, List, Tuple, Iterable -from opuscleaner.config import DATA_PATH +from opuscleaner.config import DATA_PATH, SAMPLE_PY, SAMPLE_SIZE def list_datasets(path:str) -> Dict[str,List[Tuple[str,Path]]]: @@ -39,13 +45,101 @@ def list_datasets(path:str) -> Dict[str,List[Tuple[str,Path]]]: } -def main() -> None: - import sys - import pprint - if len(sys.argv) == 1: - pprint.pprint(list_datasets(DATA_PATH)) +def dataset_path(name:str, template:str) -> str: + # TODO: fix this hack to get the file path from the name this is silly we + # should just use get_dataset(name).path or something + root = DATA_PATH.split('*')[0] + + # If the dataset name is a subdirectory, do some hacky shit to get to a + # .sample.gz file in said subdirectory. + parts = name.rsplit('/', maxsplit=2) + if len(parts) == 2: + root = os.path.join(root, parts[0]) + filename = parts[1] else: - pprint.pprint(list_datasets(sys.argv[1])) + filename = parts[0] + + return os.path.join(root, template.format(filename)) + + +def filter_configuration_path(name:str) -> str: + return dataset_path(name, '{}.filters.json') + + +def sample_path(name:str, langs:Iterable[str]) -> str: + languages = '.'.join(sorted(langs)) + return dataset_path(name, f'.sample.{{}}.{languages}') + + +def main_list_commands(args): + print("Error: No command specified.\n\n" + "Available commands:\n" + " list list datasets\n" + " sample sample all datasets\n" + "", file=sys.stderr) + sys.exit(1) + + +async def sample_all_datasets(args): + tasks = [] + + for name, columns in list_datasets(DATA_PATH).items(): + langs = [lang for lang, _ in columns] + if not os.path.exists(sample_path(name, langs)) or args.force: + print(f"Sampling {name}...", file=sys.stderr) + tasks.append([name, columns]) + + for task, result in zip(tasks, await asyncio.gather(*[compute_sample(*task) for task in tasks], return_exceptions=True)): + if isinstance(result, Exception): + print(f"Could not compute sample for {task[0]}: {result!s}", file=sys.stderr) + + +async def compute_sample(name:str, columns:List[Tuple[str,Path]]) -> None: + langs = [lang for lang, _ in columns] + with TemporaryFile() as tempfile: + proc = await asyncio.subprocess.create_subprocess_exec( + *SAMPLE_PY, + '-n', str(SAMPLE_SIZE), + *[str(file.resolve()) for _, file in columns], + stdout=tempfile, + stderr=asyncio.subprocess.PIPE) + + _, stderr = await proc.communicate() + + if proc.returncode != 0: + raise Exception(f'sample.py failed with exit code {proc.returncode}: {stderr.decode()}') + + tempfile.seek(0) + + with open(sample_path(name, langs), 'wb') as fdest: + copyfileobj(tempfile, fdest) + + +def main_list(args): + pprint.pprint(list_datasets(args.path)) + + +def main_sample(args): + asyncio.run(sample_all_datasets(args)) + + +def main(argv=sys.argv): + import argparse + + parser = argparse.ArgumentParser(description='Fill up those seats on your empty train.') + parser.set_defaults(func=main_list_commands) + subparsers = parser.add_subparsers() + + parser_serve = subparsers.add_parser('list') + parser_serve.add_argument('path', nargs="?", type=str, default=DATA_PATH) + parser_serve.set_defaults(func=main_list) + + parser_sample = subparsers.add_parser('sample') + parser_sample.add_argument("--force", "-f", action="store_true") + parser_sample.set_defaults(func=main_sample) + + args = parser.parse_args() + args.func(args) if __name__ == '__main__': diff --git a/opuscleaner/server.py b/opuscleaner/server.py index bdaa884..f58ac3c 100644 --- a/opuscleaner/server.py +++ b/opuscleaner/server.py @@ -12,9 +12,8 @@ from enum import Enum from glob import glob from itertools import chain, zip_longest +from pathlib import Path from pprint import pprint -from shutil import copyfileobj -from tempfile import TemporaryFile from typing import NamedTuple, Optional, Iterable, TypeVar, Union, Literal, Any, AsyncIterator, Awaitable, cast, IO, List, Dict, Tuple, AsyncIterator from warnings import warn @@ -32,7 +31,7 @@ from opuscleaner._util import none_throws from opuscleaner.categories import app as categories_app from opuscleaner.config import DATA_PATH, FILTER_PATH, COL_PY, SAMPLE_PY, SAMPLE_SIZE -from opuscleaner.datasets import list_datasets, Path +from opuscleaner.datasets import list_datasets, dataset_path, sample_path, filter_configuration_path, compute_sample from opuscleaner.download import app as download_app from opuscleaner.filters import filter_format_command, format_shell, get_global_filter, get_global_filters, set_global_filters, list_filters, FilterType, FilterStep, FilterPipeline from opuscleaner.sample import sample @@ -67,53 +66,6 @@ class FilterPipelinePatch(BaseModel): filters: List[FilterStep] -def dataset_path(name:str, template:str) -> str: - # TODO: fix this hack to get the file path from the name this is silly we - # should just use get_dataset(name).path or something - root = DATA_PATH.split('*')[0] - - # If the dataset name is a subdirectory, do some hacky shit to get to a - # .sample.gz file in said subdirectory. - parts = name.rsplit('/', maxsplit=2) - if len(parts) == 2: - root = os.path.join(root, parts[0]) - filename = parts[1] - else: - filename = parts[0] - - return os.path.join(root, template.format(filename)) - - -def sample_path(name:str, langs:Iterable[str]) -> str: - languages = '.'.join(sorted(langs)) - return dataset_path(name, f'.sample.{{}}.{languages}') - - -def filter_configuration_path(name:str) -> str: - return dataset_path(name, '{}.filters.json') - - -async def compute_sample(name:str, columns:List[Tuple[str,Path]]) -> None: - langs = [lang for lang, _ in columns] - with TemporaryFile() as tempfile: - proc = await asyncio.subprocess.create_subprocess_exec( - *SAMPLE_PY, - '-n', str(SAMPLE_SIZE), - *[str(file.resolve()) for _, file in columns], - stdout=tempfile, - stderr=asyncio.subprocess.PIPE) - - _, stderr = await proc.communicate() - - if proc.returncode != 0: - raise Exception(f'sample.py returned {proc.returncode}: {stderr.decode()}') - - tempfile.seek(0) - - with open(sample_path(name, langs), 'wb') as fdest: - copyfileobj(tempfile, fdest) - - class FilterOutput(NamedTuple): langs: List[str] # order of columns returncode: int @@ -173,7 +125,7 @@ async def get_dataset_sample(name:str, columns:List[Tuple[str,Path]]) -> FilterO return FilterOutput([lang for lang, _ in columns], 0, stdout, bytes()) -async def exec_filter_step(filter_step: FilterStep, langs: List[str], input: bytes) -> Tuple[bytes,bytes]: +async def exec_filter_step(filter_step: FilterStep, langs: List[str], input: bytes) -> FilterOutput: filter_definition = get_global_filter(filter_step.filter) command = filter_format_command(filter_definition, filter_step, langs) @@ -198,6 +150,7 @@ async def exec_filter_step(filter_step: FilterStep, langs: List[str], input: byt # Check exit codes, testing most obvious problems first. stdout, stderr = await p_filter.communicate(input=input) + assert p_filter.returncode is not None return FilterOutput(langs, p_filter.returncode, stdout, stderr) @@ -472,56 +425,24 @@ def redirect_to_interface(): app.mount('/api/categories/', categories_app) + def main_serve(args): import uvicorn uvicorn.run(f'opuscleaner.server:app', host=args.host, port=args.port, reload=args.reload, log_level='info') -async def sample_all_datasets(args): - tasks = [] - - for name, columns in list_datasets(DATA_PATH).items(): - langs = [lang for lang, _ in columns] - if not os.path.exists(sample_path(name, langs)): - print(f"Sampling {name}...", file=sys.stderr) - tasks.append([name, columns]) - - for task, result in zip(tasks, await asyncio.gather(*[compute_sample(*task) for task in tasks], return_exceptions=True)): - if isinstance(result, Exception): - print(f"Could not compute sample for {task[0]}: {result!s}", file=sys.stderr) - - -def main_sample(args): - asyncio.run(sample_all_datasets(args)) - - -def main_list_commands(args): - print("Error: No command specified.\n\n" - "Available commands:\n" - " serve run webserver\n" - " sample sample all datasets\n" - "", file=sys.stderr) - sys.exit(1) - - def main(argv=sys.argv): import argparse parser = argparse.ArgumentParser(description='Fill up those seats on your empty train.') - parser.set_defaults(func=main_list_commands) - subparsers = parser.add_subparsers() - - parser_serve = subparsers.add_parser('serve') - parser_serve.add_argument('--host', type=str, default='127.0.0.1', help='Bind socket to this host. (default: 127.0.0.1)') - parser_serve.add_argument('-p', '--port', type=int, default=8000, help='Bind socket to this port. (default: 8000)') - parser_serve.add_argument('--reload', action='store_true', help='Enable auto-reload.') - parser_serve.set_defaults(func=main_serve) - - parser_sample = subparsers.add_parser('sample') - parser_sample.set_defaults(func=main_sample) + parser.add_argument('--host', type=str, default='127.0.0.1', help='Bind socket to this host. (default: 127.0.0.1)') + parser.add_argument('-p', '--port', type=int, default=8000, help='Bind socket to this port. (default: 8000)') + parser.add_argument('--reload', action='store_true', help='Enable auto-reload.') + parser.set_defaults(func=main_serve) args = parser.parse_args() args.func(args) + if __name__ == '__main__': main() diff --git a/pyproject.toml b/pyproject.toml index 3b5c5b7..2e36cbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ opuscleaner-clean = "opuscleaner.clean:main" opuscleaner-col = "opuscleaner.col:main" opuscleaner-threshold = "opuscleaner.threshold:main" opuscleaner-sample = "opuscleaner.sample:main" +opuscleaner-datasets = "opuscleaner.datasets:main" opuscleaner-download = "opuscleaner.download:main" [project.urls]