forked from hikopensource/DAVAR-Lab-OCR
-
Notifications
You must be signed in to change notification settings - Fork 2
/
res32_ace.py
101 lines (87 loc) · 2.74 KB
/
res32_ace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
##################################################################################################
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
# Filename : ace_res32.py
# Abstract : ACE recognition Model
# Current Version: 1.0.0
# Date : 2021-06-11
##################################################################################################
"""
# encoding=utf-8
_base_ = [
'./baseline.py'
]
# recognition dictionary
character = "/data1/open-source/demo/text_recognition/__dictionary__/Scene_text_68.txt"
"""
1. Model Settings
include model-related setting, such as model type, user-selected modules and parameters.
"""
# model parameters for changing the ace
model = dict(
sequence_module=dict(
type='CascadeRNN',
rnn_modules=[
dict(
type='BidirectionalLSTM',
input_size=512,
hidden_size=256,
output_size=512,
with_linear=True,
bidirectional=True,), ],
_delete_=True
),
sequence_head=dict(
type='ACEHead',
embed_size=512,
batch_max_length=25,
loss_ace=dict(
type="ACELoss",
character=character),
converter=dict(
type='ACELabelConverter',
character=character, ),
_delete_=True
),
)
data = dict(
sampler=dict(
type='DistBatchBalancedSampler',
mode=0,
),
)
checkpoint_config = dict(type="DavarCheckpointHook",
interval=1,
iter_interval=5000,
by_epoch=True,
by_iter=True,
filename_tmpl='ckpt/res32_ace_e{}.pth',
metric="accuracy",
rule="greater",
save_mode="lightweight",
init_metric=-1,
model_milestone=0.5
)
# logger setting
log_config = dict(
interval=50,
hooks=[dict(type='TextLoggerHook'), ])
# evaluation setting
evaluation = dict(start=3,
start_iter=0.5,
save_best="accuracy",
interval=1,
iter_interval=5000,
model_type="RECOGNIZOR",
eval_mode="lightweight",
by_epoch=True,
by_iter=True,
rule="greater",
metric=['accuracy', 'NED'],
)
# runner setting
runner = dict(type='EpochBasedRunner', max_epochs=6)
# work directory
work_dir = '/data1/workdir/davar_opensource/ace/'
# distributed training setting
dist_params = dict(backend='nccl')