You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thank you for developing BERTax! It looks like a really great tool for taxonomic classification of sequences that are typically difficult to classify with tools that rely on big databases.
I was interested to see if BERTax could be used for classification of metagenomic sequencing reads, but it seems like it would be quite a bit slower than kmer based methods (Centrifuge, Kraken2) even with GPU acceleration (16 CPU threads (Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz): 6 reads/s; Nvidia Quadro RTX 5000 (Driver Version: 470.63.01; CUDA Version: 11.4): 20 reads/s).
Are there any plans to optimize BERTax for performing predictions on larger inputs?
I tried to modify the BERTax code to be a little more efficient on large inputs (reads in FASTQ) in PR peterk87#1 but I'm not familiar with Keras or Tensorflow, so I'm not sure how one would go about optimizing that code. The call to model.predict seems to be taking the most time by far.
For example, for a read of length 6092 split into 5 chunks:
seq2tokens: 0.792363 ms
process_bert_tokens_batch: 1.096281 ms
model.predict: 67.773608 ms
writing output: 1.32 ms
Total elapsed time of 70.986515 ms. Timings were obtained with time.time_ns. Although there may be optimizations that could be possible for input processing and formatting output, most of the time (>95%) is spent running model.predict.
I noticed that in the bertax-visualize script, that the Keras model is converted into a PyTorch model:
I haven't tested whether using PyTorch and a converted model would help speed-up predictions. Maybe the Keras model could be converted to a Tensorflow model for less overhead per call to model.predict as per the following blogpost:
Unfortunately, I'm only familiar with NumPy and not familiar with Keras, Tensorflow or PyTorch. I have a bit of experience working with Cython and Numba for accelerating Python code, but using those may not be appropriate in this case.
Any speed-ups (or ideas for how to achieve speed-ups) would be extremely useful and appreciated and allow BERTax to be used on a wider range of datasets!
Thanks!
Peter
The text was updated successfully, but these errors were encountered:
Many thanks for your tests and suggestions! I haven't looked into runtime optimization that much so far, so I think there are definitely some improvements that can be made. I didn't know about tensorflow lite, that seems like a promising starting point, although I'm not sure how well custom models (keras-bert) can be converted.
Thanks again, I'll look into it!
Fleming
Hello,
Thank you for developing BERTax! It looks like a really great tool for taxonomic classification of sequences that are typically difficult to classify with tools that rely on big databases.
I was interested to see if BERTax could be used for classification of metagenomic sequencing reads, but it seems like it would be quite a bit slower than kmer based methods (Centrifuge, Kraken2) even with GPU acceleration (16 CPU threads (Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz): 6 reads/s; Nvidia Quadro RTX 5000 (Driver Version: 470.63.01; CUDA Version: 11.4): 20 reads/s).
Are there any plans to optimize BERTax for performing predictions on larger inputs?
I tried to modify the BERTax code to be a little more efficient on large inputs (reads in FASTQ) in PR peterk87#1 but I'm not familiar with Keras or Tensorflow, so I'm not sure how one would go about optimizing that code. The call to
model.predict
seems to be taking the most time by far.For example, for a read of length 6092 split into 5 chunks:
seq2tokens
: 0.792363 msprocess_bert_tokens_batch
: 1.096281 msmodel.predict
: 67.773608 msTotal elapsed time of 70.986515 ms. Timings were obtained with
time.time_ns
. Although there may be optimizations that could be possible for input processing and formatting output, most of the time (>95%) is spent runningmodel.predict
.I noticed that in the bertax-visualize script, that the Keras model is converted into a PyTorch model:
https://github.com/f-kretschmer/bertax/blob/ae8cc568a2e66692e7663025906fda0016aa8b52/bertax/visualize.py#L29
I haven't tested whether using PyTorch and a converted model would help speed-up predictions. Maybe the Keras model could be converted to a Tensorflow model for less overhead per call to
model.predict
as per the following blogpost:https://micwurm.medium.com/using-tensorflow-lite-to-speed-up-predictions-a3954886eb98
Unfortunately, I'm only familiar with NumPy and not familiar with Keras, Tensorflow or PyTorch. I have a bit of experience working with Cython and Numba for accelerating Python code, but using those may not be appropriate in this case.
Any speed-ups (or ideas for how to achieve speed-ups) would be extremely useful and appreciated and allow BERTax to be used on a wider range of datasets!
Thanks!
Peter
The text was updated successfully, but these errors were encountered: