2017-06-05 8 views
1

Ich bin völlig auf der Tensorflow Saver-Methode verloren.Tensorflow saver.restore() nicht wiederherstellen Netzwerk

Ich versuche, das grundlegende Tensorflow Deep Neural Network Model Tutorial zu folgen. Ich möchte herausfinden, wie das Netzwerk für einige Iterationen trainiert wird, und dann das Modell in einer anderen Sitzung laden.

with tf.Session() as sess: 
    graph = tf.Graph() 
    x = tf.placeholder(tf.float32,shape=[None,784]) 
    y_ = tf.placeholder(tf.float32, shape=[None,10]) 

    sess.run(global_variables_initializer()) 

    #Define the Network 
    #(This part is all copied from the tutorial - not copied for brevity) 
    #See here: https://www.tensorflow.org/versions/r0.12/tutorials/mnist/pros/ 

Überspringen zum Training.

#Train the Network 
    train_step = tf.train.AdamOptimizer(1e-4).minimize(
        cross_entropy,global_step=global_step) 
    correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1)) 
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 

    saver = tf.train.Saver() 

    for i in range(101): 
     batch = mnist.train.next_batch(50) 
     if i%100 == 0: 
     train_accuracy = accuracy.eval(feed_dict= 
          {x:batch[0],y_:batch[1]}) 
     print 'Step %d, training accuracy %g'%(i,train_accuracy) 
      train_step.run(feed_dict={x:batch[0], y_: batch[1]}) 
     if i%100 == 0: 
      print 'Test accuracy %g'%accuracy.eval(feed_dict={x: 
         mnist.test.images, y_: mnist.test.labels}) 

     saver.save(sess,'./mnist_model') 

Die Konsole druckt:

Schritt 0, Trainingsgenauigkeit 0,16

Testgenauigkeit 0,0719

Schritt 100 Trainingsgenauigkeit 0,88

Testgenauigkeit 0,8734

Als nächstes möchte ich

das Modell laden
with tf.Session() as sess: 
    saver = tf.train.import_meta_graph('mnist_model.meta') 
    saver.restore(sess,tf.train.latest_checkpoint('./')) 
    sess.run(tf.global_variables_initializer()) 

Jetzt habe ich, um zu sehen, erneut testen wollen, ob das Modell

geladen
print 'Test accuracy %g'%accuracy.eval(feed_dict={x: 
         mnist.test.images, y_: mnist.test.labels}) 

der Konsole druckt:

Prüfgenauigkeit 0,1111

Es scheint nicht, dass das Modell irgendwelche der Daten speichert? Was mache ich falsch?

+0

Sie sollten nicht laufen 'sess.run (tf.global_variables_initializer())' nach dem Wiederherstellen von Gewichten. Dies wird alle Ihre Gewichte zurücksetzen – martianwars

Antwort

0

Wenn Sie Ihre Modelle speichern, werden normalerweise alle globalen Variablen in externen Dateien gespeichert, lokale Variablen dagegen nicht. Sie können sich diesen answer ansehen, um den Unterschied zu verstehen.

Der Fehler in Ihrem Wiederherstellungscode lautet tf.global_variable_initializer()nachsaver.restore(). Die saver.restore docs erwähnen,

Die Variablen wiederherstellen müssen nicht initialisiert wurden, als die Wiederherstellung selbst eine Art und Weise ist Variablen zu initialisieren.

daher versuchen, die Linie zu entfernen,

sess.run(tf.global_variables_initializer()) 

Sie sollten es im Idealfall ersetzen,

sess.run(tf.local_variables_initializer()) 
+1

Danke, das scheint sicherlich mein Problem gelöst zu haben! Wenn in den Dokumenten steht, dass 'saver.restore()' ein Initialisierungsprozess ist, wird 'sess.run (tf.local_variables_initializer()) 'irgendeinen Zweck erfüllen? Dies scheint auch darauf hinzuweisen, dass Tutorials wie [Eine schnelle vollständige Anleitung zum Speichern und Wiederherstellen Tensorflow-Modelle] (http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete- Tutorial /) zeigt falsche Verwendung, nicht wahr? –

+0

Sie sollten ['tf.local_variables()'] (https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/local_variables) überprüfen. Es ist erforderlich, wenn diese Liste nicht leer ist – martianwars

Verwandte Themen