2016-12-20 5 views
1

Ich versuche, einige Variablen zu speichern und sehen, ob ich es später wiederherstellen könnte. Hier ist mein Sparcode:Tensorflow speichert nur den initialisierten Wert

import tensorflow as tf; 
    my_a = tf.Variable(2,name = "my_a"); 
    my_b = tf.Variable(3,name = "my_b"); 
    my_c = tf.Variable(4,name = "my_c"); 
    my_c = tf.add(my_a,my_b); 

    with tf.Session() as sess: 
     init = tf.initialize_all_variables(); 
     sess.run(init); 
     print("my_c = ",sess.run(my_c)); 
     saver = tf.train.Saver(); 
     saver.save(sess,"test.ckpt"); 

Dies gibt:

my_c = 5 

Wie ich es wiederherstellen:

import tensorflow as tf; 
    c = tf.Variable(3100,dtype = tf.int32); 
    with tf.Session() as sess: 
     sess.run(tf.initialize_all_variables()); 
     saver = tf.train.Saver({"my_c":c}); 
     saver.restore(sess, "test.ckpt"); 
     cc= sess.run(c); 
     print(cc); 

Das gibt mir:

4 

Die restaurierte Wert von my_c sollte sei 5, da es die Summe von my_a und my_b ist. Jedoch gibt es mir 4, was der initialisierte Wert von my_c ist. Könnte jemand erklären, warum das passiert und wie man die Änderungen in einer Variablen speichert?

Antwort

2

In Ihrem ursprünglichen Code haben Sie die Variable my_c (beachten Sie, TensorFlow name) nicht wirklich zu my_a + my_b zugewiesen.

Durch das Schreiben my_c = tf.add(my_a,my_b) unterscheidet sich die Python-Variable my_c jetzt von der mit name='my_c'.

Wenn Sie sess.run() ausführen, führen Sie nur die Operation aus und aktualisieren diese Variable nicht.

Wenn Sie diesen Code wollen korrekt funktionieren, verwenden Sie diese stattdessen - (siehe die Kommentare für Änderungen)

import tensorflow as tf 
my_a = tf.Variable(2,name = "my_a") 
my_b = tf.Variable(3,name = "my_b") 
my_c = tf.Variable(4,name="my_c") 
# Use the assign() function to set the new value 
add = my_c.assign(tf.add(my_a,my_b)) 

with tf.Session() as sess: 
    init = tf.initialize_all_variables() 
    sess.run(init) 
    # Execute the add operator 
    sess.run(add) 
    print("my_c = ",sess.run(my_c)) 
    saver = tf.train.Saver() 
    saver.save(sess,"test.ckpt") 
+0

Danke für die Hilfe. – lina

+0

Mit Ihrem Code meldet er den Fehler beim Laden des Modells: NotFoundError (siehe oben für Traceback): Tensor Name "my_c" nicht in Checkpoint-Dateien gefunden test.ckpt – lina

+0

Entschuldigung, überprüfen Sie die neue Antwort. Meine alte Antwort bestand nicht darin, eine Variable zu erstellen. Nur Variablen werden von TensorFlow gespeichert, nicht Tensoren – martianwars

Verwandte Themen