本文介绍了tensorflow中next_batch的具体使用,分享给大家,具体如下:
此处给出了几种不同的next_batch方法,该文章只是做出代码片段的解释,以备以后查看:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
|
def next_batch( self , batch_size, fake_data = False ): """Return the next `batch_size` examples from this data set.""" if fake_data: fake_image = [ 1 ] * 784 if self .one_hot: fake_label = [ 1 ] + [ 0 ] * 9 else : fake_label = 0 return [fake_image for _ in xrange (batch_size)], [ fake_label for _ in xrange (batch_size) ] start = self ._index_in_epoch self ._index_in_epoch + = batch_size if self ._index_in_epoch > self ._num_examples: # epoch中的句子下标是否大于所有语料的个数,如果为True,开始新一轮的遍历 # Finished epoch self ._epochs_completed + = 1 # Shuffle the data perm = numpy.arange( self ._num_examples) # arange函数用于创建等差数组 numpy.random.shuffle(perm) # 打乱 self ._images = self ._images[perm] self ._labels = self ._labels[perm] # Start next epoch start = 0 self ._index_in_epoch = batch_size assert batch_size < = self ._num_examples end = self ._index_in_epoch return self ._images[start:end], self ._labels[start:end] |
该段代码摘自mnist.py文件,从代码第12行start = self._index_in_epoch开始解释,_index_in_epoch-1是上一次batch个图片中最后一张图片的下边,这次epoch第一张图片的下标是从 _index_in_epoch开始,最后一张图片的下标是_index_in_epoch+batch, 如果 _index_in_epoch 大于语料中图片的个数,表示这个epoch是不合适的,就算是完成了语料的一遍的遍历,所以应该对图片洗牌然后开始新一轮的语料组成batch开始
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
def ptb_iterator(raw_data, batch_size, num_steps): """Iterate on the raw PTB data. This generates batch_size pointers into the raw PTB data, and allows minibatch iteration along these pointers. Args: raw_data: one of the raw data outputs from ptb_raw_data. batch_size: int, the batch size. num_steps: int, the number of unrolls. Yields: Pairs of the batched data, each a matrix of shape [batch_size, num_steps]. The second element of the tuple is the same data time-shifted to the right by one. Raises: ValueError: if batch_size or num_steps are too high. """ raw_data = np.array(raw_data, dtype = np.int32) data_len = len (raw_data) batch_len = data_len / / batch_size #有多少个batch data = np.zeros([batch_size, batch_len], dtype = np.int32) # batch_len 有多少个单词 for i in range (batch_size): # batch_size 有多少个batch data[i] = raw_data[batch_len * i:batch_len * (i + 1 )] epoch_size = (batch_len - 1 ) / / num_steps # batch_len 是指一个batch中有多少个句子 #epoch_size = ((len(data) // model.batch_size) - 1) // model.num_steps # // 表示整数除法 if epoch_size = = 0 : raise ValueError( "epoch_size == 0, decrease batch_size or num_steps" ) for i in range (epoch_size): x = data[:, i * num_steps:(i + 1 ) * num_steps] y = data[:, i * num_steps + 1 :(i + 1 ) * num_steps + 1 ] yield (x, y) |
第三种方式:
1
2
3
4
5
6
7
8
9
10
11
12
13
|
def next ( self , batch_size): """ Return a batch of data. When dataset end is reached, start over. """ if self .batch_id = = len ( self .data): self .batch_id = 0 batch_data = ( self .data[ self .batch_id: min ( self .batch_id + batch_size, len ( self .data))]) batch_labels = ( self .labels[ self .batch_id: min ( self .batch_id + batch_size, len ( self .data))]) batch_seqlen = ( self .seqlen[ self .batch_id: min ( self .batch_id + batch_size, len ( self .data))]) self .batch_id = min ( self .batch_id + batch_size, len ( self .data)) return batch_data, batch_labels, batch_seqlen |
第四种方式:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
def batch_iter(sourceData, batch_size, num_epochs, shuffle = True ): data = np.array(sourceData) # 将sourceData转换为array存储 data_size = len (sourceData) num_batches_per_epoch = int ( len (sourceData) / batch_size) + 1 for epoch in range (num_epochs): # Shuffle the data at each epoch if shuffle: shuffle_indices = np.random.permutation(np.arange(data_size)) shuffled_data = sourceData[shuffle_indices] else : shuffled_data = sourceData for batch_num in range (num_batches_per_epoch): start_index = batch_num * batch_size end_index = min ((batch_num + 1 ) * batch_size, data_size) yield shuffled_data[start_index:end_index] |
迭代器的用法,具体学习Python迭代器的用法
另外需要注意的是,前三种方式只是所有语料遍历一次,而最后一种方法是,所有语料遍历了num_epochs次
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持服务器之家。
原文链接:http://blog.csdn.net/appleml/article/details/57413615