2017-07-23 1 views
0

Ich arbeite an einem Trainingsscript in TensorFlow, um zwei verschiedene Arten von Bildern zu klassifizieren. Hier ist die Klasse, die ein Datensatzobjekt erstellt, das zum Generieren von Stapeln und zum Inkrementieren von Epochen verwendet wird. Es funktioniert gut, bis die erste Epoche abgeschlossen ist. Es scheitert dann an der Linie self._images = self._images[perm] innerhalb der next_batch Methode. Das ergibt für mich keinen Sinn, da Python keine self._images duplizieren sollte - nur die Daten neu zu mischen.Speicherleck in TensorFlow beim Start einer neuen Epoche

class DataSet(object): 
    def __init__(self, images, labels, norm=True): 
    assert images.shape[0] == labels.shape[0], (
     "images.shape: %s labels.shape: %s" % (images.shape, 
             labels.shape)) 
    self._num_examples = images.shape[0] 
    self._images = images 
    self._labels = labels 
    self._epochs_completed = 0 
    self._index_in_epoch = 0 
    self._norm = norm 
    # Shuffle the data right away 
    perm = np.arange(self._num_examples) 
    np.random.shuffle(perm) 
    self._images = self._images[perm] 
    self._labels = self._labels[perm] 
    @property 
    def images(self): 
    return self._images 
    @property 
    def labels(self): 
    return self._labels 
    @property 
    def num_examples(self): 
    return self._num_examples 
    @property 
    def epochs_completed(self): 
    return self._epochs_completed 
    def next_batch(self, batch_size): 
    """Return the next `batch_size` examples from this data set.""" 
    start = self._index_in_epoch 
    self._index_in_epoch += batch_size 
    if self._index_in_epoch > self._num_examples: 
     # Finished epoch 
     self._epochs_completed += 1 
     print("Completed epoch %d.\n"%self._epochs_completed) 
     # Shuffle the data 
     perm = np.arange(self._num_examples) 
     np.random.shuffle(perm) 
     self._images = self._images[perm] # this is where OOM happens 
     self._labels = self._labels[perm] 
     # Start next epoch 

Die Speicherbelegung steigt während normaler Trainingszyklen nicht an. Hier ist der Teil des Trainingscodes. data_train_norm ist ein DataSet Objekt.

batch_size = 300 
csv_plot = open("csvs/train_plot.csv","a") 
for i in range(3000): 
    batch = data_train_norm.next_batch(batch_size) 
    if i%50 == 0: 
      tce = cross_entropy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob:1.0},session=sess) 
      print("\nstep %d, train ce %g"%(i,tce)) 
      print datetime.datetime.now() 
      csv_plot.write("%d, %g\n"%(i,tce)) 

    train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.8},session=sess) 

version = 1 
saver.save(sess,'nets/cnn0nu_batch_gpu_roi_v%02d'%version) 
csv_plot.close() 

Antwort

1

Dies ist wahrscheinlich, weil dieses Stück Code, wo Sie auf das Graphen

for i in range(3000): 
    batch = data_train_norm.next_batch(batch_size) 

Verfahren einen neuen next_batch Betrieb hinzufügen data_train_norm.next_batch einen neuen TensorFlow Betrieb erstellen, so dass Sie es nur nennen sollten einmal und verwenden Sie die Operation erstellt (halten Sie in batch). Werfen Sie einen Blick auf die Beispiele in der doc, z.B .:

dataset = tf.contrib.data.Dataset.range(100) 
iterator = dataset.make_one_shot_iterator() 
next_element = iterator.get_next() 

for i in range(100): 
    value = sess.run(next_element) 
    assert i == value 

Auch wenn TensorFlow Speicherleck Debuggen, können Sie sess.graph.finalize()

Verwandte Themen