2017-09-08 1 views
4

Tensorflow programmer's guide empfiehlt die Verwendung von Feedable-Iterator, um zwischen Trainings- und Validierungsdatensatz zu wechseln, ohne den Iterator neu zu initialisieren. Es erfordert hauptsächlich, den Griff zu füttern, um zwischen ihnen zu wählen.Wie kann der Feedable-Iterator von Tensorflow Dataset API zusammen mit MonitoredTrainingSession verwendet werden?

Wie man es zusammen mit tf.train.MonitoredTrainingSession verwenden?

Folgende Methode schlägt fehl mit "RuntimeError: Diagramm ist abgeschlossen und kann nicht geändert werden." Error.

with tf.train.MonitoredTrainingSession() as sess: 
    training_handle = sess.run(training_iterator.string_handle()) 
    validation_handle = sess.run(validation_iterator.string_handle()) 

Wie sowohl die Bequemlichkeit der MonitoredTrainingSession zu erreichen und Ausbildung und Validierung Datensätze iterieren gleichzeitig?

Antwort

2

bekam ich die Antwort aus der Tensorflow GitHub Ausgabe - https://github.com/tensorflow/tensorflow/issues/12859

Die Lösung ist die iterator.string_handle() vor dem Erstellen des MonitoredSession aufzurufen.

import tensorflow as tf 
from tensorflow.contrib.data import Dataset, Iterator 

dataset_train = Dataset.range(10) 
dataset_val = Dataset.range(90, 100) 

iter_train_handle = dataset_train.make_one_shot_iterator().string_handle() 
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle() 

handle = tf.placeholder(tf.string, shape=[]) 
iterator = Iterator.from_string_handle(
    handle, dataset_train.output_types, dataset_train.output_shapes) 
next_batch = iterator.get_next() 

with tf.train.MonitoredTrainingSession() as sess: 
    handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle]) 

    for step in range(10): 
     print('train', sess.run(next_batch, feed_dict={handle: handle_train})) 

     if step % 3 == 0: 
      print('val', sess.run(next_batch, feed_dict={handle: handle_val})) 

Output: 
('train', 0) 
('val', 90) 
('train', 1) 
('train', 2) 
('val', 91) 
('train', 3) 
0

@Michael Jais G Antwort ist korrekt. Es funktioniert jedoch nicht, wenn Sie auch bestimmte session_run_hooks verwenden möchten, die Teile des Diagramms wie z. LoggingTensorHook oder SummarySaverHook. Das folgende Beispiel wird einen Fehler verursachen:

import tensorflow as tf 

dataset_train = tf.data.Dataset.range(10) 
dataset_val = tf.data.Dataset.range(90, 100) 

iter_train_handle = dataset_train.make_one_shot_iterator().string_handle() 
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle() 

handle = tf.placeholder(tf.string, shape=[]) 
iterator = tf.data.Iterator.from_string_handle(
    handle, dataset_train.output_types, dataset_train.output_shapes) 
feature = iterator.get_next() 

pred = feature * feature 
tf.summary.scalar('pred', pred) 
global_step = tf.train.create_global_step() 

summary_hook = tf.train.SummarySaverHook(save_steps=5, 
             output_dir="summaries", summary_op=tf.summary.merge_all()) 

with tf.train.MonitoredTrainingSession(hooks=[summary_hook]) as sess: 
    handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle]) 

    for step in range(10): 
     feat = sess.run(feature, feed_dict={handle: handle_train}) 
     pred_ = sess.run(pred, feed_dict={handle: handle_train}) 
     print('train: ', feat) 
     print('pred: ', pred_) 

     if step % 3 == 0: 
      print('val', sess.run(feature, feed_dict={handle: handle_val})) 

Dies wird mit Fehler fehlschlagen:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype string 
    [[Node: Placeholder = Placeholder[dtype=DT_STRING, shape=[], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]] 
    [[Node: cond/Switch_1/_15 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_18_cond/Switch_1", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]] 

Der Grund dafür ist, dass der Haken versuchen, die Grafik bereits auf dem ersten session.run zu bewerten ([ iter_train_handle, iter_val_handle]), die offensichtlich noch kein Handle im feed_dict enthält. Die Problemumgehung besteht darin, die Hooks, die das Problem verursachen, zu überschreiben und den Code in before_run und after_run zu ändern, um nur bei session.run-Aufrufen mit dem Handle im feed_dict auszuwerten (Sie können auf das feed_dict der aktuellen session.run zugreifen.) Aufruf über das run_context-Argument von before_run und after_run)

Oder Sie können den neuesten Master von Tensorflow (post-1.4) verwenden, der eine Funktion run_step_fn zu MonitoredSession hinzufügt, mit der Sie die folgende step_fn angeben können, die den Fehler vermeidet (on die Kosten der Auswertung der if-Anweisung TrainingIteration Anzahl der Male ...)

def step_fn(step_context): 
    if handle_train is None: 
    handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle]) 
    return step_context.run_with_hooks(fetches=..., feed_dict=...) 
Verwandte Themen