Ich tat eine ähnliche Sache wie Sie beabsichtigen zu tun. Ich habe auch das gleiche Skript zum Erstellen von Bilddaten verwendet. Mein Code zum Lesen und Trainieren der Daten lautet
import tensorflow as tf
height = 28
width = 28
tfrecords_train_filename = 'train-00000-of-00001'
tfrecords_test_filename = 'test-00000-of-00001'
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image/height': tf.FixedLenFeature([], tf.int64),
'image/width': tf.FixedLenFeature([], tf.int64),
'image/colorspace': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/channels': tf.FixedLenFeature([], tf.int64),
'image/class/label': tf.FixedLenFeature([], tf.int64),
'image/class/text': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/format': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/filename': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value='')
})
image_buffer = features['image/encoded']
image_label = tf.cast(features['image/class/label'], tf.int32)
# Decode the jpeg
with tf.name_scope('decode_jpeg', [image_buffer], None):
# decode
image = tf.image.decode_jpeg(image_buffer, channels=3)
# and convert to single precision data type
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.rgb_to_grayscale(image)
image_shape = tf.stack([height, width, 1])
image = tf.reshape(image, image_shape)
return image, image_label
def inputs(filename, batch_size, num_epochs):
if not num_epochs: num_epochs = None
with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer([filename], num_epochs=None)
image, label = read_and_decode(filename_queue)
# Shuffle the examples and collect them into batch_size batches.
images, sparse_labels = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, num_threads=2,
capacity=1000 + 3 * batch_size,
min_after_dequeue=1000)
return images, sparse_labels
image, label = inputs(filename=tfrecords_train_filename, batch_size=200, num_epochs=None)
image = tf.reshape(image, [-1, 784])
label = tf.one_hot(label - 1, 10)
# Create the model
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
for i in range(1000):
img, lbl = sess.run([image, label])
sess.run(train_step, feed_dict={x: img, y_: lbl})
img, lbl = sess.run([image, label])
print(sess.run(accuracy, feed_dict={x: img, y_: lbl}))
coord.request_stop()
coord.join(threads)
Dies ist ein super einfaches Modell zur Klassifizierung von Mnist. Ich denke jedoch, dass es auch eine erweiterbare Antwort für das Trainieren mit TFRecord-Dateien ist. Die Evaluierungsdaten werden noch nicht berücksichtigt, da hierfür mehr Koordination erforderlich ist.
In diesem Handbuch (https://www.tensorflow.org/programmers_guide/datasets) finden Sie Beispiele, wie Sie Daten aus TFRecord-Dateien und GT-Tensoren mit den Daten laden können. Dann ist es nur eine Frage der Weitergabe dieser Daten als Eingabe an Ihr Netzwerk anstelle von jedem Eingang, den das Netzwerk im Moment erhält. – GPhilo
@GPhilo Ich habe meinen Datensatz als "images: Images. 4D tensor der Größe [batch_size, FLAGS.image_size, image_size, 3]. Etiketten: 1-D ganzzahlige Tensor von [FLAGS.batch_size]. ", Aber ich sehe nicht t.estimator.inputs eine Funktion zu nehmen, was ich geladen habe. – Eejin
tf.estimator.inputs verfügt über Komfortfunktionen, um Daten, die noch nicht im Tensorformat vorliegen, in etwas umzuwandeln, das das Netzwerk annehmen kann. Sie müssen das 'input_fn' neu schreiben. Ich kenne diese High-level-API nicht, aber aus der [Estimator-Dokumentation] (https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator) muss man wohl eine 'input_fn' definieren das gibt ein dict '{'images' zurück: your_image_tensor, 'labels': your_label_tensor}'. – GPhilo