forked from dhlab-epfl/dhSegment
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dh_segment_train.py
executable file
·145 lines (124 loc) · 6.97 KB
/
dh_segment_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#!/usr/bin/env python
import os
import tensorflow as tf
# Tensorflow logging level
from logging import WARNING # import DEBUG, INFO, ERROR for more/less verbosity
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # or any {'0', '1', '2'}
tf.logging.set_verbosity(WARNING)
from dh_segment_text import estimator_fn, utils
from dh_segment_text.io import input
import json
try:
import better_exceptions
except ImportError:
print('/!\ W -- Not able to import package better_exceptions')
pass
from tqdm import trange
from sacred import Experiment
ex = Experiment('dhSegment_experiment')
@ex.config
def default_config():
train_data = None # Directory with training data
eval_data = None # Directory with validation data
model_output_dir = None # Directory to output tf model
restore_model = False # Set to true to continue training
classes_file = None # txt file with classes values (unused for REGRESSION)
gpu = '' # GPU to be used for training
use_embeddings = False
weights_histogram = False
seed_augment = False
embeddings_dim = 300
prediction_type = utils.PredictionType.CLASSIFICATION # One of CLASSIFICATION, REGRESSION or MULTILABEL
model_params = utils.ModelParams().to_dict() # Model parameters
embeddings_params = utils.EmbeddingsParams().to_dict() # Embeddings parameters
training_params = utils.TrainingParams().to_dict() # Training parameters
if prediction_type == utils.PredictionType.CLASSIFICATION:
assert classes_file is not None
model_params['n_classes'] = utils.get_n_classes_from_file(classes_file)
elif prediction_type == utils.PredictionType.REGRESSION:
model_params['n_classes'] = 1
elif prediction_type == utils.PredictionType.MULTILABEL:
assert classes_file is not None
model_params['n_classes'] = utils.get_n_classes_from_file_multilabel(classes_file)
@ex.automain
def run(train_data, eval_data, model_output_dir, gpu, training_params, use_embeddings, embeddings_dim, _config):
tf.set_random_seed(_config['seed'])
# Create output directory
if not os.path.isdir(model_output_dir):
os.makedirs(model_output_dir)
else:
assert _config.get('restore_model'), \
'{0} already exists, you cannot use it as output directory. ' \
'Set "restore_model=True" to continue training, or delete dir "rm -r {0}"'.format(model_output_dir)
# Save config
with open(os.path.join(model_output_dir, 'config.json'), 'w') as f:
json.dump(_config, f, indent=4, sort_keys=True)
# Create export directory for saved models
saved_model_dir = os.path.join(model_output_dir, 'export')
if not os.path.isdir(saved_model_dir):
os.makedirs(saved_model_dir)
training_params = utils.TrainingParams.from_dict(training_params)
session_config = tf.ConfigProto()
session_config.gpu_options.visible_device_list = str(gpu)
session_config.gpu_options.per_process_gpu_memory_fraction = 1.0
estimator_config = tf.estimator.RunConfig().replace(session_config=session_config,
save_summary_steps=10,
keep_checkpoint_max=1,
tf_random_seed=_config['seed'])
estimator = tf.estimator.Estimator(estimator_fn.model_fn, model_dir=model_output_dir,
params=_config, config=estimator_config)
def get_dirs_or_files(input_data):
if os.path.isdir(input_data):
image_input, labels_input = os.path.join(input_data, 'images'), os.path.join(input_data, 'labels')
# Check if training dir exists
assert os.path.isdir(image_input), "{} is not a directory".format(image_input)
assert os.path.isdir(labels_input), "{} is not a directory".format(labels_input)
elif os.path.isfile(input_data) and input_data.endswith('.csv'):
image_input = input_data
labels_input = None
else:
raise TypeError('input_data {} is neither a directory nor a csv file'.format(input_data))
return image_input, labels_input
train_input, train_labels_input = get_dirs_or_files(train_data)
if eval_data is not None:
eval_input, eval_labels_input = get_dirs_or_files(eval_data)
# Configure exporter
serving_input_fn = input.serving_input_filename(training_params.input_resized_size, use_embeddings=use_embeddings, embeddings_dim=embeddings_dim)
exporter = tf.estimator.BestExporter(serving_input_receiver_fn=serving_input_fn, exports_to_keep=2)
#if eval_data is not None:
# exporter = tf.estimator.BestExporter(serving_input_receiver_fn=serving_input_fn, exports_to_keep=2)
#else:
# exporter = tf.estimator.LatestExporter(name='SimpleExporter', serving_input_receiver_fn=serving_input_fn,
# exports_to_keep=5)
nb_cores = os.cpu_count()
if nb_cores:
num_threads = min(nb_cores//2, 16)
else:
num_threads = 4
for i in trange(0, training_params.n_epochs, training_params.evaluate_every_epoch, desc='Evaluated epochs'):
estimator.train(input.input_fn(train_input,
input_label_dir=train_labels_input,
num_epochs=training_params.evaluate_every_epoch,
batch_size=training_params.batch_size,
data_augmentation=training_params.data_augmentation,
make_patches=training_params.make_patches,
image_summaries=True,
params=_config,
num_threads=num_threads,
progressbar_description="Training".format(i),
seed=_config['seed']))
if eval_data is not None:
eval_result = estimator.evaluate(input.input_fn(eval_input,
input_label_dir=eval_labels_input,
batch_size=1,
data_augmentation=False,
make_patches=False,
image_summaries=False,
params=_config,
num_threads=num_threads,
progressbar_description="Evaluation"
))
else:
eval_result = None
exporter.export(estimator, saved_model_dir, checkpoint_path=None, eval_result=eval_result,
is_the_final_export=False)