2016-04-29 13 views
3

Ich versuche, ein möglichst einfaches LSTM-Netzwerk aufzubauen. Ich will nur, dass es den nächsten Wert in der Sequenz np_input_data vorhersagt.tensorflow: shared Variablen Fehler mit einfachen LSTM-Netzwerk

import tensorflow as tf 
from tensorflow.python.ops import rnn_cell 
import numpy as np 

num_steps = 3 
num_units = 1 
np_input_data = [np.array([[1.],[2.]]), np.array([[2.],[3.]]), np.array([[3.],[4.]])] 

batch_size = 2 

graph = tf.Graph() 

with graph.as_default(): 
    tf_inputs = [tf.placeholder(tf.float32, [batch_size, 1]) for _ in range(num_steps)] 

    lstm = rnn_cell.BasicLSTMCell(num_units) 
    initial_state = state = tf.zeros([batch_size, lstm.state_size]) 
    loss = 0 

    for i in range(num_steps-1): 
     output, state = lstm(tf_inputs[i], state) 
     loss += tf.reduce_mean(tf.square(output - tf_inputs[i+1])) 

with tf.Session(graph=graph) as session: 
    tf.initialize_all_variables().run() 

    feed_dict={tf_inputs[i]: np_input_data[i] for i in range(len(np_input_data))} 

    loss = session.run(loss, feed_dict=feed_dict) 

    print(loss) 

Die Dolmetscher kehrt:

ValueError: Variable BasicLSTMCell/Linear/Matrix already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at: 
    output, state = lstm(tf_inputs[i], state) 

Was kann ich tun, falsch?

Antwort

5

Der Aufruf lstm hier:

for i in range(num_steps-1): 
    output, state = lstm(tf_inputs[i], state) 

werden versuchen, Variablen mit dem gleichen Namen jeder Iteration zu erstellen, wenn Sie es anders sagen. Sie können diese tf.variable_scope

with tf.variable_scope("myrnn") as scope: 
    for i in range(num_steps-1): 
    if i > 0: 
     scope.reuse_variables() 
    output, state = lstm(tf_inputs[i], state)  

Die erste Iteration erzeugt die Verwendung von Variablen tun, die Ihre LSTM Parameter und jede nachfolgende Iteration darstellen (nach dem Aufruf reuse_variables) wird sie nur namentlich im Rahmen nachschlagen.

1

Verwenden Sie tf.nn.rnn oder tf.nn.dynamic_rnn, die dies tun, und viele andere nette Dinge für Sie.

5

Ich stieß auf ein ähnliches Problem in TensorFlow v1.0.1 unter Verwendung tf.nn.dynamic_rnn. Es stellte sich heraus, dass der Fehler nur dann auftrat, wenn ich mitten im Training neu trainieren oder abbrechen musste und meinen Trainingsprozess neu starten musste. Grundsätzlich wurde der Graph nicht zurückgesetzt.

Lange Geschichte kurz, werfen Sie eine tf.reset_default_graph() am Anfang Ihres Codes und es sollte helfen. Zumindest bei Verwendung von tf.nn.dynamic_rnn und Umschulung.

Verwandte Themen