-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
132 lines (113 loc) · 4.91 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import argparse
import os
from ngram_model import NgramModel
from genetic_algorithm import GeneticAlgorithm
from processing import split_syllables, get_top_symbols, split_symbols
from beam_search import BeamSearch
DEFAULT_N_GRAM = 3
DEFAULT_NUM_SYMBOLS = 50
DEFAULT_NODES = 10
DEFAULT_BEAM_WIDTH = 100
DEFAULT_POP_SIZE = 2000
DEFAULT_NUM_PARENTS = 1000
DEFAULT_MUTATION_RATE = 0.5
DEFAULT_CROSSOVER_RATE = 0.8
DEFAULT_GENERATIONS = 200
DEFAULT_ALGORITHM = 'bs'
DEFAULT_N_CORES = 0
def parse_arguments() -> argparse.Namespace:
"""
Parse command line arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument('source', type=str, help='Source file')
parser.add_argument('target', type=str, help='Target file')
parser.add_argument('-a', '--algorithm', type=str, default=DEFAULT_ALGORITHM, help='Algorithm')
parser.add_argument('-ng', '--n-gram', type=int, default=DEFAULT_N_GRAM, help='N-gram')
parser.add_argument('-s', '--num-symbols', type=int, default=DEFAULT_NUM_SYMBOLS, help='Number of symbols')
parser.add_argument('-i', '--ignore', type=str, help='Ignore symbols')
parser.add_argument('-n', '--nodes', type=int, default=DEFAULT_NODES, help='Number of nodes (BS)')
parser.add_argument('-bw', '--beam-width', type=int, default=DEFAULT_BEAM_WIDTH, help='Beam width (BS)')
parser.add_argument('-p', '--pop-size', type=int, default=DEFAULT_POP_SIZE, help='Population size (GA)')
parser.add_argument('-np', '--num-parents', type=int, default=DEFAULT_NUM_PARENTS, help='Number of parents')
parser.add_argument('-m', '--mutation-rate', type=float, default=DEFAULT_MUTATION_RATE, help='Mutation rate')
parser.add_argument('-c', '--crossover-rate', type=float, default=DEFAULT_CROSSOVER_RATE, help='Crossover rate')
parser.add_argument('-g', '--generations', type=int, default=DEFAULT_GENERATIONS, help='Number of generations')
parser.add_argument('-nc', '--n-cores', type=int, default=DEFAULT_N_CORES, help='Number of cores')
parser.add_argument('-e', '--eval', type=str, help='Evaluation file')
parser.add_argument('-o', '--output', type=str, help='Output folder')
return parser.parse_args()
def main(args: argparse.Namespace) -> None:
"""
Main function.
"""
source_file = args.source
target_file = args.target
n_gram = args.n_gram
num_symbols = args.num_symbols
ignore = ['?'] + args.ignore.split(',') if args.ignore else ['?']
try:
with open(source_file, 'r') as file:
source_text = file.read().splitlines()
except IOError as e:
print(f'Error: Could not read source file {source_file}: {e}')
return
try:
with open(target_file, 'r') as file:
target_text = file.read().splitlines()
except IOError as e:
print(f'Error: Could not read target file {target_file}: {e}')
return
source_text_split = [split_symbols(line) for line in source_text]
target_text_split = [split_syllables(line) for line in target_text]
bigram_model = NgramModel(n_gram)
bigram_model.fit(target_text_split)
source_symbols = get_top_symbols(source_text_split, num_symbols, ignore)
target_symbols = get_top_symbols(target_text_split, num_symbols)
if args.algorithm == 'ga':
solver = GeneticAlgorithm(
source_symbols,
target_symbols,
source_text_split,
bigram_model,
args.pop_size,
args.num_parents,
args.mutation_rate,
args.crossover_rate,
n_cores=args.n_cores
)
elif args.algorithm == 'bs':
solver = BeamSearch(
source_symbols,
target_symbols,
source_text_split,
bigram_model,
args.nodes,
args.beam_width
)
solver.run(args.generations)
if args.output:
output_folder = args.output
if not os.path.exists(output_folder):
try:
os.makedirs(output_folder)
except OSError as e:
print(f'Error: Could not create output directory {args.output}: {e}')
return
solver.write_result(os.path.join(output_folder, 'best_key.txt'))
solver.plot(os.path.join(output_folder, 'plot.png'))
if args.eval:
try:
with open(args.eval, 'r') as file:
eval_text = file.read().splitlines()
eval_symbols = [x.split() for x in eval_text]
eval_map = {x[0]: x[1] for x in eval_symbols}
except IOError as e:
print(f'Error: Could not read file {args.eval}: {e}')
return
correct_count = sum(1 for k, v in solver.best_key.items() if k in eval_map and eval_map[k] == v)
total_count = len(solver.best_key)
print(f'Correct symbols: {correct_count} / {total_count} ({correct_count / total_count * 100:.2f}%)')
if __name__ == '__main__':
args = parse_arguments()
main(args)