Ich habe auf die Wiederherstellung TF
Modelle und die Google
doc Seite auf exporting graphs ein paar posts gesehen, aber ich glaube, ich bin etwas fehlt.Wiederherstellen gespeichert TensorFlow Modell im Test zu bewerten gesetzt
Ich verwende den Code in diesem Gist das Modell speichern zusammen mit dieser Datei das Modell auf dem defines utils
Jetzt würde ich es wiederherstellen möchte und in eine bisher unerreichte Testdaten wie folgt durchgeführt:
def evaluate(X_data, y_data):
num_examples = len(X_data)
total_accuracy = 0
total_loss = 0
sess = tf.get_default_session()
acc_steps = len(X_data) // BATCH_SIZE
for i in range(acc_steps):
batch_x, batch_y = next_batch(X_val, Y_val, BATCH_SIZE)
loss, accuracy = sess.run([loss_value, acc], feed_dict={
images_placeholder: batch_x,
labels_placeholder: batch_y,
keep_prob: 0.5
})
total_accuracy += (accuracy * len(batch_x))
total_loss += (loss * len(batch_x))
return (total_accuracy/num_examples, total_loss/num_examples)
## re-execute the code that defines the model
# Image Tensor
images_placeholder = tf.placeholder(tf.float32, shape=[None, 32, 32, 3], name='x')
gray = tf.image.rgb_to_grayscale(images_placeholder, name='gray')
gray /= 255.
# Label Tensor
labels_placeholder = tf.placeholder(tf.float32, shape=(None, 43), name='y')
# dropout Tensor
keep_prob = tf.placeholder(tf.float32, name='drop')
# construct model
logits = inference(gray, keep_prob)
# calculate loss
loss_value = loss(logits, labels_placeholder)
# training
train_op = training(loss_value, 0.001)
# accuracy
acc = accuracy(logits, labels_placeholder)
with tf.Session() as sess:
loader = tf.train.import_meta_graph('gtsd.meta')
loader.restore(sess, tf.train.latest_checkpoint('./'))
sess.run(tf.initialize_all_variables())
test_accuracy = evaluate(X_test, y_test)
print("Test Accuracy = {:.3f}".format(test_accuracy[0]))
Ich bekomme eine Testgenauigkeit von nur 3%. Wenn ich jedoch das Notebook nicht schließe und den Testcode sofort nach dem Training des Modells abspiele, bekomme ich eine 95% Genauigkeit.
Das führt zu der Annahme, dass ich das Modell nicht korrekt lade?
Dank @mrry ich diesen jetzt –
versuchen wird, wie Sie erwartet, TF eine Fehlermeldung über einen nicht initialisierten Variable führt. Wenn ich diese Linie nach oben ziehe, wie Sie es vorgeschlagen haben, gibt es immer noch nur eine Genauigkeit von 2%, also beginnt sie mit dem Beginn. –
Oh, ich habe ein anderes Problem bemerkt! 'tf.train.import_meta_graph()' lädt eine ** zweite Kopie ** der Modellstruktur in das aktuelle Diagramm. Wenn der Code vor dem Erstellen der "tf.Session" eine Kopie des Diagramms erstellt (einschließlich aller Gewichtungen), bleiben * diese * Gewichtungen nicht initialisiert und nur die Gewichtungen in der zweiten Kopie werden wiederhergestellt. Es gibt zwei Möglichkeiten, dies zu umgehen: (1) Anstatt "tf.train.import_meta_graph()' 'zu verwenden, erstellen Sie direkt einen' tf.train.Saver' und verwenden Sie ihn, um den Checkpoint in die erste Kopie des Graphen wiederherzustellen; oder ... – mrry