2017-04-07 5 views
1

Ich bin Anfänger mit maschinellem Lernen und Tensorflow, mit seinem Beispiel Tutorial Quellcode, das Modell geschult und gedruckt die Genauigkeit, aber es enthält keinen Quellcode, um das Modell zu exportieren und Variablen und Import für ein neues Bild vorhersagen.Tensorflow Bericht Fehler mit dem wiederhergestellten Trainingsmodell

Also habe ich den Quellcode überarbeitet, um das Modell zu exportieren, und ein neues Python-Skript erstellen, um mithilfe des Testdatensatzes vorherzusagen.

Hier ist der Quellcode des Trainingsmodell zu exportieren: Die y-Funktion

sess = tf.Session() 
saver = tf.train.import_meta_graph('result.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
W = tf.get_collection('W')[0] 
b = tf.get_collection('b')[0] 
y = tf.get_collection('y')[0] 


mnist = input_data.read_data_sets('/tmp/tensorflow/mnist/input_data', one_hot=True) 
img = mnist.test.images[0] 
x = tf.placeholder(tf.float32, [None, 784]) 
sess.run(y, feed_dict={x: mnist.test.images}) 

Alles funktioniert

mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) 
print("run here3") 
# Create the model 
x = tf.placeholder(tf.float32, [None, 784], name="x") 
W = tf.Variable(tf.zeros([784, 10]), name="W") 
b = tf.Variable(tf.zeros([10])) 
y = tf.matmul(x, W) + b 
saver = tf.train.Saver() 
sess = tf.InteractiveSession() 
... ignore the source code for the cost function definition and train the model 
#after the model get trained, save the variables and y 
tf.add_to_collection('W', W) 
tf.add_to_collection('b', b) 
tf.add_to_collection('y', y) 

saver.save(sess, 'result') 

Im neuen Python-Skript, ich versuche das Modell wiederherzustellen und erneut auszuführen korrekt, ich könnte den W- und b-Wert bekommen, wenn ich sie drucke, jedoch bekomme ich Fehler beim Ausführen der letzten Anweisung (run y function).

Caused by op u'x', defined at: 
File "predict.py", line 58, in <module> 
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 
File "/Users/zhouqi/git/machine-learning/tensorflow/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run 
_sys.exit(main(_sys.argv[:1] + flags_passthrough)) 
File "predict.py", line 25, in main 
saver = tf.train.import_meta_graph('result.meta') 
File "/Users/zhouqi/git/machine-learning/tensorflow/lib/python2.7/site- packages/tensorflow/python/training/saver.py", line 1566, in import_meta_graph 
**kwargs) 
File "/Users/zhouqi/git/machine-learning/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/meta_graph.py", line 498, in import_scoped_meta_graph 
producer_op_list=producer_op_list) 
File "/Users/zhouqi/git/machine-learning/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 288, in import_graph_def 
op_def=op_def) 
File "/Users/zhouqi/git/machine-learning/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2327, in create_op 
original_op=self._default_original_op, op_def=op_def) 
File "/Users/zhouqi/git/machine-learning/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1226, in __init__ 
self._traceback = _extract_stack() 

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'x' with dtype float 
[[Node: x = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]] 

Es ist komisch, weil ich die gleiche Aussage verwenden x zu definieren, und führen Sie den gleichen Ansatz x Verwendung während y-Funktion ausgeführt wird, ich weiß nicht, warum es nicht funktioniert?

+0

Warum haben Sie diese Zeile: 'mnist = input_data.read_data_sets (FLAGS.data_dir, one_hot = True)' 'mit mnist = input_data.read_data_sets ('/ tmp/tensorflow/mnist/INPUT_DATA', one_hot = True)' ? – tagoma

+0

Oh, für das neue Skript, das ich erstellt habe, um das Modell wiederherzustellen, vereinfache ich es, den hartcodierten Datenordner zu verwenden, FLAGS.data_dir ist dasselbe wie/tmp/tensorflow/mnist/input_data. – mailme365

Antwort

1

Das Problem ist der neue Platzhalter:

x = tf.placeholder(tf.float32, [None, 784]) 

einen Platzhalter mit dem gleichen Namen zu schaffen, ist nicht genug. Sie benötigen den gleichen Platzhalter, den Sie beim Erstellen des Modells verwendet haben.

tf.add_to_collection('x', x) 

und es in der neuen Datei laden: Sie müssen daher auch x zu einer Sammelstelle in der ersten Datei hinzufügen

x = tf.get_collection('x')[0] 

anstatt einen neuen zu erstellen.

Verwandte Themen