Sie auch tf.get_collection()
verwenden können:
cell = rnn_cell.BasicLSTMCell(num_nodes)
with tf.variable_scope("LSTM") as vs:
# Execute the LSTM cell here in any way, for example:
for i in range(num_steps):
output[i], state = cell(input_data[i], state)
lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)
(zum Teil aus Rafal Antwort kopiert)
Beachten Sie, dass die letzte Zeile in die Liste Verständnis entspricht in Rafals Code.
Grundsätzlich speichert Tensorflow eine globale Sammlung von Variablen, die entweder durch tf.all_variables()
oder tf.get_collection(tf.GraphKeys.VARIABLES)
abgerufen werden können. Wenn Sie in der Funktion tf.get_collection()
scope
(Bereichsname) angeben, werden in der Auflistung nur Tensoren (Variablen in diesem Fall) abgerufen, deren Bereiche sich im angegebenen Bereich befinden.
EDIT: Sie können auch tf.GraphKeys.TRAINABLE_VARIABLES
verwenden, um nur trainierbare Variablen zu erhalten. Aber da vanilla BasicLSTMCell keine nicht trainierbare Variable initialisiert, werden beide funktional äquivalent sein. Für eine vollständige Liste der Standard-Graph-Sammlungen, überprüfen Sie this heraus.
Das ist perfekt, danke. Ich habe nicht gemerkt, dass 'tf.trainable_variables()' den Umfang respektiert, aber ich denke, im Nachhinein macht es Sinn! – bge0
Möchte hinzufügen, dass 'tf.all_variables()' anstelle von 'tf.trainable_variables()' wäre eine bessere Wahl. Vor allem, weil es Dinge wie Optimierer gibt, die keine trainierbaren Variablen haben, die allerdings noch initialisiert werden müssten. – bge0
Danke, Sie haben Recht. Ich habe den Code aktualisiert. –