Pytorch 简明教程

PyTorch - Datasets

在本章中,我们将重点介绍 torchvision.datasets 及其各种类型。PyTorch 包含以下数据集加载器 −

  1. MNIST

  2. COCO (Captioning and Detection)

数据集包括下面给出的两种主要类型的函数 −

  1. Transform − 一个获取图像并返回标准内容的修改版本的功能。这些可以与变换一起构成。

  2. Target_transform − 一个获取目标并进行转换的功能。例如,获取标题字符串并返回世界索引的张量。

MNIST

以下是由 MNIST 数据集生成的示例代码 -

dset.MNIST(root, train = TRUE, transform = NONE,
target_transform = None, download = FALSE)

参数如下:

  1. root - 数据集的根目录,已处理的数据位于此处。

  2. train - True = 训练集,False = 测试集

  3. download - True = 从互联网下载数据集并将它放在根目录中。

COCO

这需要安装 COCO API。以下示例演示如何使用 PyTorch 实现 COCO 数据集 -

import torchvision.dataset as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = ‘ dir where images are’,
annFile = ’json annotation file’,
transform = transforms.ToTensor())
print(‘Number of samples: ‘, len(cap))
print(target)

达到的输出如下 -

Number of samples: 82783
Image Size: (3L, 427L, 640L)