shuffle = False时,不打乱数据顺序
shuffle = True,随机打乱
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
|
import numpy as np import h5py import torch from torch.utils.data import DataLoader, Dataset h5f = h5py. File ( 'train.h5' , 'w' ); data1 = np.array([[ 1 , 2 , 3 ], [ 2 , 5 , 6 ], [ 3 , 5 , 6 ], [ 4 , 5 , 6 ]]) data2 = np.array([[ 1 , 1 , 1 ], [ 1 , 2 , 6 ], [ 1 , 3 , 6 ], [ 1 , 4 , 6 ]]) h5f.create_dataset( str ( 'data' ), data = data1) h5f.create_dataset( str ( 'label' ), data = data2) class Dataset(Dataset): def __init__( self ): h5f = h5py. File ( 'train.h5' , 'r' ) self .data = h5f[ 'data' ] self .label = h5f[ 'label' ] def __getitem__( self , index): data = torch.from_numpy( self .data[index]) label = torch.from_numpy( self .label[index]) return data, label def __len__( self ): assert self .data.shape[ 0 ] = = self .label.shape[ 0 ], "wrong data length" return self .data.shape[ 0 ] dataset_train = Dataset() loader_train = DataLoader(dataset = dataset_train, batch_size = 2 , shuffle = True ) for i, data in enumerate (loader_train): train_data, label = data print (train_data) |
pytorch DataLoader使用细节
背景:
我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,
数据变换共有以下内容
1
2
3
4
5
|
composed = transforms.Compose([transforms.Resize(( 448 , 448 )), # resize transforms.RandomCrop( 300 ), # random crop transforms.ToTensor(), transforms.Normalize(mean = [ 0.5 , 0.5 , 0.5 ], # normalize std = [ 0.5 , 0.5 , 0.5 ])]) |
简单的数据读取类, 进返回PIL格式的image:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
class MyDataset(data.Dataset): def __init__( self , labels_file, root_dir, transform = None ): with open (labels_file) as csvfile: self .labels_file = list (csv.reader(csvfile)) self .root_dir = root_dir self .transform = transform def __len__( self ): return len ( self .labels_file) def __getitem__( self , idx): im_name = os.path.join(root_dir, self .labels_file[idx][ 0 ]) im = Image. open (im_name) if self .transform: im = self .transform(im) return im |
下面是主程序
1
2
3
4
5
6
7
8
9
10
11
|
labels_file = "F:/test_temp/labels.csv" root_dir = "F:/test_temp" dataset_transform = MyDataset(labels_file, root_dir, transform = composed) dataloader = data.DataLoader(dataset_transform, batch_size = 1 , shuffle = False ) """原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张) """ for eopch in range ( 2 ): plt.figure(figsize = ( 6 , 6 )) for ind, i in enumerate (dataloader): a = i[ 0 , :, :, :].numpy().transpose(( 1 , 2 , 0 )) plt.subplot( 1 , 3 , ind + 1 ) plt.imshow(a) |
从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增
以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/qq_35752161/article/details/110875040