TensorFlow之MNIST入门
· 阅读需 6 分钟
MNIST手写数字识别是机器学习中非常经典的问题,相当于编程语言界的“Hello World“。关于神经网络解决MNIST手写数字识别问题,可以参考这个视频:深度学习之神经网络的结构 Part 1 ver 2.0
视频中使用的是多层神经网络,为了简化问题,这里我们使用单层的网络结构。
参考之前的MNIST数据集解析,先对MNIST数据集进行解析:
import gzip
import struct
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
def load_images(image_gz):
with gzip.open(image_gz) as f:
buf = f.read()
num = int(struct.unpack_from('>i', buf, 4)[0])
return (np.array(struct.unpack_from('B'*num*28*28, buf, 16)
).reshape(num, 784)/255).astype(np.float32)
def load_labels(label_gz):
with gzip.open(label_gz) as f:
buf = f.read()
num = int(struct.unpack_from('>i', buf, 4)[0])
idx = 8
tmp = []
for i in range(num):
label = int(struct.unpack_from('B', buf, idx)[0])
idx += 1
# one-hot encoding
ohl = np.zeros(10, dtype=np.float32)
ohl[label] = 1.0
tmp.append(ohl)
return np.array(tmp)
train_images = load_images('train-images-idx3-ubyte.gz')
train_labels = load_labels('train-labels-idx1-ubyte.gz')
test_images = load_images('t10k-images-idx3-ubyte.gz')
test_labels = load_labels('t10k-labels-idx1-ubyte.gz')
在读取图片时,一次性读取二进制数据,这样可以大大提升效率。之后,为了使用的方便,将它变形为num*784大小,由于图片都是28*28大小,所以单张图片的像素数就是784。另外,还将像素值进行了归一化,因为,如果输入层的值很大,在反向传播时传递到输入层的梯度就会很大,如果梯度非常大,学习率就必须非常小,否则就会跳过局部最小(直接表现就是代价函数的值为nan)。因此,如果用梯度下降来训练模型一般都要在数据预处理步骤进行数据归一化。
对于离散的特征一般按照one-hot编码,该离散特征有多少取值,就用多少维度来表示该特征。在回归、分类、聚类等机器学习算法中,特征之间距离的计算或相似度的计算是非常重要的,使用one-hot编码,特征之间的距离更为合理。
基于树的方法不需要特征归一化,基于参数或距离的模型要进行特征归一化。
X = tf.placeholder(tf.float32, (None, 784))
Y = tf.placeholder(tf.float32, (None, 10))
W = tf.Variable(tf.truncated_normal((784, 10), stddev=0.01))
b = tf.Variable(tf.zeros((10,)))
y = tf.nn.softmax(tf.matmul(X, W) + b)
cost = -tf.reduce_sum(Y*tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
Softmax函数一般用于多分类问题,可以对预测的标签进行归一化。计算公式为:
下面举例说明计算过程:
a = tf.constant(np.array([
[6., 1., 0.],
[0., 4., 2.]
]))
b = tf.nn.softmax(a)
with tf.Session() as sess:
print(sess.run(b))
# Output:
[[ 0.99086747 0.00667641 0.00245611]
[ 0.01587624 0.86681333 0.11731043]]
使用Linux自带的计算器bc进行手动计算的过程(其中e(x)
表示exp(x)
):
$ bc -lq
e(6)/(e(6)+e(1)+e(0))
.99086747258217259526
e(1)/(e(6)+e(1)+e(0))
.00667641251337645118
e(0)/(e(6)+e(1)+e(0))
.00245611490445095354
e(0)/(e(0)+e(4)+e(2))
.01587623997646676632
e(4)/(e(0)+e(4)+e(2))
.86681333219733487114
e(2)/(e(0)+e(4)+e(2))
.11731042782619836253