0

Ich baue ein Encoder-Decoder-Modell in TensorFlow 1.0.1 die Legacy-Sequenz zu Sequenz Framework. Alles funktioniert wie es sollte, wenn ich eine Schicht von LSTMs im Encoder und Decoder habe. Allerdings, wenn ich mit> 1 Schichten von LSTMs in einem MultiRNNCell gewickelt versuchen, erhalte ich einen Fehler, wenn tf.contrib.legacy_seq2seq.rnn_decoder Aufruf.TensorFlow wirft nur Fehler, wenn MultiRNNCell mit

Der vollständige Fehler ist am Ende oben diesen Beitrag, aber kurz gesagt, ist es durch eine Linie

(c_prev, m_prev) = state 

in TensorFlow verursacht, die TypeError: 'Tensor' object is not iterable. wirft. Ich bin verwirrt, da der Anfangszustand, den ich an rnn_decoder übergebe, tatsächlich ein Tupel ist, wie es sein sollte. Soweit ich das beurteilen kann, besteht der einzige Unterschied zwischen der Verwendung von 1 oder> 1 Schichten darin, dass die Verwendung von MultiRNNCell verwendet wird. Gibt es einige API-Eigenarten, über die ich bei der Verwendung Bescheid wissen sollte?

Dies ist mein Code (basierend auf dem Beispiel in this GitHub Repo). Entschuldigung für seine Länge; Das ist so minimal, dass ich es schaffen kann, während es noch vollständig und nachprüfbar ist.

import tensorflow as tf 
import tensorflow.contrib.legacy_seq2seq as seq2seq 
import tensorflow.contrib.rnn as rnn 

seq_len = 50 
input_dim = 300 
output_dim = 12 
num_layers = 2 
hidden_units = 100 

sess = tf.Session() 

encoder_inputs = [] 
decoder_inputs = [] 

for i in range(seq_len): 
    encoder_inputs.append(tf.placeholder(tf.float32, shape=(None, input_dim), 
             name="encoder_{0}".format(i))) 

for i in range(seq_len + 1): 
    decoder_inputs.append(tf.placeholder(tf.float32, shape=(None, output_dim), 
             name="decoder_{0}".format(i))) 

if num_layers > 1: 
    # Encoder cells (bidirectional) 
    # Forward 
    enc_cells_fw = [rnn.LSTMCell(hidden_units) 
        for _ in range(num_layers)] 
    enc_cell_fw = rnn.MultiRNNCell(enc_cells_fw) 
    # Backward 
    enc_cells_bw = [rnn.LSTMCell(hidden_units) 
        for _ in range(num_layers)] 
    enc_cell_bw = rnn.MultiRNNCell(enc_cells_bw) 
    # Decoder cell 
    dec_cells = [rnn.LSTMCell(2*hidden_units) 
       for _ in range(num_layers)] 
    dec_cell = rnn.MultiRNNCell(dec_cells) 
else: 
    # Encoder 
    enc_cell_fw = rnn.LSTMCell(hidden_units) 
    enc_cell_bw = rnn.LSTMCell(hidden_units) 
    # Decoder 
    dec_cell = rnn.LSTMCell(2*hidden_units) 

# Make sure input and output are the correct dimensions 
enc_cell_fw = rnn.InputProjectionWrapper(enc_cell_fw, input_dim) 
enc_cell_bw = rnn.InputProjectionWrapper(enc_cell_bw, input_dim) 
dec_cell = rnn.OutputProjectionWrapper(dec_cell, output_dim) 

_, final_fw_state, final_bw_state = \ 
    rnn.static_bidirectional_rnn(enc_cell_fw, 
            enc_cell_bw, 
            encoder_inputs, 
            dtype=tf.float32) 

# Concatenate forward and backward cell states 
# (The state is a tuple of previous output and cell state) 
if num_layers == 1: 
    initial_dec_state = tuple([tf.concat([final_fw_state[i], 
              final_bw_state[i]], 1) 
           for i in range(2)]) 
else: 
    initial_dec_state = tuple([tf.concat([final_fw_state[-1][i], 
              final_bw_state[-1][i]], 1) 
           for i in range(2)]) 

decoder = seq2seq.rnn_decoder(decoder_inputs, initial_dec_state, dec_cell) 

tf.global_variables_initializer().run(session=sess) 

Und das ist der Fehler:

Traceback (most recent call last): 
    File "example.py", line 67, in <module> 
    decoder = seq2seq.rnn_decoder(decoder_inputs, initial_dec_state, dec_cell) 
    File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py", line 150, in rnn_decoder 
    output, state = cell(inp, state) 
    File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py", line 426, in __call__ 
    output, res_state = self._cell(inputs, state) 
    File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py", line 655, in __call__ 
    cur_inp, new_state = cell(cur_inp, cur_state) 
    File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py", line 321, in __call__ 
    (c_prev, m_prev) = state 
    File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 502, in __iter__ 
    raise TypeError("'Tensor' object is not iterable.") 
TypeError: 'Tensor' object is not iterable. 

Thank you!

Antwort

4

Das Problem liegt im Format des Ausgangszustands (initial_dec_state), übergeben an seq2seq.rnn_decoder.

Wenn Sie rnn.MultiRNNCell verwenden, erstellen Sie ein mehrschichtiges rekurrentes Netzwerk. Daher müssen Sie für diese Schichten einen Anfangszustand für jeweils angeben.

Daher sollten Sie eine Liste der Tupel als Anfangszustand angeben, wobei jedes Element der Liste der vorherige Zustand ist, der von der entsprechenden Schicht des wiederkehrenden Netzwerks kommt.

So Ihre initial_dec_state, wie dies initialisiert:

initial_dec_state = tuple([tf.concat([final_fw_state[-1][i], 
             final_bw_state[-1][i]], 1) 
          for i in range(2)]) 

stattdessen sollte wie folgt sein:

initial_dec_state = [ 
        tuple([tf.concat([final_fw_state[j][i],final_bw_state[j][i]], 1) 
          for i in range(2)]) for j in range(len(final_fw_state)) 
         ] 

, die eine Liste von Tupeln in dem Format erstellt:

[(state_c1, state_m1), (state_c2, state_m2) ...] 

In mehr Details, der 'Tensor' object is not iterable. Fehler passiert, weil seq2seq.rnn_decoder ruft intern Ihre rnn.MultiRNNCell (dec_cell), die den Ausgangszustand (initial_dec_state) zu.

rnn.MultiRNNCell.__call__ iteriert durch die Liste der Anfangszustände und extrahiert für jedes einzelne das Tupel (c_prev, m_prev) (in der Anweisung (c_prev, m_prev) = state).

Also, wenn Sie nur ein Tupel übergeben, wird rnn.MultiRNNCell.__call__ über sie iterieren, und sobald es die (c_prev, m_prev) = state erreicht wird es einen Tensor finden (was ein Tupel sein sollte) als state und die 'Tensor' object is not iterable. Fehler werfen.

Ein guter Weg zu wissen, welches Format des Ausgangszustandes seq2seq.rnn_decoder erwartet, ist dec_cell.zero_state(batch_size, dtype=tf.float32) zu rufen. Diese Methode gibt null-gefüllte Zustandstensor (en) in genau dem Format zurück, das benötigt wird, um das von Ihnen verwendete wiederkehrende Modul zu initialisieren.

+0

Große Antwort! Informativ und hilfreich. – AlVaz