Ich bin völlig auf der Tensorflow Saver-Methode verloren.Tensorflow saver.restore() nicht wiederherstellen Netzwerk
Ich versuche, das grundlegende Tensorflow Deep Neural Network Model Tutorial zu folgen. Ich möchte herausfinden, wie das Netzwerk für einige Iterationen trainiert wird, und dann das Modell in einer anderen Sitzung laden.
with tf.Session() as sess:
graph = tf.Graph()
x = tf.placeholder(tf.float32,shape=[None,784])
y_ = tf.placeholder(tf.float32, shape=[None,10])
sess.run(global_variables_initializer())
#Define the Network
#(This part is all copied from the tutorial - not copied for brevity)
#See here: https://www.tensorflow.org/versions/r0.12/tutorials/mnist/pros/
Überspringen zum Training.
#Train the Network
train_step = tf.train.AdamOptimizer(1e-4).minimize(
cross_entropy,global_step=global_step)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
saver = tf.train.Saver()
for i in range(101):
batch = mnist.train.next_batch(50)
if i%100 == 0:
train_accuracy = accuracy.eval(feed_dict=
{x:batch[0],y_:batch[1]})
print 'Step %d, training accuracy %g'%(i,train_accuracy)
train_step.run(feed_dict={x:batch[0], y_: batch[1]})
if i%100 == 0:
print 'Test accuracy %g'%accuracy.eval(feed_dict={x:
mnist.test.images, y_: mnist.test.labels})
saver.save(sess,'./mnist_model')
Die Konsole druckt:
Schritt 0, Trainingsgenauigkeit 0,16
Testgenauigkeit 0,0719
Schritt 100 Trainingsgenauigkeit 0,88
Testgenauigkeit 0,8734
Als nächstes möchte ich
das Modell ladenwith tf.Session() as sess:
saver = tf.train.import_meta_graph('mnist_model.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
sess.run(tf.global_variables_initializer())
Jetzt habe ich, um zu sehen, erneut testen wollen, ob das Modell
geladenprint 'Test accuracy %g'%accuracy.eval(feed_dict={x:
mnist.test.images, y_: mnist.test.labels})
der Konsole druckt:
Prüfgenauigkeit 0,1111
Es scheint nicht, dass das Modell irgendwelche der Daten speichert? Was mache ich falsch?
Sie sollten nicht laufen 'sess.run (tf.global_variables_initializer())' nach dem Wiederherstellen von Gewichten. Dies wird alle Ihre Gewichte zurücksetzen – martianwars