2017-09-09 1 views
10

Keras 'fit_generator() Model-Methode erwartet einen Generator, der Tupel der Form (Eingabe, Ziele) erzeugt, wobei beide Elemente NumPy-Arrays sind. The documentation scheint zu implizieren, dass, wenn ich einfach einen Dataset iterator in einem Generator verpacke, und sicherstellen, dass die Tensors zu NumPy-Arrays konvertieren, sollte ich gut gehen. Dieser Code gibt mir jedoch einen Fehler:Wie können TensorFlows Dataset API und Keras richtig kombiniert werden?

import numpy as np 
import os 
import keras.backend as K 
from keras.layers import Dense, Input 
from keras.models import Model 
import tensorflow as tf 
from tensorflow.contrib.data import Dataset 

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

with tf.Session() as sess: 
    def create_data_generator(): 
     dat1 = np.arange(4).reshape(-1, 1) 
     ds1 = Dataset.from_tensor_slices(dat1).repeat() 

     dat2 = np.arange(5, 9).reshape(-1, 1) 
     ds2 = Dataset.from_tensor_slices(dat2).repeat() 

     ds = Dataset.zip((ds1, ds2)).batch(4) 
     iterator = ds.make_one_shot_iterator() 
     while True: 
      next_val = iterator.get_next() 
      yield sess.run(next_val) 

datagen = create_data_generator() 

input_vals = Input(shape=(1,)) 
output = Dense(1, activation='relu')(input_vals) 
model = Model(inputs=input_vals, outputs=output) 
model.compile('rmsprop', 'mean_squared_error') 
model.fit_generator(datagen, steps_per_epoch=1, epochs=5, 
        verbose=2, max_queue_size=2) 

Hier ist der Fehler, den ich bekommen:

Using TensorFlow backend. 
Epoch 1/5 
Exception in thread Thread-1: 
Traceback (most recent call last): 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 270, in __init__ 
    fetch, allow_tensor=True, allow_operation=True)) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2708, in as_graph_element 
    return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2787, in _as_graph_element_locked 
    raise ValueError("Tensor %s is not an element of this graph." % obj) 
ValueError: Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph. 

During handling of the above exception, another exception occurred: 

Traceback (most recent call last): 
    File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner 
    self.run() 
    File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 864, in run 
    self._target(*self._args, **self._kwargs) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/utils/data_utils.py", line 568, in data_generator_task 
    generator_output = next(self._generator) 
    File "./datagen_test.py", line 25, in create_data_generator 
    yield sess.run(next_val) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 895, in run 
    run_metadata_ptr) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1109, in _run 
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 413, in __init__ 
    self._fetch_mapper = _FetchMapper.for_fetch(fetches) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 233, in for_fetch 
    return _ListFetchMapper(fetch) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in __init__ 
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in <listcomp> 
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 241, in for_fetch 
    return _ElementFetchMapper(fetches, contraction_fn) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 277, in __init__ 
    'Tensor. (%s)' % (fetch, str(e))) 
ValueError: Fetch argument <tf.Tensor 'IteratorGetNext:0' shape=(?, 1) dtype=int64> cannot be interpreted as a Tensor. (Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.) 

Traceback (most recent call last): 
    File "./datagen_test.py", line 34, in <module> 
    verbose=2, max_queue_size=2) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper 
    return func(*args, **kwargs) 
    File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 2011, in fit_generator 
    generator_output = next(output_generator) 
StopIteration 

Merkwürdigerweise eine Linie Zugabe enthält next(datagen) direkt nach dem ich initialisieren datagen den Code verursacht nur laufen gut, ohne Fehler.

Warum funktioniert mein ursprünglicher Code nicht? Warum fängt es an zu arbeiten, wenn ich diese Zeile zu meinem Code hinzufüge? Gibt es eine effizientere Möglichkeit, TensorFlows Dataset-API mit Keras zu verwenden, die keine Umwandlung von Tensoren in NumPy-Arrays und wieder zurück erfordert?

+0

Ich bin mir nicht sicher, ob das der Grund ist, aber ich finde es wirklich seltsam, dass Sie eine Funktion innerhalb eines 'with' Blocks definieren. –

+0

Offensichtlich macht das Setzen des 'with' Blocks innerhalb der Generatordefinition den Code sowohl mit als auch ohne die zusätzliche Zeile funktionieren, obwohl ich hätte schwören können, dass ich es zuerst auf diese Weise versucht habe. In Anbetracht der (meiner Meinung nach) Arbeit von TensorFlow 'Session sehe ich jedoch nicht, warum es einen Unterschied machen sollte. Ein anderes Geheimnis. – Jason

