-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Whisper inference support in cpp runtime #2320
Conversation
weight = (right_mel - mel) / (right_mel - center_mel); | ||
} else if (mel_type == MelType::Slaney) { | ||
if (mel <= center_mel) { | ||
weight = (InverseMelScale(mel, mel_type) - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
window_[i] = pow(0.5 - 0.5 * cos(a * i), 0.85); | ||
} else if (window_type == WindowType::Hanning) { | ||
// periodic hanning window | ||
double a = M_2PI / (frame_length_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://pytorch.org/docs/stable/generated/torch.hann_window.html#torch.hann_window:~:text=periodic%20(bool%2C%20optional)%20%E2%80%93%20If%20True%2C%20returns%20a%20window%20to%20be%20used%20as%20periodic%20function.%20If%20False%2C%20return%20a%20symmetric%20window. default period is true, meaning N is window_length + 1 so that it can be used as a periodic function
runtime/core/frontend/fbank.h
Outdated
} | ||
|
||
static inline float MelScale(float freq) { | ||
return 1127.0f * logf(1.0f + freq / 700.0f); | ||
static inline float MelScale(float freq, MelType mel_type = MelType::HTK) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
} | ||
} | ||
|
||
static inline float InverseMelScale(float mel_freq, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be further optimized if needed, there are lot of repeated computations, but could already be optimized by some compiler through constant propogation
runtime/core/frontend/fbank.h
Outdated
|
||
if (scaled_float_as_input_) { | ||
for (int j = 0; j < frame_length_; ++j) { | ||
data[j] = data[j] / S16_TO_FLOAT_SCALE; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data feed into this pipeline is int but converted to float without scaling, whisper training code load this as float between -1 to 1
Great job, it's clear. |
Great job! I think |
BTW, Why does the CTC result contain <|notimestamp|>? The label provided to the CTC loss function doesn't include this tag (only the label given to the CE loss has it), so <|notimestamp|> shouldn't appear during CTC decoding. |
Thanks for referencing this implementation, those are really helpful! Ya, they are doing basically the same thing that we want to do. I think we have two ways to support whisper inference in wenet.
For 1, since Fbank computation is not a lot of code, i think it makes sense to have a separate WhisperFbankComputer even without reusing the current FBank code. However for 1. I don't like the fact that they hard coded the weight there, so it's less flexible https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/csrc/whisper-mel-bank.h. and b. It's computing STFT using torch. Since wenet can support runtimes other than torch, we probably don't want to depend on torch in the feature extraction part. However, if we do think 1 is a preferred structure, we can reuse the current wenet FFT so b is not a problem, and we can reuse the code in this PR to generate the filters so a is not a problem as well. Based on above, i think it's more of a style thing, basically we need to decide if we want to reuse FBank or create another WhisperFbank |
Very good question, i am curious as well, i checked again and it does look like those special whisper tokens is not added in CTC loss calculation, so it shouldn't be possible. Maybe something wrong with my training setup? Currently my only assumption is that our model didn't converge well (and it generalized very badly on out of domain data), it looks like we did make a mistake in training by not setting wenet/wenet/whisper/whisper.py Line 76 in 5d6ea3e
|
runtime/core/frontend/fbank.h
Outdated
@@ -24,13 +24,43 @@ | |||
#include "frontend/fft.h" | |||
#include "utils/log.h" | |||
|
|||
#define S16_ABS_MAX (2 << 15) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another concern is that this probably shouldn't be hard coded here, as it should be the responsibility of wavreader to scale the input, rather than the responsibility of feature_extraction_pipeline. Moving it to wavreader also allows the flexiblity if the audio is encoded using pcm_s32 or pcm_s8 instead of fixing it to pcm_s16.
However, doing that would also require changes in http server, websocket server and grpc server code, and possibly other places like the jni bindings for android etc. Feels like that should be decided by the main maintainers of wenet.
(we could do it in a hacky way e.g. let the cli to take in a scale factor, and making this a paramter of the feature extraction pipeline, but that doens't feel right, e.g. the cli takes in a list of wav files encoded with different number of bits, and it won't work in that case as different samples require different scaling factor.)
i will leave this as it is for now, but do let me know your thoughts if you want to update this in this PR. or if you guys feel like making a separate PR fixing this, that also works.
Yes, I think current implemenation is okay. There is a lot of hard code about mel weights for whisper fbank in |
Regarding the STFT difference, I think there is no way to make them match exactly. Reason is because if fft_length != window_length, torch STFT will pad the window on both the left side and the right side: https://github.com/pytorch/pytorch/blob/2d7a360911fb7b27be82c51ca86b4b34b6f1b087/aten/src/ATen/native/SpectralOps.cpp#L936, normally FFT energy doesn't depend on how you pad the input, however this is how torch separate audio into different frames https://github.com/pytorch/pytorch/blob/2d7a360911fb7b27be82c51ca86b4b34b6f1b087/aten/src/ATen/native/SpectralOps.cpp#L949 , because of this, padding the window in different places will result in different part of the raw signal multiplied by different part of the window, resulting in a different PSD result. But I think it's probably fine, result will be the same if we shift the sequence by (fft_length - window_length) / 2, padded_wav = F.pad(wav, (56, 56), "constant", 0) # pad the input so window will match the audio the same way wenet does
stft = torch.stft(padded_wav,
512,
160,
window=window,
center=False, # this is another trivial source of differece
win_length=400,
return_complex=True)
magnitudes = stft[..., :-1].abs()**2
mel_spec_512 = filters_512 @ magnitudes
log_spec_before_norm_512 = torch.clamp(mel_spec_512, min=1e-10).log10()
log_spec_before_norm_512 = torch.maximum(log_spec_before_norm_512, log_spec_before_norm_512.max() - 8.0)
log_spec_after_norm_512 = (log_spec_before_norm_512 + 4.0) / 4.0 and we will get almost the same result I think this is a feature, not a bug, as ASR result should not change even if we shift the input by some number of sampling points. |
3e2a77a
to
649f46c
Compare
感谢各位大佬review 和 approve,愿wenet越来越强大,用户越来越多! |
开源靠大家,感谢贡献! |
Just for people who are confused about <|notimestamps|>, it's because the token list that i used is wrong, it actually corresponds to blank tokens in CTC, which won't appear in the final transcript. Related issue: #2329 |
torch.stft(,n_fft=512 ,**) 与c++ 能够对齐, whisper 默认的 torch.stft(,n_fft=400 ,**) 好像c++ 不正常,是否是一个BUG |
We were trying to reproduce the steps to load whisper in wenet and finetune it for streaming on libsrispeech, following the steps in https://github.com/wenet-e2e/wenet/tree/main/examples/aishell/whisper. Finetuning did seem to work, but when we tried to load it in the cpp runtime, it only produced empty results. After further inspection, it looks like the feature extraction of whisper is very different than the one implemented in wenet runtime, mainly regarding:
examples show it's working
decoder main using whisper
cli is still compatible with the original code, default behavior doesnt change
There are still a couple places that are not perfect
I know these are a lot of changes, so I am more than happy to change the structure into the style you guys prefer if you think there is a better way