-
Notifications
You must be signed in to change notification settings - Fork 0
/
typing_trainer.py
149 lines (122 loc) · 4.03 KB
/
typing_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
from statistics import median
import sys
import time
import numpy as np
from getcharacter import _Getch
class Log:
def __init__(self, timestamp, key, typed_key, elapsed):
self.timestamp = int(timestamp)
self.key = key
self.typed_key = typed_key
self.elapsed = round(float(elapsed), 2)
@property
def is_error(self):
return self.key != self.typed_key
@property
def is_correct(self):
return not self.is_error
def serialize(self):
return ', '.join([str(self.timestamp),
self.key,
self.typed_key,
str(self.elapsed)])
@staticmethod
def deserialize(string):
try:
return Log(*string.split(', '))
except TypeError:
print(f'Could not parse {string}')
raise
class Stat:
def __init__(self):
self.count = 0
self.n_errors = 0
self.times = []
assert self.count >= self.n_errors
def update(self, log):
self.count += 1
self.n_errors += log.is_error
self.times.append(log.elapsed)
self.median_time = median(self.times)
@property
def error_rate(self):
if self.count == 0:
return 0
return self.n_errors / self.count
@property
def average_reward(self):
# Hand crafted reward
error_weight = 5
time_weight = 1
return (error_weight * self.error_rate + time_weight * self.median_time) / (error_weight + time_weight)
def UCB_score(self, total_counts):
if self.count == 0:
# If the character was not tested yet, we need to test it
return np.inf
return self.average_reward + np.sqrt(np.log(total_counts) / (2 * self.count))
def read_chars(path):
with open(path, 'r') as f:
return [line.strip('\n') for line in f]
def read_logs(path):
if not os.path.exists(path):
return []
logs = []
with open(path, 'r') as f:
for line in f:
logs.append(Log.deserialize(line))
return logs
def compute_stats(logs, chars):
stats = {char: Stat() for char in chars}
for log in logs:
if log.key not in stats.keys():
continue
stats[log.key].update(log)
return stats
def write_logs(logs, path):
with open(path, 'w') as f:
for log in logs:
print(log.serialize(), file=f)
def choose_characters(stats, n=1):
total_counts = sum([stat.count for stat in stats.values()])
chars, stats_list = zip(*stats.items())
scores = [stat.UCB_score(total_counts) for stat in stats_list]
return [chars[idx] for idx in np.argsort(scores)[::-1][:n]]
current_script_dir = os.path.dirname(os.path.realpath(__file__))
logs_path = os.path.join(current_script_dir, 'logs.txt')
chars_path = os.path.join(current_script_dir, 'chars.txt')
logs = read_logs(logs_path)
chars = read_chars(chars_path)
getch = _Getch()
print('Welcome to typing trainer!')
input('Press any key to continue')
print('-' * 50)
# Iterate over batches
n_batches = 4
batch_size = 5
n_samples = n_batches * batch_size
sample_idx = 0
for _ in range(n_batches):
stats = compute_stats(logs, chars)
keys = choose_characters(stats, n=batch_size)
# Iterate over keys in current batch
for key in keys:
sample_idx += 1
msg = f'\r({sample_idx}/{n_samples}) Type: {key}\t'
msg += f'(Attempts: {stats[key].count}, '
msg += f'Errors: {stats[key].n_errors}, '
msg += f'Median time: {stats[key].median_time:.2f}, '
msg += f'Reward: {stats[key].average_reward:.2f})'
sys.stdout.write(msg)
while True:
start = time.time()
typed_key = getch()
elapsed = time.time() - start
log = Log(start, key, typed_key, elapsed)
logs.append(log)
if log.is_correct:
break
sys.stdout.write('\n')
stats = compute_stats(logs, chars)
write_logs(logs, logs_path)
stats_path = os.path.join(current_script_dir, 'stats.txt')