emmm……这周最后一天了

0x00.前言

仍然是腾讯云开发者实验室(beta),今天换个实验做:

0x01.引用

1.0 TensorFlow 实现基于 CNN 数字识别的代码

1.1 前期准备

TensorFlow相关API可以到在实验TensorFlow - 相关 API中学习。
唔,这就尴尬了,这节课我还没看呢……

1.2 CNN模型构建

现在您可以在/home/ubuntu目录下创建源文件mnist_model.py,内容可参考:

1
#!/usr/bin/python
2
# -*- coding: utf-8 -*
3
4
from __future__ import absolute_import
5
from __future__ import division
6
from __future__ import print_function
7
8
import argparse
9
import sys
10
import tempfile
11
12
from tensorflow.examples.tutorials.mnist import input_data
13
14
import tensorflow as tf
15
16
FLAGS = None
17
18
19
def deepnn(x):
20
21
  with tf.name_scope('reshape'):
22
    x_image = tf.reshape(x, [-1, 28, 28, 1])
23
24
  #第一层卷积层,卷积核为5*5,生成32个feature maps.
25
  with tf.name_scope('conv1'):
26
    W_conv1 = weight_variable([5, 5, 1, 32])
27
    b_conv1 = bias_variable([32])
28
    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) #激活函数采用relu
29
30
  # 第一层池化层,下采样2.
31
  with tf.name_scope('pool1'):
32
    h_pool1 = max_pool_2x2(h_conv1)
33
34
  # 第二层卷积层,卷积核为5*5,生成64个feature maps
35
  with tf.name_scope('conv2'):
36
    W_conv2 = weight_variable([5, 5, 32, 64])
37
    b_conv2 = bias_variable([64])
38
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)#激活函数采用relu
39
40
  # 第二层池化层,下采样2.
41
  with tf.name_scope('pool2'):
42
    h_pool2 = max_pool_2x2(h_conv2)
43
44
  #第一层全连接层,将7x7x64个feature maps与1024个features全连接
45
  with tf.name_scope('fc1'):
46
    W_fc1 = weight_variable([7 * 7 * 64, 1024])
47
    b_fc1 = bias_variable([1024])
48
49
    h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
50
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
51
52
  #dropout层,训练时候随机让某些隐含层节点权重不工作
53
  with tf.name_scope('dropout'):
54
    keep_prob = tf.placeholder(tf.float32)
55
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
56
57
  # 第二层全连接层,1024个features和10个features全连接
58
  with tf.name_scope('fc2'):
59
    W_fc2 = weight_variable([1024, 10])
60
    b_fc2 = bias_variable([10])
61
62
    y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
63
  return y_conv, keep_prob
64
65
#卷积
66
def conv2d(x, W):
67
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
68
69
#池化
70
def max_pool_2x2(x):
71
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
72
                        strides=[1, 2, 2, 1], padding='SAME')
73
#权重
74
def weight_variable(shape):
75
  initial = tf.truncated_normal(shape, stddev=0.1)
76
  return tf.Variable(initial)
77
78
#偏置
79
def bias_variable(shape):
80
  initial = tf.constant(0.1, shape=shape)
81
  return tf.Variable(initial)

1.3 训练 CNN 模型

现在您可以在/home/ubuntu目录下创建源文件train_mnist_model.py,内容可参考:

1
#!/usr/bin/python
2
# -*- coding: utf-8 -*
3
4
from __future__ import absolute_import
5
from __future__ import division
6
from __future__ import print_function
7
8
import argparse
9
import sys
10
import tempfile
11
12
from tensorflow.examples.tutorials.mnist import input_data
13
14
import tensorflow as tf
15
16
import mnist_model
17
18
FLAGS = None
19
20
21
def main(_):
22
  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
23
24
  #输入变量,mnist图片大小为28*28
25
  x = tf.placeholder(tf.float32, [None, 784])
26
27
  #输出变量,数字是1-10
28
  y_ = tf.placeholder(tf.float32, [None, 10])
29
30
  # 构建网络,输入—>第一层卷积—>第一层池化—>第二层卷积—>第二层池化—>第一层全连接—>第二层全连接
31
  y_conv, keep_prob = mnist_model.deepnn(x)
32
33
  #第一步对网络最后一层的输出做一个softmax,第二步将softmax输出和实际样本做一个交叉熵
34
  #cross_entropy返回的是向量
35
  with tf.name_scope('loss'):
36
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_,
37
                                                            logits=y_conv)
38
39
  #求cross_entropy向量的平均值得到交叉熵
40
  cross_entropy = tf.reduce_mean(cross_entropy)
41
42
  #AdamOptimizer是Adam优化算法:一个寻找全局最优点的优化算法,引入二次方梯度校验
43
  with tf.name_scope('adam_optimizer'):
44
    train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
45
46
  #在测试集上的精确度
47
  with tf.name_scope('accuracy'):
48
    correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
49
    correct_prediction = tf.cast(correct_prediction, tf.float32)
50
  accuracy = tf.reduce_mean(correct_prediction)
51
52
  #将神经网络图模型保存本地,可以通过浏览器查看可视化网络结构
53
  graph_location = tempfile.mkdtemp()
54
  print('Saving graph to: %s' % graph_location)
55
  train_writer = tf.summary.FileWriter(graph_location)
