Ich habe ein CNN für die Bildklassifizierung gebaut. Während des Trainings habe ich mehrere Checkpoints gespeichert. Die Daten werden über ein feed_dictionary in das Netzwerk eingespeist.Tensorflow beschwert sich über fehlende Feed_dict während Graph Restore
Jetzt möchte ich das Modell wiederherstellen, das fehlschlägt und ich kann nicht herausfinden, warum. Die wichtigen Codezeilen sind wie folgt:
with tf.Graph().as_default():
....
if checkpoint_dir is not None:
checkpoint_saver = tf.train.Saver()
session_hooks.append(tf.train.CheckpointSaverHook(checkpoint_dir,
save_secs=flags.save_interval_secs,
saver=checkpoint_saver))
....
with tf.train.MonitoredTrainingSession(
save_summaries_steps=flags.save_summaries_steps,
hooks=session_hooks,
config=tf.ConfigProto(
log_device_placement=flags.log_device_placement)) as mon_sess:
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
if checkpoint and checkpoint.model_checkpoint_path:
# restoring from the checkpoint file
checkpoint_saver.restore(mon_sess, checkpoint.model_checkpoint_path)
global_step_restore = checkpoint.model_checkpoint_path.split('/')[-1].split('-')[-1]
print("Model restored from checkpoint: global_step = %s" % global_step_restore)
The Line "checkpoint_saver.restore" wirft einen Fehler:
Traceback (jüngste Aufforderung zuletzt): Datei „C: \ Programme \ Anaconda3 \ envs \ tensorflow \ lib \ site-packages \ tensorfluss \ python \ client \ session.py ", Zeile 1022, in _do_call return fn (* args) Datei" C: \ Programme \ Anaconda3 \ envs \ tensorflow \ lib \ site-packages \ tensorflow \ python \ client \ session.py ", Zeile 1004, in _run_fn status, run_metadata) Datei" C: \ Programme \ Anaconda3 \ envs \ tensorflow \ lib \ contextlib.py ", Zeile 6 6, in Ausfahrt nächste (self.gen) Datei "C: \ Programme \ Anaconda3 \ envs \ Tensorfluss \ lib \ Site-Pakete \ Tensorfluss \ Python \ Framework \ Fehler_impl.py", Zeile 469, in Raise_exception_on_not_ok_status pywrap_tensorflow.TF_GetCode (Status)) tensorflow.python.framework.errors_impl.InvalidArgumentError: Sie müssen einen Wert für Platzhalter Tensor 'input_images' füttern mit dtype Schwimmer [[Knoten: input_images = Placeholderdtype = DT_FLOAT, Form = [], _device = "/ job: localhost/replik: 0/task: 0/cpu: 0"]]
Jeder weiß, wie man das löst? Warum brauche ich ein gefülltes feed_dictionary nur für die Wiederherstellung der Grafik?
Vielen Dank im Voraus!
Update:
Dies ist der Code des Wiederherstellungsmethode des Schoners Objekt:
def restore(self, sess, save_path):
"""Restores previously saved variables.
This method runs the ops added by the constructor for restoring variables.
It requires a session in which the graph was launched. The variables to
restore do not have to have been initialized, as restoring is itself a way
to initialize variables.
The `save_path` argument is typically a value previously returned from a
`save()` call, or a call to `latest_checkpoint()`.
Args:
sess: A `Session` to use to restore the parameters.
save_path: Path where parameters were previously saved.
"""
if self._is_empty:
return
sess.run(self.saver_def.restore_op_name,
{self.saver_def.filename_tensor_name: save_path})
Was ich nicht bekommen: Warum der Graph sofort ausgeführt wird? Benutze ich die falsche Methode? Ich möchte nur alle trainierbaren Vars wiederherstellen.
Benennen Sie alle Variablen und Platzhalter. Ist das hilfreich? http://stackoverflow.com/questions/34793978/tensorflow-complaining-about-placeholder-after-model-restore – hars
Alle Vars sind benannt. Der Input-Feed für meinen Bildtensor fehlt. Ich denke, dass das Problem durch die kombinierte Verwendung von MonitoredTrainingSession und einem feed_dict verursacht wird. MonitoredTrainingSession ist für größere Setups gedacht und möglicherweise nicht kompatibel mit Feed-Dictionarys?!?. Ich versuche, einen Testfall für mein benutzerdefiniertes "Training Framework" zu erstellen.Daher möchte ich das Beispielmodell leichtgewichtig halten (benutze ein feed_dict statt einer Import-Queue) – monchi