2016-01-26 6 views
16

Ich habe ein Setup, wo ich ein LSTM nach der Hauptinitialisierung initialisieren muss, die tf.initialize_all_variables() verwendet. I.e. Ich möchte tf.initialize_variables([var_list])Tensorflow: Wie bekomme ich alle Variablen von rnn_cell.BasicLSTM & rnn_cell.MultiRNNCell

nennen, ist es Weg, um alle internen trainierbar Variablen für beide zu sammeln:

  • rnn_cell.BasicLSTM
  • rnn_cell.MultiRNNCell

so dass ich GERADE initialisieren diese Parameter?

Der Hauptgrund, warum ich das möchte, ist, weil ich einige trainierte Werte von früher nicht neu initialisieren möchte.

Antwort

17

Der einfachste Weg zur Lösung Ihres Problems ist die Verwendung eines variablen Bereichs. Den Namen der Variablen in einem Bereich wird der Name vorangestellt. Hier ist ein kurzer Ausschnitt:

cell = rnn_cell.BasicLSTMCell(num_nodes) 

with tf.variable_scope("LSTM") as vs: 
    # Execute the LSTM cell here in any way, for example: 
    for i in range(num_steps): 
    output[i], state = cell(input_data[i], state) 

    # Retrieve just the LSTM variables. 
    lstm_variables = [v for v in tf.all_variables() 
        if v.name.startswith(vs.name)] 

# [..] 
# Initialize the LSTM variables. 
tf.initialize_variables(lstm_variables) 

Es wäre mit MultiRNNCell die gleiche Art und Weise arbeiten.

EDIT: geändert tf.trainable_variables zu tf.all_variables()

+0

Das ist perfekt, danke. Ich habe nicht gemerkt, dass 'tf.trainable_variables()' den Umfang respektiert, aber ich denke, im Nachhinein macht es Sinn! – bge0

+1

Möchte hinzufügen, dass 'tf.all_variables()' anstelle von 'tf.trainable_variables()' wäre eine bessere Wahl. Vor allem, weil es Dinge wie Optimierer gibt, die keine trainierbaren Variablen haben, die allerdings noch initialisiert werden müssten. – bge0

+1

Danke, Sie haben Recht. Ich habe den Code aktualisiert. –

11

Sie auch tf.get_collection() verwenden können:

cell = rnn_cell.BasicLSTMCell(num_nodes) 
with tf.variable_scope("LSTM") as vs: 
    # Execute the LSTM cell here in any way, for example: 
    for i in range(num_steps): 
    output[i], state = cell(input_data[i], state) 

    lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name) 

(zum Teil aus Rafal Antwort kopiert)

Beachten Sie, dass die letzte Zeile in die Liste Verständnis entspricht in Rafals Code.

Grundsätzlich speichert Tensorflow eine globale Sammlung von Variablen, die entweder durch tf.all_variables() oder tf.get_collection(tf.GraphKeys.VARIABLES) abgerufen werden können. Wenn Sie in der Funktion tf.get_collection()scope (Bereichsname) angeben, werden in der Auflistung nur Tensoren (Variablen in diesem Fall) abgerufen, deren Bereiche sich im angegebenen Bereich befinden.

EDIT: Sie können auch tf.GraphKeys.TRAINABLE_VARIABLES verwenden, um nur trainierbare Variablen zu erhalten. Aber da vanilla BasicLSTMCell keine nicht trainierbare Variable initialisiert, werden beide funktional äquivalent sein. Für eine vollständige Liste der Standard-Graph-Sammlungen, überprüfen Sie this heraus.

+0

konsistent zu sein Dies ist der bessere Weg als Rafals Lösung :-) –

+1

Genau wie ich oben erwähnt, sollten Sie vielleicht besser verwenden ' tf.get_collection (..., scope = vs.name + "/") 'weil es einen anderen Bereich namens" LSTM2 "geben könnte. – Albert

Verwandte Themen