-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_dc_task.py
179 lines (146 loc) · 6.27 KB
/
run_dc_task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Main program
Created on Sep 06 2022
@author: Jesús Cid-Sueiro
"""
import pathlib
import argparse
import inspect
import yaml
# Local imports
from src.menu_navigator.menu_navigator import MenuNavigator
from src.task_manager import TaskManagerIMT # , TaskManagerCMD
# ########################
# Main body of application
def main():
# ################
# Get menu options
# Get the menu options from a options_menu file.
path2menu = pathlib.Path('config', 'options_menu.yaml')
# This is a fake menu because it is used only to get the list of options
options = MenuNavigator(None, path2menu).get_options(tasks_only=True)
# option_names = [x[0] for x in options]
options_txt = "\n".join([f' - {x[0]}: {x[1]}' for x in options])
# ####################
# Read input arguments
# Read input arguments
parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument(
'--task', required=True,
help=f"Command/task to be executed. It must be one of the following:"
f"\n{options_txt}")
parser.add_argument(
'--p', required=True,
help="path to a new or an existing project")
parser.add_argument(
'--source', default="../datasets",
help="path to the source data folder")
parser.add_argument(
'--zeroshot', default='../zero_shot_model/Sciro-Shot',
help="path to the zero-shot model folder")
parser.add_argument(
'--class_name',
help="Name of the labeled dataset")
parser.add_argument(
'-i', '--ignore_IMT_root_path', action='store_true',
help="Use the project path as it is, ignoring the imt root")
args, other_args = parser.parse_known_args()
# ################################################
# Read input arguments for the task manager method
# This if is used to avoid the cases which would re-read parameter
# class_name, because it has been already read, raising an error
if args.task not in {'load_labels', 'reset_labels'}:
params = inspect.getfullargspec(getattr(TaskManagerIMT, args.task))
n_params = len(params.args)
n_defaults = 0 if params.defaults is None else len(params.defaults)
# Index starts from 1 because argument 0 is 'self'
arg_names = params.args[1:n_params - n_defaults]
kwarg_names = params.args[n_params - n_defaults:]
default_values = [] if n_defaults == 0 else params.defaults
for arg in arg_names:
arg_type = str # Default
if arg in params.annotations:
arg_type = params.annotations[arg]
parser.add_argument(f'--{arg}', type=arg_type, required=True)
for arg, value in zip(kwarg_names, default_values):
arg_type = str # Default
if arg in params.annotations:
arg_type = params.annotations[arg]
parser.add_argument(f'--{arg}', type=arg_type, default=value)
args = parser.parse_args()
# Create task manager object
if args.ignore_IMT_root_path:
project_path = pathlib.Path(args.p)
else:
# This is used for the IMT: the project path is rooted into a default
# project path taken from the parameters.default.yaml.
# FIXME: This must be changed, because I do not think it is a good idea
# to take parameters from a file of default values.
with open('config/parameters.default.yaml', 'r', encoding='utf8') as f:
parameter_default = yaml.safe_load(f)
project_path = pathlib.Path(
parameter_default['project_folder_path']) / pathlib.Path(args.p)
tm = TaskManagerIMT(project_path, path2source=args.source,
path2zeroshot=args.zeroshot)
# #####################
# Run preparation tasks
# If the task is load or create, we simply need to run it.
if args.task not in {'create', 'load'}:
# Load or create project
if project_path.is_dir():
tm.load()
else:
tm.create()
# Load labels if the task requires it
options_needing_labels = {
'load_labels', 'evaluate_PUlabels', 'train_PUmodel',
'evaluate_PUmodel', 'performance_metrics_PU',
'performance_metrics_PN', 'get_feedback', 'sample_documents',
'get_labels_from_docs', 'annotate', 'retrain_model',
'reevaluate_model', 'import_annotations', 'export_annotations',
# Options added for the IMT:
'inference', 'on_retrain', 'on_classify', 'on_evaluate',
'on_sample', 'on_save_feedback'}
option = args.task
if option in options_needing_labels:
if args.class_name is not None:
tm.load_labels(args.class_name)
else:
raise TypeError(
f"Task {args.task} requires argument --class_name")
# ########
# Run task
if args.task not in {'load_labels', 'reset_labels'}:
# Get args
arg_values = [getattr(args, name) for name in arg_names]
# Get kwargs
# - default values
kwargs = dict(zip(kwarg_names, default_values))
# - user-defined values
for arg in kwarg_names:
if arg in args:
kwargs[arg] = getattr(args, arg)
# Run task
getattr(tm, args.task)(*arg_values, **kwargs)
# If the task is reset_labels run it with the class name.
elif args.task == 'reset_labels':
tm.reset_labels(args.class_name)
# The case load_labels is ignored, because that task has been already done.
# TEST PENDING:
# - get_feedback: Get relevance feedback from user
# - evaluate_PUlabels: Evaluate subcorpus with respect to a gold standard
# - evaluate_PUmodel: Evaluate PU classifier model with the available labs
# - retrain_model: Retrain model with manual annotations
# - reevaluate_model: Evaluate retrained model.
# - performance_metrics_PN: Show all performance metrics
# - import_annotations: Import annotations (overwrites existing annots)
# - export_annotations: Export annotations (delete older annot files)
print("\n*** END.\n")
return
# ############
# Execute main
if __name__ == '__main__':
main()