2017-12-11 3 views
3

standardmäßig funktionieren dynamic_rnn Ausgänge nur versteckte Zustände (wie m bekannt) für jeden Zeitpunkt die wie folgt erhältlich:Tensorflow: Wie man intermediate Zellenzustände (c) von LSTMCell mit dynamic_rnn erhalten?

cell = tf.contrib.rnn.LSTMCell(100) 
rnn_outputs, _ = tf.nn.dynamic_rnn(cell, 
            inputs=inputs, 
            sequence_length=sequence_lengths, 
            dtype=tf.float32) 

Gibt es eine Möglichkeit Zwischen bekommen (nicht abschliessend) Zellzustände (c) in Zusatz?

A tensorflow Beitrag mentions, dass es mit einer Zelle Wrapper erfolgen:

class Wrapper(tf.nn.rnn_cell.RNNCell): 
    def __init__(self, inner_cell): 
    super(Wrapper, self).__init__() 
    self._inner_cell = inner_cell 
    @property 
    def state_size(self): 
    return self._inner_cell.state_size 
    @property 
    def output_size(self): 
    return (self._inner_cell.state_size, self._inner_cell.output_size) 
    def call(self, input, state) 
    output, next_state = self._inner_cell(input, state) 
    emit_output = (next_state, output) 
    return emit_output, next_state 

aber es scheint nicht zu funktionieren. Irgendwelche Ideen?

Antwort

2

Die vorgeschlagene Lösung funktioniert für mich, aber Layer.call Methode Spezifikation ist allgemeiner, so dass die folgenden Wrapper sollte robuster API Änderungen sein. dies Thy:

class Wrapper(tf.nn.rnn_cell.RNNCell): 
    def __init__(self, inner_cell): 
    super(Wrapper, self).__init__() 
    self._inner_cell = inner_cell 

    @property 
    def state_size(self): 
    return self._inner_cell.state_size 

    @property 
    def output_size(self): 
    return (self._inner_cell.state_size, self._inner_cell.output_size) 

    def call(self, input, *args, **kwargs): 
    output, next_state = self._inner_cell(input, *args, **kwargs) 
    emit_output = (next_state, output) 
    return emit_output, next_state 

Hier ist der Test:

n_steps = 2 
n_inputs = 3 
n_neurons = 5 

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs]) 
basic_cell = Wrapper(tf.nn.rnn_cell.LSTMCell(num_units=n_neurons, state_is_tuple=False)) 
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32) 
print(outputs, states) 

X_batch = np.array([ 
    # t = 0  t = 1 
    [[0, 1, 2], [9, 8, 7]], # instance 0 
    [[3, 4, 5], [0, 0, 0]], # instance 1 
    [[6, 7, 8], [6, 5, 4]], # instance 2 
    [[9, 0, 1], [3, 2, 1]], # instance 3 
]) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    outputs_val = outputs[0].eval(feed_dict={X: X_batch}) 
    print(outputs_val) 

Zurück outputs ist das Tupel von (?, 2, 10) und (?, 2, 5) Tensoren, die alle LSTM Zustände und Ausgänge sind. Beachten Sie, dass ich die "abgestufte" Version von LSTMCell, von Paket, nicht tf.contrib.rnn verwende. Beachten Sie auch state_is_tuple=True, um den Umgang mit LSTMStateTuple zu vermeiden.

0

Basierend auf Maxims Idee, landete ich mit folgenden Lösung:

class StatefulLSTMCell(LSTMCell): 
    def __init__(self, *args, **kwargs): 
     super(StatefulLSTMCell, self).__init__(*args, **kwargs) 

    @property 
    def output_size(self): 
     return (self.state_size, super(StatefulLSTMCell, self).output_size) 

    def call(self, input, state): 
     output, next_state = super(StatefulLSTMCell, self).call(input, state) 
     emit_output = (next_state, output) 
     return emit_output, next_state 
Verwandte Themen