前言
由于前一段时间在搞定位算法,暂时没有时间来学习、记录pytorch的学习过程。在这里就先提及需要详细解释的一些部分
- 优化器
optimizer的种类、各自的优势,以及我应该如何选取 - 正则化、过拟合、
batch normalization的解释和应用 - 学习更多的深度学习算法来更好的应用在实际例子中
- 更多pytorch例程的研究记录
经过这几天的零碎学习,我真是觉得神经网络的搭建真的太复杂了,复杂到什么程度呢?就是一个概念不懂,找到解释后,发现解释里又包含着好几个你根本不知道的概念,不过还好我做了一些学习准备,也不算是完全没有头绪。但是目前暴露的缺陷有:
那么针对这些,我认为可以在学习一些实例、教程时碎片化的学习来完善自己的知识框架,所以有了以下的“深入了解”学习阶段
思考:
Dataset与DataLoader?Dataset与DataLoader就起到了必要的作用Dataset与DataLoader分别对应数据的读取和操作,他们是官网提供给我们处理数据的一个范例。DatasetDataset位于torch.utils.data.Dataset中,往往我们需要创建自己的获取数据集的Mydataset类,他必须继承官方的Dataset类,并且必须要实现其两个成员函数:__len__()、__getitem__()1 | import torch |
transform操作(这个具体在后面说明),至于后面的test参数也和transform有关重点看类里内置的方法
__getitem__(self,index):该方法支持从 0 到len(self)的索引
在我们自己的Dataset中必须需要一个这个方法获取数据集中的每一组数据。比如在此例中,由于在建立dataset实例时出入的data数据是所有的.jpg文件,那么当我们需要获取图片进行训练或测试时,需要利用opencv的方法读取图片数据、将图片转换成RGB格式,把每个图片的尺寸统一为设定的模板尺寸,最后将图片信息数据转换为tensor张量形式
__len__(self)
这个方法比较简单,用起来也很方便。目的就是,返回数据集的长度。但这个方法必须有。
DataLoadertorch.utils.data已经提供的类:Dataset,但是通过这种方式只能一个个的数据的把数据全部读出来,定义了数据读取的方式,不能实现批量的把数据读取出来,为此pytorch有提供了一个方法:DataLoader()DataLoader将Dataset中的数据集以每次Batch_size个大小的组,读取出来,以供训练Dataloader本质是一个可迭代对象,使用iter()访问,不能使用next()访问,由于它本身就是一个可迭代对象,可以使用for inputs, labels in dataloaders进行可迭代对象的访问DataLoader的方法了,只需要在构造函数中指定相应的参数即可,比如常见的batch_size,shuffle等等参数。所以使用DataLoader十分简洁方便。1 | #自己定义的函数:获取数据并将数据分为训练集和测试集两类 |
dataset创建了训练和测试两个实例,训练需要数据增强而测试不需要,所以两个实例除数据不同,tranform也不同DataLoader的参数X_dataset:数据集的实例batch_size:训练的提取数据的单位内所含个数shuffle :布尔值True或者是False ,表示每一个epoch之后是否对样本进行随机打乱,默认是Falsecollate_fn:这个是一个问题,因为通常由自己确定。是一个自己对于数据选取的方法,但是我不知道他有什么意义,或者优势pin_memory:如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中num_workers:这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)1 | import torch |
Dataset与DataLoader的使用,这里很简单,只是看一个思路。Batch_size个数的数据,但是你的神经网络的输入层的结构不会因为batch_size的变化受到影响batch_size对CNN结果的影响,点击这里了解collate_fn如何在加载数据时起作用?collate_fn上collate_fn默认是等于default_collate1 | def __next__(self): #这是DataLoader的关于迭代的方法的源码 |
collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果,想要知道详解点击这里,那么batch中的每一个元素其实就是Dataset中__getitem__方法得到的比如(图片信息img、标签label)那么如果理解了以上的意思,再看collate_fn就能较快地上手了
那什么时候该使用DataLoader的collate_fn这个参数?
当定义DataSet类中的__getitem__函数的时候,由于每次返回的是一组类似于(x,y)的样本,但是如果在返回的每一组样本x,y中出现什么错误,或者是还需要进一步对x,y进行一些处理的时候,我们就需要再定义一个collate_fn函数来实现这些功能。当然我也可以自己在实现__getitem__的时候就实现这些后处理也是可以的。
collate_fn,中单词collate的含义是:核对,校勘,对照,整理。顾名思义,这就是一个对每一组样本数据进行一遍“核对和重新整理”,现在可能更好理解一些
torchvision是什么?torchvision是PyTorch中专门用来处理图像的库,PyTorch官网的安装教程,也会让你安装上这个包。torchvision.datasetstorchvision.modelstorchvision.transformstorchvision.utilstorchvision.datasetstorchvision.datasets 是用来进行数据加载的,PyTorch团队在这个包中帮我们提前处理好了很多很多图片数据集。
MNISTCOCOCaptionsDetectionLSUNImageFolderImagenet-12CIFARSTL10SVHNPhotoTour我们可以直接使用,示例如下:
1 | import torchvision |
torchvision.modelstorchvision.models中为我们提供了已经训练好的模型,让我们可以加载之后,直接使用
torchvision.models模块的子模块中包含以下模型结构。
AlexNetVGGResNetSqueezeNetDenseNet我们可以直接使用如下代码来快速创建一个权重随机初始化的模型
1 | import torchvision.models as models |
也可以通过使用pretrained=True来加载一个别人预训练好的模型
1 | import torchvision.models as models |
torchvision.transformstransforms提供了一般的图像转换操作类,这就是我在一开始Dataset中提到了transform,还记得么?哈哈,可以翻一翻上面的例程中在创建Mydataset类的getitem方法中是如何使用transforms的
其实我猜测例程中,作者想在获取每一个图像信息时,除了进行opencv的操作外,训练时还要对图像数据信息做更多处理,而测试时当然就不需要这种处理
1 | # 我们这里还是对MNIST进行处理,初始的MNIST是 28 * 28,我们把它处理成 96 * 96 的torch.Tensor的格式 |
torchvision.datasets中获取的图像数据做transform,这样更简单,那如何在自己定义的Dataset类中加入Transform呢?DataSet,并重写__getitem__,在里面实现关键的transform操作transforms,通过Compose类来实现,注意这个类的返回值哦!DataSet的对象,将组合起来的transform传递进去,这样就会对每一个batch_size的图像都进行相关的数据增强操作了numpy矩阵类型,最后又要转换成tensor张量类型,这个要之后细究Resnet?👴也不知道,👴正在学
torch.normal(means, std, out=None)means和std分别给出均值和标准差torch.rand(*sizes, out=None) → Tensortorch.randn(*sizes, out=None) → Tensortorch.squeeze(a,N)a.squeeze(N) 也是去掉a中指定的维数为一的维度,去掉的维度数为Ntorch.unsqueeze()a.squeeze(N)可以关注一下python自带的
numpy数字计算库和torch的数据处理,其实相似度很高
torch.cat((A,B),0)np.hstack((A,B,C..)) & np.vstack((A,B,C..))torch.stack([tensor1, tensor2, tensor3…], dim=0)np.random.shuffle(x)np.transpose(x)torch.div(input, value, out=None)plt.scatter(x,y,c='r',s=20,marker=".",lw=2)matplotlib库中画散点图的函数,注意c这个参数的特别用法Update your browser to view this website correctly. Update my browser now