Pytorch Learning

How to read your own dataset

Posted by GwanSiu on July 13, 2017

1.数据的存放

首先,将自己的数据集按照如下形式存放: ---

2. 数据的预处理

数据的预处理主要是三步:torch.transoforms.Compose,torchvision.datasetstorch.utils.data.DataLoader,分别是数据预处理的方法,加载数据集和形成数据生成器。

torch.transoforms.Compose: 可以把图像预处理的方法都集中起来,按照编写的顺序方式,按顺序对图像进行预处理。注意:图像预处理的操作只对于PIL格式图像,在处理完之后需要转化成Tensor:transforms.Tosensor

import torch
import torchvision
import torchvision.transforms as transform

#将图像尺度变成(224,224),接着转化成Tensor,归一化
transform = {'train': transform.Compose([transform.Scale(224,224),
transform.ToTensor(),transform.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))]),
'test':transform.Compose([[transform.Scale(224,224),
transform.ToTensor(),transform.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])}   

接着,加载数据集

train_set = torchvision.datasets.ImageFolder(train_path, transform=transform['train'])
train_loader = torch.utils.data.DataLoader(train_set,batch_size=32,shuffle=True,num_workers=2)

test_set = torchvision.datasets.ImageFolder(test_path,transform=transform['test'])
test_loader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=True, num_workers=2)

使用迭代的方式读取数据集:

import numpy as np
import matplotlib.pyplot as plt

def imshow(img):
    img = img/2 + 0.5
    nimg = img.numpy()
    plt.imshow(np.transpose(nimg,(1,2,0)))

dataloader = iter(train_loader)
images, labels = dataloader.next()
imshow(torchvision.utils.make_grid(images))