Ich arbeite an einem Trainingsscript in TensorFlow, um zwei verschiedene Arten von Bildern zu klassifizieren. Hier ist die Klasse, die ein Datensatzobjekt erstellt, das zum Generieren von Stapeln und zum Inkrementieren von Epochen verwendet wird. Es funktioniert gut, bis die erste Epoche abgeschlossen ist. Es scheitert dann an der Linie self._images = self._images[perm]
innerhalb der next_batch
Methode. Das ergibt für mich keinen Sinn, da Python keine self._images duplizieren sollte - nur die Daten neu zu mischen.Speicherleck in TensorFlow beim Start einer neuen Epoche
class DataSet(object):
def __init__(self, images, labels, norm=True):
assert images.shape[0] == labels.shape[0], (
"images.shape: %s labels.shape: %s" % (images.shape,
labels.shape))
self._num_examples = images.shape[0]
self._images = images
self._labels = labels
self._epochs_completed = 0
self._index_in_epoch = 0
self._norm = norm
# Shuffle the data right away
perm = np.arange(self._num_examples)
np.random.shuffle(perm)
self._images = self._images[perm]
self._labels = self._labels[perm]
@property
def images(self):
return self._images
@property
def labels(self):
return self._labels
@property
def num_examples(self):
return self._num_examples
@property
def epochs_completed(self):
return self._epochs_completed
def next_batch(self, batch_size):
"""Return the next `batch_size` examples from this data set."""
start = self._index_in_epoch
self._index_in_epoch += batch_size
if self._index_in_epoch > self._num_examples:
# Finished epoch
self._epochs_completed += 1
print("Completed epoch %d.\n"%self._epochs_completed)
# Shuffle the data
perm = np.arange(self._num_examples)
np.random.shuffle(perm)
self._images = self._images[perm] # this is where OOM happens
self._labels = self._labels[perm]
# Start next epoch
Die Speicherbelegung steigt während normaler Trainingszyklen nicht an. Hier ist der Teil des Trainingscodes. data_train_norm
ist ein DataSet
Objekt.
batch_size = 300
csv_plot = open("csvs/train_plot.csv","a")
for i in range(3000):
batch = data_train_norm.next_batch(batch_size)
if i%50 == 0:
tce = cross_entropy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob:1.0},session=sess)
print("\nstep %d, train ce %g"%(i,tce))
print datetime.datetime.now()
csv_plot.write("%d, %g\n"%(i,tce))
train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.8},session=sess)
version = 1
saver.save(sess,'nets/cnn0nu_batch_gpu_roi_v%02d'%version)
csv_plot.close()