56
  train_writer.add_graph(tf.get_default_graph())
57
58
  #将训练的网络保存下来
59
  saver = tf.train.Saver()
60
  with tf.Session() as sess:
61
    sess.run(tf.global_variables_initializer())
62
    for i in range(5000):
63
      batch = mnist.train.next_batch(50)
64
      if i % 100 == 0:
65
        train_accuracy = accuracy.eval(feed_dict={
66
            x: batch[0], y_: batch[1], keep_prob: 1.0})#输入是字典,表示tensorflow被feed的值
67
        print('step %d, training accuracy %g' % (i, train_accuracy))
68
      train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
69
70
    test_accuracy = 0
71
    for i in range(200):
72
      batch = mnist.test.next_batch(50)
73
      test_accuracy += accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0}) / 200;
74
75
    print('test accuracy %g' % test_accuracy)
76
77
    save_path = saver.save(sess,"mnist_cnn_model.ckpt")
78
79
if __name__ == '__main__':
80
  parser = argparse.ArgumentParser()
81
  parser.add_argument('--data_dir', type=str,
82
                      default='/tmp/tensorflow/mnist/input_data',
83
                      help='Directory for storing input data')
84
  FLAGS, unparsed = parser.parse_known_args()
85
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

然后执行:
cd /home/ubuntu
python train_mnist_model.py
P.S.请无视;,原文给的……
训练的时间会较长,可以喝杯茶耐心等待。
喝茶,哈哈哈……看了下速度实在是太慢了,直接下一步吧……
执行结果:

1
step 3600, training accuracy 0.98
2
step 3700, training accuracy 0.98
3
step 3800, training accuracy 0.96
4
step 3900, training accuracy 1
5
step 4000, training accuracy 0.98
6
step 4100, training accuracy 0.96
7
step 4200, training accuracy 1
8
step 4300, training accuracy 1
9
step 4400, training accuracy 0.98
10
step 4500, training accuracy 0.98
11
step 4600, training accuracy 0.98
12
step 4700, training accuracy 1
13
step 4800, training accuracy 0.98
14
step 4900, training accuracy 1
15
test accuracy 0.9862

1.4 测试 CNN 模型

下载测试图片
下载test_num.zip
cd /home/ubuntu
wget http://tensorflow-1253902462.cosgz.myqcloud.com/test_num.zip
解压测试图片包
解压test_num.zip,其中1-9.png!webp1-9数字图片。
unzip test_num.zip
实现predict代码
现在您可以在/home/ubuntu目录下创建源文件predict_mnist_model.py,内容可参考:

1
#!/usr/bin/python
2
# -*- coding: utf-8 -*
3
4
from __future__ import absolute_import
5
from __future__ import division
6
from __future__ import print_function
7
8
import argparse
9
import sys
10
import tempfile
11
12
from tensorflow.examples.tutorials.mnist import input_data
13
14
import tensorflow as tf
15
16
import mnist_model
17
from PIL import Image, ImageFilter
18
19
def load_data(argv):
20
21
    grayimage = Image.open(argv).convert('L')
22
    width = float(grayimage.size[0])
23
    height = float(grayimage.size[1])
24
    newImage = Image.new('L', (28, 28), (255))
25
26
    if width > height:
27
        nheight = int(round((20.0/width*height),0))
28
        if (nheigth == 0):
29
            nheigth = 1
30
        img = grayimage.resize((20,nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN)
31
        wtop = int(round(((28 - nheight)/2),0))
32
        newImage.paste(img, (4, wtop))
33
    else:
34
        nwidth = int(round((20.0/height*width),0))
35
        if (nwidth == 0):
36
            nwidth = 1
37
        img = grayimage.resize((nwidth,20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN)
38
        wleft = int(round(((28 - nwidth)/2),0))
39
        newImage.paste(img, (wleft, 4))
40
41
    tv = list(newImage.getdata())
42
    tva = [ (255-x)*1.0/255.0 for x in tv]
43
    return tva
44
45
def main(argv):
46
47
    imvalue = load_data(argv)
48
49
    x = tf.placeholder(tf.float32, [None, 784])
50
    y_ = tf.placeholder(tf.float32, [None, 10])
51
    y_conv, keep_prob = mnist_model.deepnn(x)
52
53
    y_predict = tf.nn.softmax(y_conv)
54
    init_op = tf.global_variables_initializer()
55
    saver = tf.train.Saver()
56
    with tf.Session() as sess:
57
        sess.run(init_op)
58
        saver.restore(sess, "mnist_cnn_model.ckpt")
59
        prediction=tf.argmax(y_predict,1)
60
        predint = prediction.eval(feed_dict={x: [imvalue],keep_prob: 1.0}, session=sess)
61
        print (predint[0])
62
63
if __name__ == "__main__":
64
    main(sys.argv[1])

然后执行:
cd /home/ubuntu
python predict_mnist_model.py 1.png!webp
执行结果:
1

你可以修改1.png!webp1-9.png!webp中任意一个
既然都这么说了,那我就全部试一下……

emmm……我要看下原图

0x02.后记

文件浏览器的刷新按钮好像坏掉了,实验做到最后文件也没改变……手动点下上层文件夹就好了
我感觉我又水了一篇文章……
未完待续……