来了来了我们可亲可爱的朋友MNIST~
作为一个超级初级的手写数字数据集,含有60k个全损低清图像作为训练集,10k个作为验证集。每个图像是(28,28,1)
范围在[0,255]的整数数字灰阶图像,数字越大代表笔迹越重,即黑底白字。
介于MNIST的特性,想来试试FC和CNN之间的区别,当然训练速度仅供参考,模型不同怎么恋爱?我这个小1060也带不动两位巨佬神仙打架。
整理数据
环境
Python 3.7.2
Tensorflow 1.13 rc2 + Keras
CUDA 10.0 / CuDNN 7.4
为了引入gpu接口,我们需要在一开始为keras设置好设备:
1 | import tensorflow as tf |
导入MNIST数据集
导入方式有很多,这里简单阐述两种。
1. 官网
链接:http://yann.lecun.com/exdb/mnist/
界面中四个超显眼的文件就是我们需要下载的文件,点击下载。
这四个文件是纯二进制构成,所以还得编写一些preprocessing functions去读取内容。
下面这个函数的作用是读取path
下的文件并返回指定数据集(训练集、测试集)。
1 | def load_mnist(path, kind='train'): |
读入(28, 28, 1)
的张量,留了一个位置为keras的后续操作。
其实相当于对(28, 28)
的张量做一个dim_expand(axis=0)
。
2. Keras.dataset
详见官方文档:
1 | from keras.datasets import mnist |
输出的是(28, 28)
的张量,故需要dim_expand
回来。
调整输入输出
输入的数值为[0,255]整数,为简化计算与充分利用性能,消除偏移,可以将输入的值转为32位浮点并归一化,使范围控制在[-1,1]以内。
1 | train_X = (train_X.astype('float32') - 127) / 127 |
为提高分类问题效率,输出上往往使用__One-Hot__编码。我们将代表Lable的数字转写为长度10的列表:
1 | nb_classes = 10 |
构建模型并训练
模型依旧使用Sequence去构建。本文分别阐述FC与CNN。
FC
线性模型可谓最基础的深度学习模型,即数个全连接层。
1 | model = Sequential() |
这里输入的是一行完整的(28*28*1)
,故需要在前面加载数据集时设置好。
总权重数:1,333,770
下面开始训练~
1 | model.compile(optimizer='adam',# 优化器:ADAM |
keras中本身自带了一个将训练集与测试集分离的选项validation_split
,我们可以直接选用。
CNN
卷积大法好。
1 | model.add(Conv2D(28,(5,5),activation='relu',padding="same",input_shape=(28,28,1))) # 用5*5的卷积核扫一扫 |
总权重数:1,266,288
训练:
1 | model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy']) |
测试结果
FC最后一个epoch
用时: 5s
loss: 0.0603
acc: 0.9815
CNN最后一个epoch
用时: 12s
loss: 0.0337
acc: 0.9881
都是百万级的权重值,让我们来看看它们在实际中效果如何。
实际检测
在前面我们要将生成的模型保存起来以便在其他地方使用,使用函数model.save(path)
,保存的后缀设为h5,如”./model.h5“。
制作图片
这一步可以选用GAN生成,让两个模型分别作为对方的判别器……又是另外一个话题了。
打开我们的PS,新建图像28*28,背景为黑色,用白色画笔歪歪扭扭写一个数字,保存为jpg或其他图片格式。
看起来很模糊,但机器不那么认为。
导入模型
我们新建一个py文件,在其中导入模型:
1 | from keras.models import load_model |
导入图片
我们对导入的图片做与输入同样的处理:灰阶、张量形状、浮点、归一化:
1 | import cv2 |
在前后各加一个维度以适配Keras:
1 | img = np.expand_dims(img, axis=0) |
预测结果
接下来就可以用输入通过模型预测输出:
1 | result = model.predict(img) |
好了让我们来看看result长什么样:
1 | print(result) |
这是……啥?一共十个数数字组成的列表的列表!细心的水友可能发现了,第4个数字有些特别:1.0000000e+00,赫然是一个大1,这也意味着结果是3的可能性最大,与我们输入的图片相符!
对输出稍作处理,我们就能得到一个完整的手写识别处理模块:
1 | predicts = 0 |