猫狗识别(使用VGG16神经网络)-微调 作者:马育民 • 2020-02-06 15:15 • 阅读:10511 上接:[猫狗识别(使用VGG16神经网络)](https://www.malaoshi.top/show_1EF4vi5xOeDl.html "猫狗识别(使用VGG16神经网络)") # 概述 从 [猫狗识别(使用VGG16神经网络)](https://www.malaoshi.top/show_1EF4vi5xOeDl.html "猫狗识别(使用VGG16神经网络)") 可知,训练第8次时,发生过拟合。在数据集只有2000张图片的情况下,准确率已经达到92% 本节 通过微调,让准确率再次提升 ### 注意 本文介绍的方法,是在 [猫狗识别(使用VGG16神经网络)](https://www.malaoshi.top/show_1EF4vi5xOeDl.html "猫狗识别(使用VGG16神经网络)") 基础之上 并且一定是发生 **过拟合** 后,再继续训练 **否则无效** # 微调方法 ### 前提 通过 [猫狗识别(使用VGG16神经网络)](https://www.malaoshi.top/show_1EF4vi5xOeDl.html "猫狗识别(使用VGG16神经网络)") 文章中的方法,进行训练,直到 发生 **过拟合** 此时已经 将 **全连接层** 和 **输出层** 训练好 >回忆上文中的模型,用到VGG16的 **卷积层** 和 **池化层**,且冻结它们的权重,所以训练时,只会改变 **全连接层** 和 **输出层** 的权重 ### 设置后面的层可训练权重 冻结VGG16前面的m层,但可以训练后面的n层 **例子:** ``` covn_base.trainable=True for item in covn_base.layers[:-3]: item.trainable=False ``` ### 训练 在之前过拟合的基础之上,再次训练 此时:vgg16的 **后面n层**,**全连接层** 和 **输出层**,都可以训练 学习速率要降低,如:```0.00005``` **例子:** ``` model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.00005),loss="binary_crossentropy",metrics=["acc"]) ``` ### 原理 在卷积神经网络中,越是前面的层,图片越大,卷积核较小(一般3x3),提取的是纹理,通常提取的是 **通用** 特征 随着经过池化层,其特征图会减小,也就是说:越到后面的层,特征图越小,卷积核大小未改变(仍是3x3),此时 卷积核的 [视野](https://www.malaoshi.top/show_1EF5YLFg71ax.html "视野") 相对越来越大,提取的是更加完整的特征。如果是猫的图片,最后提取的就是猫的抽象特征 [![](https://www.malaoshi.top/upload/0/0/1EF4cZeMoySG.jpg)](https://www.malaoshi.top/upload/0/0/1EF4cZeMoySG.jpg) # 代码 接 [猫狗识别(使用VGG16神经网络)](https://www.malaoshi.top/show_1EF4vi5xOeDl.html "猫狗识别(使用VGG16神经网络)") 文章后面写,而且 **必须在训练过拟合后**,**继续训练**,否则无效 ### 只允许后面3层可训练 只允许后面3层的权重可训练 ``` covn_base.trainable=True for item in covn_base.layers[:-3]: item.trainable=False ``` ### 编译 **注意:** 学习速度要很小 ``` model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.00005),loss="binary_crossentropy",metrics=["acc"]) ``` ### 训练 ``` history=model.fit(train_ds2,epochs=25,validation_data=test_ds2,initial_epoch=15) ``` **注意:** 由于是 **过拟合** 后,再次训练,所以要设置 ```initial_epoch```参数,表示从第16次开始训练,共训练 `25-15=10`次 [![](https://www.malaoshi.top/upload/0/0/1EF4vyRrM1kT.png)](https://www.malaoshi.top/upload/0/0/1EF4vyRrM1kT.png) ### 查看图片 ``` plt.plot(history.epoch,history.history["acc"],label="acc") plt.plot(history.epoch,history.history["val_acc"],label="val_acc") plt.legend() ``` [![](https://www.malaoshi.top/upload/0/0/1EF4vyTH9AQZ.png)](https://www.malaoshi.top/upload/0/0/1EF4vyTH9AQZ.png) 原文出处:http://malaoshi.top/show_1EF4vyjtBcY5.html