如何用PyTorch实现从零开始的LeNet5图像分类?
摘要:本文介绍了如何从0开始构建 LeNet5 去识别手写数字(在MNIST数据集上)。代码包括三大部分:网络结构部分、训练部分、测试部分。在编LeNet5部分代码之前,本文详细地梳理了LeNet5的结构,对于初学者十分友好。训练和测试部分也都有
摘要:
本文介绍了如何从0开始构建 LeNet5 去识别手写数字(在MNIST数据集上)。代码包括三大部分:网络结构部分、训练部分、测试部分。在编LeNet5部分代码之前,本文详细地梳理了LeNet5的结构,对于初学者十分友好。训练和测试部分也都有详细的代码说明。
在实现 LeNet5 手写数字识别的同时,补充了很多CNN的基础概念和Python编程知识。包括:PyTorch中的常用库和其中的模块,特征图在卷积过程中尺寸如何变化,如何把数据加载进训练程序等。
本文不是通过复制粘贴代码介绍如何实现 LeNet5 的手写数字识别,而是通过内在逻辑,深层次地阐述这一过程,力求“知其然,知其所以然。”
完整代码已经上传至GitHub:https://github.com/TiezhuXing01/LeNet5_in_PyTorch.git
温馨提示:
(1)本文主要介绍如何从0实现LeNet5,注重编程思路的讲解,对于一些前置知识不做赘述。
(2)请确保你已经配置好并进入了深度学习环境。
(3)目录导航见页面右上角黄色方块。
前置知识:
(1)概念:PyTorch、卷积、池化、全连接、ReLU、前向传播、反向传播
(2)对python语法有基本的了解
在从0开始编程前,我们首先思考一下,一个由LeNet5完成的图像分类任务(如:手写数字识别),都需要哪些组成部分?
首先,肯定要有LeNet5网络结构的代码。
其次,还要有在训练集上训练的代码,让网络学习特征表示。
最后,训练完要在测试集上测试,不然咋知道训练得效果怎样呢?
1. 引入库(Import the Libraries)
“库”是一组已经写好的代码,可以理解为一个“工具箱”。对于某些功能的实现,开发者可以从引入的“工具箱”中拿出工具直接使用,而不是从头开始“造工具”(即编代码)。所以在所有工作之前,要把“工具箱”引入进来,以方便后续编程。但有时候,我们引入库时可能会漏掉一些库,在编程到后面才意识到。不用担心,我们再回到开头这里引入库就好了。
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
torch库提供了各种用于张量(tensor)操作、神经网络搭建、优化算法等方面的函数。
什么是张量?如果你对张量(tensor)不了解,不用担心,你只需要记住这是图片经过某种处理之后的一种形式,就像你已知的图片有.jpg和.png格式一样。但对于.jpg和.png格式的图像,神经网络并不喜欢,无法直接处理它们。而张量(tensor)这种形式,适用于神经网络的处理。
torch.nn库可以理解为大工具箱torch里面的小工具箱nn。这个模块里包含构建神经网络层和模型的类和函数。import torch.nn as nn表示,引入之后,模块torch.nn的名字就可以简称nn了。
torchvision库提供了一系列用于图像处理、计算机视觉数据集加载、图像变换、以及许多流行的计算机视觉模型的实现。
torchvision.transforms模块用于进行图像的变换和预处理。这个模块包含了一系列用于处理图像的转换函数,可用于数据增强、数据清理和准备图像数据以输入神经网络等任务。
2. 选择在哪个设备上训练模型(GPU或CPU)
通常情况下,我们都会选择在GPU上训练网络模型,因为神经网络的训练需要大量的计算,而英伟达的GPU提供了CUDA(一个加速计算库)。但如果你的电脑显卡是AMD的,那么有很大概率不支持使用CUDA,此时只能用CPU训练。但在CPU上训练模型是十分缓慢的。如果你暂时没法换电脑,那我建议你去租一个服务器。或者使用阿里云、百度飞桨、谷歌Colab等平台。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.is_available()函数的功能是检查系统中是否安装了可用的 CUDA并且 GPU 是可用的。如果 GPU 可用,返回 True,否则返回 False。
'cuda' if torch.cuda.is_available() else 'cpu'表示如果GPU可用,则返回字符串'cuda'。如果不可用,则返回字符串'cpu'。
如果函数torch.device(...)接收到的是'cuda',则选择在GPU上计算,如果接收到的是'cpu',则选择在CPU上进行计算。
