2016-11-25 4 views
0

So habe ich dieses Spielzeug Beispielcode;Tensorflow Multithreading Bild wird geladen

import glob 
from tqdm import tqdm 
import tensorflow as tf 

imgPaths = glob.glob("/home/msmith/imgs/*/*") # Some images 

filenameQ = tf.train.string_input_producer(imgPaths) 
reader = tf.WholeFileReader() 
key, value = reader.read(filenameQ) 

img = tf.image.decode_jpeg(value) 
init_op = tf.initialize_all_variables() 

with tf.Session() as sess: 
    sess.run(init_op) 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 
    for i in tqdm(range(10000)): 
     img.eval().mean() 

lädt Bilder und druckt den Mittelwert von jedem. Wie ich es bearbeite, damit es den ladenden Teil der Bilder multithreading ist, der im Moment mein Flaschenhals auf meinen tf Bildskripten ist.

+0

ich einen Blick auf [QueueRunner] (https://www.tensorflow.org/versions/r0.11/how_tos/threading_and_queues/index.html#queuerunner) Klasse nehmen würde, obwohl es mir nicht klar ist, wie man es mit einem vorgebauten Leser verbindet. – sygi

Antwort

2

EDIT (2018/3/5): Es ist jetzt einfacher, die gleichen Ergebnisse mit der tf.data API zu erhalten.

import glob 
from tqdm import tqdm 
import tensorflow as tf 

imgPaths = glob.glob("/home/msmith/imgs/*/*") # Some images 

dataset = (tf.data.Dataset.from_tensor_slices(imgPaths) 
      .map(lambda x: tf.reduce_mean(tf.decode_jpeg(tf.read_file(x))), 
       num_parallel_calls=16) 
      .prefetch(128)) 

iterator = dataset.make_one_shot_iterator() 
next_mean = iterator.get_next() 

with tf.Session() as sess: 
    for i in tqdm(range(10000)): 
     sess.run(next_mean) 

sygi Wie in their comment schon sagt, kann ein tf.train.QueueRunner einige ops definiert werden, die in einem separaten Thread ausführen und (typischerweise) enqueue Werte in eine TensorFlow Warteschlange.

import glob 
from tqdm import tqdm 
import tensorflow as tf 

imgPaths = glob.glob("/home/msmith/imgs/*/*") # Some images 

filenameQ = tf.train.string_input_producer(imgPaths) 

# Define a subgraph that takes a filename, reads the file, decodes it, and                      
# enqueues it.                                     
filename = filenameQ.dequeue() 
image_bytes = tf.read_file(filename) 
decoded_image = tf.image.decode_jpeg(image_bytes) 
image_queue = tf.FIFOQueue(128, [tf.uint8], None) 
enqueue_op = image_queue.enqueue(decoded_image) 

# Create a queue runner that will enqueue decoded images into `image_queue`.                     
NUM_THREADS = 16 
queue_runner = tf.train.QueueRunner(
    image_queue, 
    [enqueue_op] * NUM_THREADS, # Each element will be run from a separate thread.                      
    image_queue.close(), 
    image_queue.close(cancel_pending_enqueues=True)) 

# Ensure that the queue runner threads are started when we call                        
# `tf.train.start_queue_runners()` below.                              
tf.train.add_queue_runner(queue_runner) 

# Dequeue the next image from the queue, for returning to the client.                       
img = image_queue.dequeue() 

init_op = tf.global_variables_initializer() 

with tf.Session() as sess: 
    sess.run(init_op) 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(sess=sess, coord=coord) 
    for i in tqdm(range(10000)): 
     img.eval().mean() 
+0

Das ist großartig. Ein paar mehr Dinge; Wenn ich Preprocessing machen möchte, mache ich das vor image_queue.dequeue()? Auch wann kann ich herausfinden, ob die Threads die Liste der Eingaben vollständig durchlaufen haben? – mattdns

+1

Für die Vorverarbeitung würden Sie dies vor 'image_queue.dequeue()' tun, aber Sie könnten eine weitere Queue/'QueueRunner' hinzufügen, wenn Sie möchten, dass eine andere Gruppe von Threads dies parallel zur Analyse durchführt. Wenn die Bilder alle die gleiche Größe haben, finden Sie vielleicht ['tf.train.batch()'] (https://www.tensorflow.org/versions/r0.12/api_docs/python/io_ops.html#batch) nützlich dafür. Der einfachste Weg zu sagen, wann die Threads fertig sind, ist die Verwendung von 'while not coord.should_stop():' anstelle der 'for'-Schleife. – mrry

+0

Ausgezeichnet. Das Etikett des Bildes ist in der Zeichenkette des Dateinamens kodiert, wenn ich das in einen OH-Vektor umwandeln kann und ich den richtigen Vektor zur richtigen Zeit raushaben möchte ... tue ich das, indem ich ein weiteres '' 'enqueue_op'' hinzufüge 'Tensor für den Klassenvektor in diese Liste' '' [enqueue_op] '' '? Nebenbei kann ich das Kopfgeld noch 2 Stunden nicht bezahlen. – mattdns