Pytorch 简明教程
PyTorch - Loading Data
PyTorch 包含一个名为 torchvision 的包,用于加载和准备数据集。它包括两个基本函数:Dataset 和 DataLoader,它们有助于转换和加载数据集。
Dataset
Dataset 用于从给定数据集中读取和转换数据点。要实现的基本语法如下所示 −
trainset = torchvision.datasets.CIFAR10(root = './data', train = True,
download = True, transform = transform)
DataLoader 用于随机排列和批量处理数据。它可用于使用多处理工作程序并行加载数据。
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 4,
shuffle = True, num_workers = 2)
Example: Loading CSV File
我们使用 Python 包 Panda 加载 csv 文件。原始文件具有以下格式:(图像名称,68 个地标 - 每个地标具有 x、y 坐标)。
landmarks_frame = pd.read_csv('faces/face_landmarks.csv')
n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)