前言
由于前一段时间在搞定位算法,暂时没有时间来学习、记录pytorch的学习过程。在这里就先提及需要详细解释的一些部分
- 优化器
optimizer
的种类、各自的优势,以及我应该如何选取 - 正则化、过拟合、
batch normalization
的解释和应用 - 学习更多的深度学习算法来更好的应用在实际例子中
- 更多pytorch例程的研究记录
经过这几天的零碎学习,我真是觉得神经网络的搭建真的太复杂了,复杂到什么程度呢?就是一个概念不懂,找到解释后,发现解释里又包含着好几个你根本不知道的概念,不过还好我做了一些学习准备,也不算是完全没有头绪。但是目前暴露的缺陷有:
那么针对这些,我认为可以在学习一些实例、教程时碎片化的学习来完善自己的知识框架,所以有了以下的“深入了解”学习阶段
思考:
Dataset
与DataLoader
?Dataset
与DataLoader
就起到了必要的作用Dataset
与DataLoader
分别对应数据的读取和操作,他们是官网提供给我们处理数据的一个范例。Dataset
Dataset
位于torch.utils.data.Dataset
中,往往我们需要创建自己的获取数据集的Mydataset
类,他必须继承官方的Dataset
类,并且必须要实现其两个成员函数:__len__()
、__getitem__()
1 | import torch |
transform
操作(这个具体在后面说明),至于后面的test
参数也和transform有关重点看类里内置的方法
__getitem__(self,index)
:该方法支持从 0 到len(self)
的索引
在我们自己的Dataset中必须需要一个这个方法获取数据集中的每一组数据。比如在此例中,由于在建立dataset实例时出入的data
数据是所有的.jpg
文件,那么当我们需要获取图片进行训练或测试时,需要利用opencv的方法读取图片数据、将图片转换成RGB格式,把每个图片的尺寸统一为设定的模板尺寸,最后将图片信息数据转换为tensor
张量形式
__len__(self)
这个方法比较简单,用起来也很方便。目的就是,返回数据集的长度。但这个方法必须有。
DataLoader
torch.utils.data
已经提供的类:Dataset
,但是通过这种方式只能一个个的数据的把数据全部读出来,定义了数据读取的方式,不能实现批量的把数据读取出来,为此pytorch有提供了一个方法:DataLoader()
DataLoader
将Dataset
中的数据集以每次Batch_size
个大小的组,读取出来,以供训练Dataloader
本质是一个可迭代对象,使用iter()
访问,不能使用next()
访问,由于它本身就是一个可迭代对象,可以使用for inputs, labels in dataloaders
进行可迭代对象的访问DataLoader
的方法了,只需要在构造函数中指定相应的参数即可,比如常见的batch_size
,shuffle
等等参数。所以使用DataLoader十分简洁方便。1 | #自己定义的函数:获取数据并将数据分为训练集和测试集两类 |
dataset
创建了训练和测试两个实例,训练需要数据增强而测试不需要,所以两个实例除数据不同,tranform
也不同DataLoader
的参数X_dataset
:数据集的实例batch_size
:训练的提取数据的单位内所含个数shuffle
:布尔值True或者是False ,表示每一个epoch之后是否对样本进行随机打乱,默认是Falsecollate_fn
:这个是一个问题,因为通常由自己确定。是一个自己对于数据选取的方法,但是我不知道他有什么意义,或者优势pin_memory
:如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中num_workers
:这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)1 | import torch |
Dataset
与DataLoader
的使用,这里很简单,只是看一个思路。Batch_size
个数的数据,但是你的神经网络的输入层的结构不会因为batch_size
的变化受到影响batch_size
对CNN结果的影响,点击这里了解collate_fn
如何在加载数据时起作用?collate_fn
上collate_fn
默认是等于default_collate
1 | def __next__(self): #这是DataLoader的关于迭代的方法的源码 |
collate_fn
这个函数的输入就是一个list,list的长度是一个batch size
,list中的每个元素都是__getitem__
得到的结果,想要知道详解点击这里,那么batch
中的每一个元素其实就是Dataset
中__getitem__
方法得到的比如(图片信息img、标签label)那么如果理解了以上的意思,再看collate_fn
就能较快地上手了
那什么时候该使用DataLoader的collate_fn
这个参数?
当定义DataSet
类中的__getitem__
函数的时候,由于每次返回的是一组类似于(x,y)
的样本,但是如果在返回的每一组样本x,y中出现什么错误,或者是还需要进一步对x,y进行一些处理的时候,我们就需要再定义一个collate_fn
函数来实现这些功能。当然我也可以自己在实现__getitem__
的时候就实现这些后处理也是可以的。
collate_fn
,中单词collate
的含义是:核对,校勘,对照,整理。顾名思义,这就是一个对每一组样本数据进行一遍“核对和重新整理”,现在可能更好理解一些
torchvision
是什么?torchvision
是PyTorch中专门用来处理图像的库,PyTorch官网的安装教程,也会让你安装上这个包。torchvision.datasets
torchvision.models
torchvision.transforms
torchvision.utils
torchvision.datasets
torchvision.datasets
是用来进行数据加载的,PyTorch团队在这个包中帮我们提前处理好了很多很多图片数据集。
MNISTCOCO
Captions
Detection
LSUN
ImageFolder
Imagenet-12
CIFAR
STL10
SVHN
PhotoTour
我们可以直接使用,示例如下:
1 | import torchvision |
torchvision.models
torchvision.models
中为我们提供了已经训练好的模型,让我们可以加载之后,直接使用
torchvision.models
模块的子模块中包含以下模型结构。
AlexNet
VGG
ResNet
SqueezeNet
DenseNet
我们可以直接使用如下代码来快速创建一个权重随机初始化的模型
1 | import torchvision.models as models |
也可以通过使用pretrained=True
来加载一个别人预训练好的模型
1 | import torchvision.models as models |
torchvision.transforms
transforms
提供了一般的图像转换操作类,这就是我在一开始Dataset
中提到了transform
,还记得么?哈哈,可以翻一翻上面的例程中在创建Mydataset
类的getitem
方法中是如何使用transforms的
其实我猜测例程中,作者想在获取每一个图像信息时,除了进行opencv
的操作外,训练时还要对图像数据信息做更多处理,而测试时当然就不需要这种处理
1 | # 我们这里还是对MNIST进行处理,初始的MNIST是 28 * 28,我们把它处理成 96 * 96 的torch.Tensor的格式 |
torchvision.datasets
中获取的图像数据做transform,这样更简单,那如何在自己定义的Dataset类中加入Transform呢?DataSet
,并重写__getitem__
,在里面实现关键的transform
操作transforms
,通过Compose
类来实现,注意这个类的返回值哦!DataSet
的对象,将组合起来的transform
传递进去,这样就会对每一个batch_size
的图像都进行相关的数据增强操作了numpy
矩阵类型,最后又要转换成tensor
张量类型,这个要之后细究Resnet
?👴也不知道,👴正在学
torch.normal(means, std, out=None)
means
和std
分别给出均值和标准差torch.rand(*sizes, out=None) → Tensor
torch.randn(*sizes, out=None) → Tensor
torch.squeeze(a,N)
a.squeeze(N)
也是去掉a中指定的维数为一的维度,去掉的维度数为Ntorch.unsqueeze()
a.squeeze(N)
可以关注一下python自带的
numpy
数字计算库和torch
的数据处理,其实相似度很高
torch.cat((A,B),0)
np.hstack((A,B,C..))
& np.vstack((A,B,C..))
torch.stack([tensor1, tensor2, tensor3…], dim=0)
np.random.shuffle(x)
np.transpose(x)
torch.div(input, value, out=None)
plt.scatter(x,y,c='r',s=20,marker=".",lw=2)
matplotlib
库中画散点图的函数,注意c
这个参数的特别用法Update your browser to view this website correctly. Update my browser now