2017-03-01 5 views
3

In meinem aktuellen Projekt trainiere ich ein Modell und speichere Prüfpunkte alle 100 Iterationsschritte. Die Prüfpunktdateien werden alle im selben Verzeichnis gespeichert (model.ckpt-100, model.ckpt-200, model.ckpt-300 usw.). Und danach möchte ich das Modell basierend auf Validierungsdaten für alle gespeicherten Checkpoints evaluieren, nicht nur das neueste.Tensorflow: Modellbewertung über mehrere Prüfpunkte ausführen

Derzeit Code mein Stück die Prüfpunktdatei für die Wiederherstellung sieht wie folgt aus:

ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 
ckpt_list = saver.last_checkpoints 
print(ckpt_list) 
if ckpt and ckpt.model_checkpoint_path: 
    print("Reading model parameters from %s" % ckpt.model_checkpoint_path) 
    saver.restore(sess, ckpt.model_checkpoint_path) 
    # extract global_step from it. 
    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 
    print('Succesfully loaded model from %s at step=%s.' % 
      (ckpt.model_checkpoint_path, global_step)) 
else: 
    print('No checkpoint file found') 
    return 

Dies ist jedoch nur die zuletzt gespeicherte Datei Prüfpunkt wieder her. Wie schreibe ich eine Schleife über alle gespeicherten Prüfpunktdateien? Ich habe versucht, eine Liste der Checkpoint-Dateien mit saver.last_checkpoints zu erhalten, aber die zurückgegebene Liste ist leer.

Jede Hilfe wäre sehr geschätzt, danke im Voraus!

+0

Wie speichern Sie das Modell genau? Bilden Sie den Namen für die Ausgabedatei selbst oder verwenden Sie den Parameter 'global_step' beim Aufruf von' saver.save (..) '? – kaufmanu

Antwort

1

Sie können in dem Verzeichnis durch die Dateien iterieren:

import os 

dir_path = './' #change that to wherever your files are 
ckpt_files = [f for f in os.listdir(dir_path) if os.path.isfile(
    os.path.join(dir_path, f)) and 'ckpt' in f] 

for ckpt_file in ckpt_files: 
    saver.restore(sess, dir_path + ckpt_file) 
    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 
    print('Succesfully loaded model from %s at step=%s.' % 
      (ckpt.model_checkpoint_path, global_step)) 

    # Do your thing 

weitere Bedingungen in dem obigen Liste Verständnis hinzuzufügen selektiver zu sein wie: and 'meta' not in f und so weiter, je nachdem, was in diesem Verzeichnis ist und die Sparer-Version haben

0

Danke dafür. Allerdings erhalte ich die Fehler

"NotFoundError (siehe oben für Traceback): Key CONV2/Vorurteile/ExponentialMovingAverage nicht in Kontrollpunkt gefunden"

wo CONV2/spannt einen variablen Umfang ist. Ich benutze die Sparversion v2.

Inzwischen habe ich versucht, eine andere (etwas einfacher Code) und bekam den gleichen Fehler:

fileBaseName = FLAGS.checkpoint_dir + '/model.ckpt-' 

    for global_step in range(0,100,10): # range over the global steps where checkpoints were saved 
    x_str = str(global_step) 
    fileName = fileBaseName+x_str 
    print(fileName) 
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 

    #restore checkpoint file 
    saver.restore(sess, fileName) 

Der Fehler der Code in diesem Stück tatsächlich auftritt (bei variables_to_restore =):

# Restore the moving average version of the learned variables for eval. 
variable_averages = tf.train.ExponentialMovingAverage(
    MOVING_AVERAGE_DECAY) 
variables_to_restore = variable_averages.variables_to_restore() 
saver = tf.train.Saver(variables_to_restore) 

Ich habe keine Ahnung, wie ich diesen Fehler beheben kann. Könnte es etwas mit der Sparversion zu tun haben? Oder muss der Fehler in dem Teil sein, in dem die Checkpoints gespeichert sind?

Vielen Dank. TheJude

Verwandte Themen