Ich versuche, die Pipeline zum Lesen von Bildern auf dem CNN zu verwenden. Ich habe string_input_producer()
verwendet, um die Warteschlange der Dateinamen zu erhalten, aber es scheint dort zu hängen, ohne etwas zu tun. Unten ist mein Code, bitte geben Sie mir einige Ratschläge, wie es funktioniert.TensorFlow Bildlese-Warteschlange leer
def read_image_file(filename_queue, labels):
reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)
image = tf.image.decode_png(value, channels=3)
image = tf.cast(image, tf.float32)
resized_image = tf.image.resize_images(image, [224, 112])
with tf.Session() as sess:
label = getLabel(labels, key.eval())
return resized_image, label
def input_pipeline(filename_queue, queue_names, batch_size, num_epochs, labels):
image, label = read_image_file(filename_queue, labels)
min_after_dequeue = 10 * batch_size
capacity = 20 * batch_size
image_batch, label_batch = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, num_threads=1, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return image_batch, label_batch
train_queue = tf.train.string_input_producer(trainnames, shuffle=True, num_epochs=epochs)
train_batch, train_label = input_pipeline(train_queue, trainnames, batch_size, epochs, labels)
prediction = AlexNet(x)
#Training
with tf.name_scope("cost_function") as scope:
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=train_label, logits=prediction(train_batch)))
tf.summary.scalar("cost_function", cost)
train_step = tf.train.MomentumOptimizer(learning_rate, 0.9).minimize(cost)
#Accuracy
with tf.name_scope("accuracy") as scope:
correct_prediction = tf.equal(tf.argmax(prediction,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar("accuracy", accuracy)
merged = tf.summary.merge_all()
#Session
with tf.Session() as sess:
print('started')
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord, start=True)
sess.run(threads)
try:
for step in range(steps_per_epch * epochs):
print('step: %d' %step)
sess.run(train_step)
except tf.errors.OutOfRangeError as ex:
pass
coord.request_stop()
coord.join(threads)
Ich habe 'getLabel' in meinem Code definiert, aber habe es hier nicht angehängt, es extrahiert im Prinzip die Bezeichnung aus dem Dateinamen (eine Zeichenkette), aber der Schlüssel ist ein Tensor. Also habe ich 'key.eval()' gemacht, um die Zeichenkette des Dateinamens zu erhalten. Jetzt scheint es nicht zu funktionieren, gibt es einen anderen Weg, um die Schnur vom Tensor zu bekommen? – ALeex
Sie müssen wahrscheinlich alle Ihre String-Operationen durch String-Tensor-Operationen ersetzen, so dass sie Teil des Graphen sind und zur Laufzeit ausgeführt werden. – npf
Schauen Sie auf https://www.tensorflow.org/api_guides/python/string_ops – npf