2017-09-24 1 views
1

Tensorflow Neuling hier! Ich verstehe, dass Variablen im Laufe der Zeit trainiert werden, Platzhalter sind Eingabedaten, die sich nicht ändern, wenn Ihr Modell trainiert (wie Eingabebilder und Klassenbezeichnungen für diese Bilder).tf.zeros vs tf.placeholder als RNN ursprünglichen Zustand

Ich versuche, die Weiterleitung von RNN mit Tensorflow zu implementieren, und frage mich, welchen Typ ich die Ausgabe der RNN-Zelle speichern sollte. In numpy RNN Implementierung verwendet es

hiddenStates = np.zeros((T, self.hidden_dim)) #T is the length of the sequence

Dann iterativ es die Ausgabe in dem np.zeros Array speichert.

Bei TF, welchen soll ich verwenden, tf.zeros oder tf.placeholder?

Was ist die beste Vorgehensweise in diesem Fall? Ich denke, es sollte in Ordnung sein, tf.zeros zu verwenden, aber wollte überprüfen.

Antwort

2

Zunächst ist es wichtig für Sie zu verstehen, dass alles im Tensorflow ein Tensor ist. Wenn Sie also eine Art von Berechnung durchführen (z. B. eine rnn-Implementierung wie outputs = rnn(...)), wird die Ausgabe dieser Berechnung als Tensor zurückgegeben. Sie müssen es also nicht in irgendeiner Struktur speichern. Sie können es abrufen, indem Sie den entsprechenden Knoten (d. H. output) wie session.run(output, feed_dict) ausführen.

Sagte, ich denke, Sie müssen den endgültigen Zustand eines RNN nehmen und es als Ausgangszustand einer nachfolgenden Berechnung bereitstellen. Zwei Möglichkeiten:

A) Wenn Sie mit RNNCell Implementierungen Beim Bau des Modells können Sie den Null-Zustand wie folgt konstruieren:

cell = (some RNNCell implementation) 
initial_state = cell.zero_state(batch_size, tf.float32) 

B) Wenn Sie uimplementing Ihre eigenen Mitarbeiter definieren Sie den Zustand als Null Tensor:

initial_state = tf.zeros([batch_size, hidden_size]) 

Dann in beiden Fällen Sie so etwas wie haben:

output, final_state = rnn(input, initial_state) 

In Ihrer Ausführungsschleife können Sie Ihr Zustand zuerst initialisieren und dann bieten die final_state als initial_state in Ihrem feed_dict:

state = session.run(initial_state) 
for step in range(epochs): 

    feed_dict = {initial_state: state} 
    _, state = session.run((train_op,final_state), feed_dict) 

Wie Sie tatsächlich bauen Ihre feed_dict über die Umsetzung des RNN abhängt.

Für eine BasicLSTMCell zum Beispiel ein Zustand ist ein LSTMState Objekt und Sie müssen beide c bieten und h:

feed_dict = {initial_state.c=state.c, initial_state.h: state.h} 
Verwandte Themen