2017-11-21 1 views
4

Ich ändere meinen TensorFlow-Code von der alten Warteschlangenschnittstelle zur neuen Dataset-API. Mit der alten Schnittstelle könnte ich die tatsächlich gefüllte Warteschlangengröße überwachen, indem ich auf einen Rohzähler in der Grafik z. wie folgt:Zugriffsnummer der in der TensorFlow-Dataset-API eingereihten Elemente

queue = tf.train.shuffle_batch(..., name="training_batch_queue") 
queue_size_op = "training_batch_queue/random_shuffle_queue_Size:0" 
queue_size = session.run(queue_size_op) 

jedoch mit dem neuen Datensatz API kann ich nicht auf die Warteschlangen/Datensätze im Zusammenhang scheinen alle Variablen in der Grafik zu finden, so dass meine alten Code nicht mehr funktioniert. Gibt es eine Möglichkeit, die Anzahl der Elemente in der Warteschlange mithilfe der neuen Dataset-API (z. B. in der Warteschlange tf.Dataset.prefetch oder tf.Dataset.shuffle) zu erhalten?

Es ist wichtig für mich, die Anzahl der Elemente in der Warteschlange zu überwachen, da dies mir viel über das Verhalten der Vorverarbeitung in den Warteschlangen sagt, einschließlich ob die Vorverarbeitung oder der Rest (z Netzwerk) ist der Geschwindigkeitsengpass.

+0

bitte, schauen Sie sich [46444018] (https://stackoverflow.com/questions/46444018/meaning-of-buffer-size-in-dataset-map-dataset-prefetch-and-dataset-shuffle) an erhalten Sie eine bessere Vorstellung von dem zugrunde liegenden Verhalten der verschiedenen Arten von Shuffle-Argumenten –

+0

@maxF. Ja ich verstehe. Das Beispiel in meinem Post ist vielleicht nicht das beste, da es interessant ist, 'tf.train.shuffle_batch' in der alten Einstellung zu überwachen, aber es macht keinen Sinn,' tf.Dataset.shuffle' in der neuen Einstellung zu überwachen. Was zur Überwachung sinnvoll ist, ist die Größe von 'tf.Dataset.prefetch', um eine Vorstellung davon zu bekommen, ob die Vorverarbeitung oder das eigentliche Netzwerk der Flaschenhals ist. – CNugteren

Antwort

0

Als eine Arbeit herum ist es möglich, einen Zähler zu halten, um anzuzeigen, wie viele Elemente in der Warteschlange sind. Hier ist, wie der Zähler zu definieren:

queue_size = tf.get_variable("queue_size", initializer=0, 
           trainable=False, use_resource=True) 

Dann, wenn die Vorverarbeitung Daten (zB in der dataset.map-Funktion), wir, daß der Zähler erhöhen können:

def pre_processing(): 
    data_size = ... # compute this (could be just '1') 
    queue_size_op = tf.assign_add(queue_size, data_size) # adding items 
    with tf.control_dependencies([queue_size_op]): 
     # do the actual pre-processing here 

Wir können dann dekrementiert den Zähler every- Zeit laufen wir unser Modell mit einer Charge von Daten:

def model(): 
    queue_size_op = tf.assign_add(queue_size, -batch_size) # removing items 
    with tf.control_dependencies([queue_size_op]): 
     # define the actual model here 

Jetzt haben wir alle die queue_size Tensor in unserer Trainingsschleife, um herauszufinden, was die aktuellen q tun müssen, sind laufen ueue Größe ist, das heißt die Anzahl der Elemente in der Warteschlange in diesem Moment:

current_queue_size = session.run(queue_size) 

Es ist ein bisschen weniger elegant im Vergleich zu dem alten Weg (vor dem Dataset API), aber es funktioniert den Trick.

Verwandte Themen