Ich versuche, Daten aus einer hdf5-Datei in die Warteschlange einzureihen. Da Tensorflow hdf5 nicht unterstützt, habe ich eine Python-Funktion erstellt, die Beispiele aus einer hdf5-Datei liest und tf.errors.OutOfRangeError
auslöst, wenn sie das Ende der Datei erreicht. Ich wickle dann diese Python-Funktion mit tf.py_func
und verwende sie als Enqueue-Op für meine Warteschlange.Tensorflow: Benutzerdefinierte Datenleser mit py_func
Dies ist mein Code:
import h5py
import tensorflow as tf
from tensorflow.python.framework import errors
import numpy as np
def read_from_hdf5(hdf5_file, batch_size):
h5py_handle = h5py.File(hdf5_file)
# Check shapes from the hdf5 file so that we can set the tensor shapes
feature_shape = h5py_handle['features'].shape[1:]
label_shape = h5py_handle['labels'].shape[1:]
#generator that produces examples for training. It will be wrapped by tf.pyfunc to simulate a reader
def example_generator(h5py_handle):
for i in xrange(0, h5py_handle['features'].shape[0]-batch_size+1, batch_size):
features = h5py_handle['features'][i:i+batch_size]
labels = h5py_handle['labels'][i:i+batch_size]
yield [features, labels]
raise errors.OutOfRangeError(node_def=None, op=None, message='completed all examples in %s'%hdf5_file)
[features_tensor, labels_tensor] = tf.py_func(
example_generator(h5py_handle).next,
[],
[tf.float32, tf.float32],
stateful=True)
# Set the shape so that we can infer sizes etc in later layers.
features_tensor.set_shape([batch_size, feature_shape[0], feature_shape[1], feature_shape[2]])
labels_tensor.set_shape([batch_size, label_shape[0]])
return features_tensor, labels_tensor
def load_data_from_filename_list(hdf5_files, batch_size, shuffle_seed=0):
example_list = [read_from_hdf5(hdf5_file, batch_size) for hdf5_file in hdf5_files]
min_after_dequeue = 10000
capacity = min_after_dequeue + (len(example_list)+1) * batch_size #min_after_dequeue + (num_threads + a small safety margin) * batch_size
features, labels = tf.train.shuffle_batch_join(example_list, batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue, seed=shuffle_seed, enqueue_many=True)
return features, labels, metadata
ich erwartet, dass die tf.errors.OutOfRangeError
vom QueueRunner behandelt werden würde, aber ich folgende Fehlermeldung erhalten und das Programm stürzt ab. Ist es möglich, diese Art von Lesen von einem py_func zu machen, und wenn ja, was mache ich falsch? Wenn nicht, welchen Ansatz sollte ich stattdessen verwenden?
Es kann helfen, wenn Sie Ihren Code sowie die Fehlermeldung zeigen, damit Leute Ihr Problem reproduzieren können. – merlin2011
Danke für den Tipp. Ich habe meinen Code hinzugefügt. – navari