Skip to content

ybai001/AndroidMnistWithTFLite

 
 

Repository files navigation

Android MNIST With TFLite

Python部分

训练模型来源于TensorFlow的basic_classification示例,使用TensorFlow Keras API。

为了能够更好的在Android手机上呈现并供用户测试,训练模型里使用MNIST,而非basic classification示例里的Fashion MNIST。

本项目python源码位于根目录python_code路径下。

python keras_mnist_train.py

注意: 考虑到网络问题,请自行下载MNIST数据,并配置好路径 训练时会先将图像数据数值范围从0-255转为0-1,预测时需要对待测数据做同样处理。

# you can download mnist from http://yann.lecun.com/exdb/mnist/
train_images = read_local_mnist.load_train_images('input_data/train-images.idx3-ubyte')
train_labels = read_local_mnist.load_train_labels('input_data/train-labels.idx1-ubyte')
test_images = read_local_mnist.load_test_images('input_data/t10k-images.idx3-ubyte')
test_labels = read_local_mnist.load_test_labels('input_data/t10k-labels.idx1-ubyte')

训练得到keras_mnist_model.h5训练结果,验证h5是否有效

python eveluate.py keras_mnist_model.h5

将h5结果转化为tflite

python convert.py keras_mnist_model.h5

注意: 由于TensorFlow版本的持续更新,运行时可能会报TFLiteConverter Not Found等问题,建议使用TensorFlow Nightly,或者在Google Colab上进行。

Android部分

UI逻辑来源于MindOrksAndroidTensorFlowMNISTExample

核心代码就是以下一小段:

Interpreter mInterpreter = new Interpreter(loadModelFile(mContext));
float[][] labelProbArray = new float[1][10];
//Get input pixels from DrawView.
mInterpreter.run(userInputPixels, labelProbArray);
return getMax(labelProbArray[0]);

最终呈现结果如下:

ui_interface

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Java 63.2%
  • Python 36.8%