+0

Schließt der with-Block nicht die Sitzung am Ende? Ich denke, es sollte wirklich keine Definitionen enthalten, die außerhalb davon verwendet werden ....Wenn ich das als Antwort auf die Frage poste, würde es als beantwortet markiert werden? –

Antwort

7

Es gibt in der Tat eine effizientere Möglichkeit, Dataset zu verwenden, ohne die Tensoren in numplige Arrays konvertieren zu müssen. Allerdings steht es (noch?) Nicht auf der offiziellen Dokumentation. Aus der Release Note ist eine Funktion, die in Keras 2.0.7 eingeführt wurde. Möglicherweise müssen Sie keras> = 2.0.7 installieren, um es zu verwenden.

x = np.arange(4).reshape(-1, 1).astype('float32') 
ds_x = Dataset.from_tensor_slices(x).repeat().batch(4) 
it_x = ds_x.make_one_shot_iterator() 

y = np.arange(5, 9).reshape(-1, 1).astype('float32') 
ds_y = Dataset.from_tensor_slices(y).repeat().batch(4) 
it_y = ds_y.make_one_shot_iterator() 

input_vals = Input(tensor=it_x.get_next()) 
output = Dense(1, activation='relu')(input_vals) 
model = Model(inputs=input_vals, outputs=output) 
model.compile('rmsprop', 'mse', target_tensors=[it_y.get_next()]) 
model.fit(steps_per_epoch=1, epochs=5, verbose=2) 

Mehrere Unterschiede:

  1. Versorgung der tensor Argument der Input Schicht. Keras liest Werte von diesem Tensor und verwendet sie als Eingabe für das Modell.
  2. Geben Sie das target_tensors-Argument an Model.compile().
  3. Denken Sie daran, sowohl x als auch y in float32 zu konvertieren. Bei normaler Verwendung wird Keras diese Konvertierung für Sie vornehmen. Aber jetzt musst du es selbst machen.
  4. Die Losgröße wird bei der Konstruktion Dataset angegeben. Verwenden Sie steps_per_epoch und epochs, um zu steuern, wann die Anpassung des Modells gestoppt werden soll.

kurz, benutze Input(tensor=...), model.compile(target_tensors=...) und model.fit(x=None, y=None, ...), wenn Ihre Daten werden von Tensoren gelesen werden.

+3

Es sieht so aus, als ob es nicht einmal notwendig ist, zwei separate Iteratoren zu haben. Sie können einfach die beiden Datensätze zusammenfügen, einen Knoten wie 'next_val = it.get_next()' erstellen und die Elemente seiner Ausgabe den 'Input()' - und 'Model.compile()' -Funktionen zur Verfügung stellen. – Jason

+0

Was ist mit Iterator-Initialisierung? Kann ich Keras irgendwie sagen, es mit jeder Epoche zu initialisieren? Oder muss ich noch Sitzung erstellen und manuell tun und dann jedes Mal nur eine Epoche ausführen? – backman

1

Neben @ Yu-Yang Antwort, können Sie auch tf.data.Dataset modifizieren einen Generator zu werden für fit_generator folgend

def tfdata_generator(images, labels, batch_size=128, shuffle=True,): 
    def map_func(image, label): 
     '''A transformation function 

     ''' 
     x_train = tf.reshape(tf.cast(image, tf.float32), image_shape) 
     y_train = tf.one_hot(tf.cast(label, tf.uint8), num_classes) 
     return [x_train, y_train] 

    dataset = tf.data.Dataset.from_tensor_slices((images, labels)) 
    dataset = dataset.map(map_func) 
    dataset = dataset.shuffle().batch(batch_size).repeat() 
    iterator = dataset.make_one_shot_iterator() 

    next_batch = iterator.get_next() 
    while True: 
     yield K.get_session().run(next_batch) 

Jetzt haben Sie es als Generator aufrufen können. In diesem Beispiel. Ich habe den mnist-Datensatz verwendet.

from tensorflow.contrib.learn.python.learn.datasets import mnist 

data = mnist.load_mnist() 
model = # your Keras model 

model.fit_generator(generator = tfdata_generator(data.train.images, data.train.labels), 
        steps_per_epoch=200, 
        workers = 0 , # This is important 
        verbose = 1) 
+0

Dies ist AFAIK die einzige Möglichkeit, Keras Validierungsdaten mit dem Parameter validation_data von fit_generator zur Verfügung zu stellen – Warrick

Verwandte Themen