2016-04-12 3 views
1

Ich möchte den Wert des Verlusts drucken, der von einem Optimierer minimiert wurde. Hier ein Beispiel:Den Wert des Verlusts drucken, der von einem Optimierer in TensorFlow minimiert wird

LEARNING_RATE = 0.0001 
MOMENTUM = 0.999 

mean_squared_error = tf.reduce_mean(tf.square(tf.sub(predictions, training_outputs))) 
train_step = tf.train.MomentumOptimizer(LEARNING_RATE, MOMENTUM).minimize(mean_squared_error) 

# Load data 
features = ... 
labels = ... 

# Launch TensorFlow session 
with tf.Session() as session: 
    session.run(initialize) 

    print("Begin training...") 
    session.run(train_step, feed_dict={training_inputs: features, training_outputs: labels}) 
    print("Finished training! The mean squared error is: _____") 

Nun, da ich mean_squared_error minimiert haben, wie kann ich seine minimierten Wert drucken?

Antwort

2

Der einfachste Weg, um einen Verlust zu visualisieren ist eine skalare Zusammenfassung davon zu erstellen:

mean_squared_error = tf.reduce_mean(tf.square(tf.sub(predictions, training_outputs))) 
loss_summ = tf.scalar_summary("loss", mean_squared_error) 

Sie dann erstellen Sie ein Schriftsteller in der TensorFlow Sitzung und die Zusammenfassung loss_summ zum sess.run() Anruf hinzuzufügen. Sie erhalten dann den Wert zurück in mse_val und können es ausdrucken.

with tf.Session() as session: 
    writer = tf.train.SummaryWriter("log", session.graph_def) 
    session.run(initialize) 

    print("Begin training...") 
    _, mse_val, summ = session.run([train_step, mean_squared_error, loss_summ], feed_dict={training_inputs: features, training_outputs: labels}) 
    writer.add_summary(summ) 
    print("Finished training! The mean squared error is: %f" % mse_val) 

Als Bonus können Sie auch die Entwicklung des Verlustes in TensorBoard, visualisieren tensorboard --logdir log durch Ausführen (lesen this tutorial für weitere Details).

P.S: Ihr Code läuft nur 1 Iteration des Trainings, möchten Sie vielleicht eine Schleife hinzufügen.

+0

Ich habe Ihren Kommentar aktualisiert, und ich werde ihn akzeptieren, sobald ich Ihre Vorschläge ausprobiere und finde, dass sie meine Bedürfnisse erfüllen. Vielen Dank für Ihre Hilfe! Ich schätze es wirklich. –

+0

Ich habe Ihren Code ausgeführt, und ich erhielt einen TypeError: 'Traceback (letzter Aufruf zuletzt): Datei" /home/me/PycharmProjects/Test/Prediction.py ", Zeile 77, in drucken (" Fertig Training! Der mittlere quadratische Fehler ist:% f "% mean_squared_error_value) TypeError: Float-Argument erforderlich, nicht str'. Ich habe "% f" durch "% s" ersetzt, aber stattdessen wurde Kauderwelsch gedruckt ... –

+0

Entschuldigung, Sie sollten 'mean_squared_error' in' sess.run() 'hinzufügen, um seinen Wert zu erhalten. Ich habe meine erste Antwort mit den Änderungen bearbeitet, danke! –

Verwandte Themen