新闻动态

Pytorch DataLoader shuffle验证方式

发布日期:2022-03-26 15:20 | 文章来源:源码之家

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset  
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
[2,5,6],
  [3,5,6],
  [4,5,6]])
data2 = np.array([[1,1,1],
 [1,2,6],
[1,3,6],
[1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
 def __init__(self):
  h5f = h5py.File('train.h5', 'r')
  self.data = h5f['data']
  self.label = h5f['label']
 def __getitem__(self, index):
  data = torch.from_numpy(self.data[index])
  label = torch.from_numpy(self.label[index])
  return data, label
 
 def __len__(self):
  assert self.data.shape[0] == self.label.shape[0], "wrong data length"
  return self.data.shape[0] 
 
dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,batch_size=2,shuffle = True)
 
for i, data in enumerate(loader_train):
 train_data, label = data
 print(train_data)
 

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize
transforms.RandomCrop(300), # random crop
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalizestd=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset): 
 def __init__(self, labels_file, root_dir, transform=None):
  with open(labels_file) as csvfile:
self.labels_file = list(csv.reader(csvfile))
  self.root_dir = root_dir
  self.transform = transform
  
 def __len__(self):
  return len(self.labels_file)
 
 def __getitem__(self, idx):
  im_name = os.path.join(root_dir, self.labels_file[idx][0])
  im = Image.open(im_name)
  
  if self.transform:
im = self.transform(im)

  return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range(2):
 plt.figure(figsize=(6, 6)) 
 for ind, i in enumerate(dataloader):
  a = i[0, :, :, :].numpy().transpose((1, 2, 0))
  plt.subplot(1, 3, ind+1)
  plt.imshow(a)

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持本站。

香港服务器租用

版权声明:本站文章来源标注为YINGSOO的内容版权均为本站所有,欢迎引用、转载,请保持原文完整并注明来源及原文链接。禁止复制或仿造本网站,禁止在非www.yingsoo.com所属的服务器上建立镜像,否则将依法追究法律责任。本站部分内容来源于网友推荐、互联网收集整理而来,仅供学习参考,不代表本站立场,如有内容涉嫌侵权,请联系alex-e#qq.com处理。

相关文章

实时开通

自选配置、实时开通

免备案

全球线路精选!

全天候客户服务

7x24全年不间断在线

专属顾问服务

1对1客户咨询顾问

在线
客服

在线客服:7*24小时在线

客服
热线

400-630-3752
7*24小时客服服务热线

关注
微信

关注官方微信
顶部