-
Notifications
You must be signed in to change notification settings - Fork 43
/
decode_ctc.py
executable file
·177 lines (156 loc) · 5.55 KB
/
decode_ctc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#!/usr/bin/env python3
from typing import Any, Dict, List, Optional
import jsonargparse
import pytorch_lightning as pl
import laia.common.logging as log
from laia.callbacks import Decode, ProgressBar, Segmentation
from laia.common.arguments import CommonArgs, DataArgs, DecodeArgs, TrainerArgs
from laia.common.loader import ModelLoader
from laia.decoders import CTCGreedyDecoder, CTCLanguageDecoder
from laia.engine import Compose, DataModule, EvaluatorModule, ImageFeeder, ItemFeeder
from laia.scripts.htr import common_main
from laia.utils import ImageStats, SymbolsTable
def run(
syms: str,
img_list: str,
img_dirs: Optional[List[str]] = None,
common: CommonArgs = CommonArgs(),
data: DataArgs = DataArgs(),
decode: DecodeArgs = DecodeArgs(),
trainer: TrainerArgs = TrainerArgs(),
num_workers: Optional[int] = None,
):
loader = ModelLoader(
common.train_path, filename=common.model_filename, device="cpu"
)
checkpoint = loader.prepare_checkpoint(
common.checkpoint,
common.experiment_dirpath,
common.monitor,
)
model = loader.load_by(checkpoint)
assert (
model is not None
), "Could not find the model. Have you run pylaia-htr-create-model?"
# prepare the evaluator
evaluator_module = EvaluatorModule(
model,
batch_input_fn=Compose([ItemFeeder("img"), ImageFeeder()]),
batch_id_fn=ItemFeeder("id"),
)
# prepare the symbols
syms = SymbolsTable(syms)
# prepare the data
im_stats = ImageStats(stage="test", img_list=img_list, img_dirs=img_dirs)
data_module = DataModule(
syms=syms,
img_dirs=img_dirs,
te_img_list=img_list,
batch_size=data.batch_size,
min_valid_size=model.get_min_valid_image_size(im_stats.max_width)
if im_stats.is_fixed_height
else None,
color_mode=data.color_mode,
stage="test",
num_workers=num_workers,
)
if decode.use_language_model:
decoder = CTCLanguageDecoder(
language_model_path=decode.language_model_path,
tokens_path=decode.tokens_path,
lexicon_path=decode.lexicon_path,
language_model_weight=decode.language_model_weight,
blank_token=decode.blank_token,
unk_token=decode.unk_token,
sil_token=decode.input_space,
temperature=decode.temperature,
)
# word-level confidence scores are not supported when using a language model
decode.print_word_confidence_scores = False
else:
decoder = CTCGreedyDecoder(
temperature=decode.temperature,
)
# prepare the testing callbacks
callbacks = [
ProgressBar(refresh_rate=trainer.progress_bar_refresh_rate),
Segmentation(
syms,
segmentation=decode.segmentation,
input_space=decode.input_space,
separator=decode.separator,
include_img_ids=decode.include_img_ids,
)
if bool(decode.segmentation)
else Decode(
decoder=decoder,
syms=syms,
use_symbols=decode.use_symbols,
input_space=decode.input_space,
output_space=decode.output_space,
convert_spaces=decode.convert_spaces,
join_string=decode.join_string,
separator=decode.separator,
include_img_ids=decode.include_img_ids,
print_line_confidence_scores=decode.print_line_confidence_scores,
print_word_confidence_scores=decode.print_word_confidence_scores,
),
]
# prepare the trainer
trainer = pl.Trainer(
default_root_dir=common.train_path,
callbacks=callbacks,
logger=False,
**vars(trainer),
)
# decode!
trainer.test(evaluator_module, datamodule=data_module, verbose=False)
def get_args(argv: Optional[List[str]] = None) -> Dict[str, Any]:
parser = jsonargparse.ArgumentParser(parse_as_dict=True)
parser.add_argument(
"--config", action=jsonargparse.ActionConfigFile, help="Configuration file"
)
parser.add_argument(
"syms",
type=str,
help=(
"Mapping from strings to integers. "
"The CTC symbol must be mapped to integer 0"
),
)
parser.add_argument(
"img_list",
type=str,
help=(
"File containing the images to decode. Each image is expected to be in one "
'line. Lines starting with "#" will be ignored. Lines can be filepaths '
'(e.g. "/tmp/img.jpg") or filenames of images present in --img_dirs (e.g. '
"img.jpg). The filename extension is optional and case insensitive"
),
)
parser.add_argument(
"--img_dirs",
type=Optional[List[str]],
default=None,
help=(
"Directories containing word images. "
"Optional if `img_list` contains filepaths"
),
)
parser.add_class_arguments(CommonArgs, "common")
parser.add_class_arguments(DataArgs, "data")
parser.add_function_arguments(log.config, "logging")
parser.add_class_arguments(DecodeArgs, "decode")
parser.add_class_arguments(TrainerArgs, "trainer")
args = parser.parse_args(argv, with_meta=False)
args["common"] = CommonArgs(**args["common"])
args["data"] = DataArgs(**args["data"])
args["decode"] = DecodeArgs(**args["decode"])
args["trainer"] = TrainerArgs(**args["trainer"])
return args
def main():
args = get_args()
args = common_main(args)
run(**args)
if __name__ == "__main__":
main()