forked from ggerganov/kbd-audio
-
Notifications
You must be signed in to change notification settings - Fork 0
/
common.h
158 lines (133 loc) · 4.02 KB
/
common.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
/*! \file common.h
* \brief Common types and functions
* \author Georgi Gerganov
*/
#pragma once
#include "audio_logger.h"
#include <map>
#include <string>
#include <cstring>
#include <tuple>
#include <cmath>
#include <thread>
#include <mutex>
// types
using TSum = double;
using TSum2 = double;
using TConfidence = float;
using TValueCC = double;
using TOffset = int;
using TKey = int;
using TKeyWaveform = std::vector<AudioLogger::Sample>;
using TKeyHistory = std::vector<TKeyWaveform>;
using TKeyConfidenceMap = std::map<TKey, TConfidence>;
// helpers
static std::map<std::string, std::string> parseCmdArguments(int argc, char ** argv) {
int last = argc;
std::map<std::string, std::string> res;
for (int i = 1; i < last; ++i) {
if (argv[i][0] == '-') {
if (strlen(argv[i]) > 1) {
res[std::string(1, argv[i][1])] = strlen(argv[i]) > 2 ? argv[i] + 2 : "";
}
}
}
return res;
}
static std::tuple<TSum, TSum2> calcSum(const TKeyWaveform & waveform, int is0, int is1) {
TSum sum = 0.0f;
TSum2 sum2 = 0.0f;
for (int is = is0; is < is1; ++is) {
auto a0 = waveform[is];
sum += a0;
sum2 += a0*a0;
}
return std::tuple<TSum, TSum2>(sum, sum2);
}
static TValueCC calcCC(
const TKeyWaveform & waveform0,
const TKeyWaveform & waveform1,
TSum sum0, TSum2 sum02,
int is00, int is0, int is1) {
TValueCC cc = -1.0f;
TSum sum1 = 0.0f;
TSum2 sum12 = 0.0f;
TSum2 sum01 = 0.0f;
for (int is = 0; is < is1 - is0; ++is) {
auto a0 = waveform0[is00 + is];
auto a1 = waveform1[is0 + is];
#ifdef MY_DEBUG
if (is00 + is < 0 || is00 + is >= waveform0.size()) printf("BUG 0\n");
if (is0 + is < 0 || is0 + is >= waveform1.size()) {
printf("BUG 1\n");
printf("%d %d %d\n", is0, is, (int) waveform1.size());
}
#endif
sum1 += a1;
sum12 += a1*a1;
sum01 += a0*a1;
}
int ncc = (is1 - is0);
{
double nom = sum01*ncc - sum0*sum1;
double den2a = sum02*ncc - sum0*sum0;
double den2b = sum12*ncc - sum1*sum1;
cc = (nom)/(sqrt(den2a*den2b));
}
return cc;
}
std::tuple<TValueCC, TOffset> findBestCC(
const TKeyWaveform & waveform0,
const TKeyWaveform & waveform1,
int is0, int is1,
int alignWindow) {
TOffset besto = -1;
TValueCC bestcc = -1.0f;
int is00 = waveform0.size()/2 - (is1 - is0)/2;
//auto [sum0, sum02] = calcSum(waveform0, is00, is00 + is1 - is0);
auto ret = calcSum(waveform0, is00, is00 + is1 - is0);
auto sum0 = std::get<0>(ret);
auto sum02 = std::get<1>(ret);
#ifdef __EMSCRIPTEN__
TOffset cbesto = -1;
TValueCC cbestcc = -1.0f;
for (int o = -alignWindow; o < alignWindow; ++o) {
auto cc = calcCC(waveform0, waveform1, sum0, sum02, is00, is0 + o, is1 + o);
if (cc > cbestcc) {
cbesto = o;
cbestcc = cc;
}
}
if (cbestcc > bestcc) {
bestcc = cbestcc;
besto = cbesto;
}
#else
int nWorkers = std::min(4u, std::thread::hardware_concurrency());
std::mutex mutex;
std::vector<std::thread> workers(nWorkers);
for (int i = 0; i < workers.size(); ++i) {
auto & worker = workers[i];
worker = std::thread([&, sum0 = sum0, sum02 = sum02, i]() {
TOffset cbesto = -1;
TValueCC cbestcc = -1.0f;
for (int o = -alignWindow + i; o < alignWindow; o += nWorkers) {
auto cc = calcCC(waveform0, waveform1, sum0, sum02, is00, is0 + o, is1 + o);
if (cc > cbestcc) {
cbesto = o;
cbestcc = cc;
}
}
{
std::lock_guard<std::mutex> lock(mutex);
if (cbestcc > bestcc) {
bestcc = cbestcc;
besto = cbesto;
}
}
});
}
for (auto & worker : workers) worker.join();
#endif
return std::tuple<TValueCC, TOffset>(bestcc, besto);
}