-
Notifications
You must be signed in to change notification settings - Fork 191
/
service.py
173 lines (142 loc) · 5.76 KB
/
service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import datetime
import importlib
import inspect
import json
import logging
import os
from typing import Dict
from urllib.parse import parse_qs
from fastapi import FastAPI, HTTPException, Request
from jsonargparse import Namespace
from pydantic import validate_call
from data_juicer.core.exporter import Exporter
from data_juicer.format.load import load_formatter
DJ_OUTPUT = 'outputs'
allowed_methods = {
'run', 'process', 'compute_stats', 'compute_hash', 'analyze', 'compute'
}
logger = logging.getLogger('uvicorn.error')
app = FastAPI()
def register_objects_from_init(directory: str):
"""
Traverse the specified directory for __init__.py files and
register objects defined in __all__.
"""
for dirpath, _, filenames in os.walk(os.path.normpath(directory)):
if '__init__.py' in filenames:
module_path = dirpath.replace(os.sep, '.')
module = importlib.import_module(module_path)
if hasattr(module, '__all__'):
for name in module.__all__:
obj = getattr(module, name)
if inspect.isclass(obj):
register_class(module, obj)
elif callable(obj):
register_function(module, obj)
def register_class(module, cls):
"""Register class and its methods as endpoints."""
def create_class_call(cls, method_name: str):
async def class_call(request: Request):
try:
# wrap init method
cls.__init__ = validate_call(
cls.__init__, config=dict(arbitrary_types_allowed=True))
# parse json body as cls init args
init_args = await request.json() if await request.body(
) else {}
# create an instance
instance = cls(**_setup_cfg(init_args))
# wrap called method
method = validate_call(getattr(instance, method_name))
result = _invoke(method, request)
return {'status': 'success', 'result': result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return class_call
module_path = module.__name__.replace('.', os.sep)
cls_name = cls.__name__
for method_name in _get_public_methods(cls, allowed_methods):
api_path = f'/{module_path}/{cls_name}/{method_name}'
class_call = create_class_call(cls, method_name)
app.add_api_route(api_path,
class_call,
methods=['POST'],
tags=['POST'])
logger.debug(f'Registered {api_path}')
def register_function(module, func):
"""Register a function as an endpoint."""
def create_func_call(func):
async def func_call(request: Request):
try:
nonlocal func
func = validate_call(func,
config=dict(arbitrary_types_allowed=True))
result = _invoke(func, request)
return {'status': 'success', 'result': result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return func_call
module_path = module.__name__.replace('.', os.sep)
func_name = func.__name__
api_path = f'/{module_path}/{func_name}'
func_call = create_func_call(func)
app.add_api_route(api_path, func_call, methods=['GET'], tags=['GET'])
logger.debug(f'Registered {api_path}')
def _invoke(callable, request):
# parse query params as cls method args
q_params = parse_qs(request.url.query, keep_blank_values=True)
# flatten lists with a single element
d_params = dict(
(k, v if len(v) > 1 else v[0]) for k, v in q_params.items())
# pre-processing
d_params = _setup_cfg(d_params)
exporter = _setup_dataset(d_params)
skip_return = d_params.pop('skip_return', False)
# invoke callable
result = callable(**d_params)
# post-processing
if exporter is not None:
exporter.export(result)
result = exporter.export_path
if skip_return:
result = ''
return result
def _setup_cfg(params: Dict):
"""convert string `cfg` to Namespace"""
# TODO: Traverse method's signature and convert any arguments \
# that should be Namespace but are passed as str
if cfg_str := params.get('cfg'):
if isinstance(cfg_str, str):
cfg = Namespace(**json.loads(cfg_str))
params['cfg'] = cfg
return params
def _setup_dataset(params: Dict):
"""setup dataset loading and exporting"""
exporter = None
if dataset_path := params.get('dataset'):
if isinstance(dataset_path, str):
dataset = load_formatter(dataset_path).load_dataset()
params['dataset'] = dataset
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
export_path = os.path.join(DJ_OUTPUT, timestamp,
'processed_data.jsonl')
exporter = Exporter(export_path,
keep_stats_in_res_ds=True,
keep_hashes_in_res_ds=True,
export_stats=False)
return exporter
def _get_public_methods(cls, allowed=None):
"""Get public methods of a class."""
all_methods = inspect.getmembers(cls, predicate=inspect.isfunction)
return [
name for name, _ in all_methods
if not name.startswith('_') and (allowed is None or name in allowed)
]
# Specify the directories to search
directories_to_search = [
'data_juicer',
# "tools", # Uncomment to add more directories
]
# Register objects from each specified directory
for directory in directories_to_search:
register_objects_from_init(directory)