Tensorflow【实战Google深度学习框架】全连接神经网络以及可视化

文章目录

1 可视化 神经网络的二元分类效果

ConvnetJS-地址


layer_defs = [];
layer_defs.push({type:'input', out_sx:1, out_sy:1, out_depth:2});
layer_defs.push({type:'fc', num_neurons:6, activation: 'tanh'});
layer_defs.push({type:'fc', num_neurons:2, activation: 'tanh'});

layer_defs.push({type:'softmax', num_classes:2});

net = new convnetjs.Net();
net.makeLayers(layer_defs);

trainer = new convnetjs.SGDTrainer(net, {learning_rate:0.01, momentum:0.1, batch_size:10, l2_decay:0.001});

可以从这里修改层数,和神经元数量,观察分类效果。
Tensorflow【实战Google深度学习框架】全连接神经网络以及可视化
这里可以设置不同的样本。

2 全连接神经网络

Tensorflow【实战Google深度学习框架】全连接神经网络以及可视化
在全连接神经网络中,每两层之间的节点都有边相连。
卷积神经网络也是通过一层一层的节点组织起来的,对于卷积神经网络,相邻两层之间只有部分节点相连。在卷积神经网络的前几层中,每一层的节点都被组织成一个三维矩阵。前几层中每一个节点只和上一层中部分节点相连。

3 TensorFlow搭建一个全连接神经网络

实现对MNIST手写数字的识别.
Tensorflow【实战Google深度学习框架】全连接神经网络以及可视化

3.1 读取MNIST数据

mnist = input_data.read_data_sets('C:/MNIST_data/', one_hot=True)

3.2 建立占位符

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

xs 代表图片像素数据, 每张图片(28×28)被展开成(1×784), 有多少图片还未定, 所以shape为None×784.
ys 代表图片标签数据, 0-9十个数字被表示成One-hot形式, 即只有对应bit为1, 其余为0.

3.3 建立模型

weights = tf.Variable(tf.random_normal([784, 10]))#权重
biases = tf.Variable(tf.zeros([1, 10]) + 0.1)#偏值
outputs = tf.matmul(xs, weights) + biases#权重*xs+偏置
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
                                              reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#优化器

tf.reduce_mean 函数用于计算张量tensor沿着指定的数轴(tensor的某一维度)上的的平均值,主要用作降维或者计算tensor(图像)的平均值。
其中的minimize可以拆为以下两个步骤:
① 梯度计算
② 将计算出来的梯度应用到变量的更新中
拆开的好处是,可以对计算的梯度进行限制,防止梯度消失和爆炸

使用Softmax函数作为激活函数:
Tensorflow【实战Google深度学习框架】全连接神经网络以及可视化

3.4 正确率

correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

3.5 训练模型

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    for i in range(1000):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={
            xs: batch_xs,
            ys: batch_ys
        })
        if i % 50 == 0:
            print(sess.run(accuracy, feed_dict={
                xs: mnist.test.images,
                ys: mnist.test.labels
            }))

3.6 完整代码

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

# prepare data读取MNIST数据
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])

# the model of the fully-connected network
weights = tf.Variable(tf.random_normal([784, 10]))
biases = tf.Variable(tf.zeros([1, 10]) + 0.1)
outputs = tf.matmul(xs, weights) + biases
predictions = tf.nn.softmax(outputs)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(predictions),
                                              reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# compute the accuracy
correct_predictions = tf.equal(tf.argmax(predictions, 1), tf.argmax(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    for i in range(1000):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={
            xs: batch_xs,
            ys: batch_ys
        })
        if i % 50 == 0:
            print(sess.run(accuracy, feed_dict={
                xs: mnist.test.images,
                ys: mnist.test.labels
            }))

结果:

"C:\Program Files\Python36\pythonw.exe" C:/Users/88304/PycharmProjects/bianyiyuanli/cifafenxi.py
C:\Program Files\Python36\lib\site-packages\tensorflow\python\framework\dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
C:\Program Files\Python36\lib\site-packages\tensorflow\python\framework\dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
C:\Program Files\Python36\lib\site-packages\tensorflow\python\framework\dtypes.py:521: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
C:\Program Files\Python36\lib\site-packages\tensorflow\python\framework\dtypes.py:522: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
C:\Program Files\Python36\lib\site-packages\tensorflow\python\framework\dtypes.py:523: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
C:\Program Files\Python36\lib\site-packages\tensorflow\python\framework\dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
WARNING:tensorflow:From C:/Users/88304/PycharmProjects/bianyiyuanli/cifafenxi.py:5: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Program Files\Python36\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\Program Files\Python36\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting C:/MNIST_data/train-images-idx3-ubyte.gz
Extracting C:/MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Python36\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
WARNING:tensorflow:From C:\Program Files\Python36\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting C:/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting C:/MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Python36\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
2020-04-11 18:28:52.546898: I T:\src\github\tensorflow\tensorflow\core\platform\cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
0.1597
0.6417
0.7422
0.7792
0.8057
0.8136
0.832
0.8415
0.8437
0.8458
0.8548
0.8583
0.8613
0.8622
0.8678
0.8638
0.8685
0.8659
0.8738
0.8721

进程已结束,退出代码0

准确率逐渐上升

4 参考:

https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/5-01-classifier/

上一篇:Bootstrap-v3-组件-巨幕


下一篇:机器学习的数学基础笔记①