2017-12-29 4 views
1

Mein Ziel ist einfach und klar: Nach der Graph teilweise modifiziert wird, wie die unveränderten Variablen/Parameter aus früherem Protokoll der Checkpoint-Datei (besser mit MonitoredTrainingSession)In Tensorflow, wenn Grafik geändert wird, wie "MonitoredTrainingSession" nur einen Teil des Prüfpunkts wiederherstellen?

ich auf dem Code, um einen Test von hier machen wiederzuherzustellen: https://github.com/tensorflow/models/tree/master/research/resnet

In resnet_model.py, Linie 116-118 der ursprüngliche Code (oder Grafik) ist:

with tf.variable_scope('logit'): 
    logits = self._fully_connected(x, self.hps.num_classes) 
    self.predictions = tf.nn.softmax(logits) 
with tf.variable_scope('costs'): 
    xent = tf.nn.softmax_cross_entropy_with_logits(
    logits=logits, labels=self.labels) 
    self.cost = tf.reduce_mean(xent, name='xent') 
    self.cost += self._decay() 

nach dem ersten Training, ich Prüfpunktdateien erhalten. Dann modifizierte ich den Code zu:

with tf.variable_scope('logit_modified'): 
    logits_modified = self._fully_connected('fc_1',x, 48) 
    #self.predictions = tf.nn.softmax(logits)  
with tf.variable_scope('logit_2'): 
    logits_2 = self._fully_connected('fc_2', logits_modified, 
    self.hps.num_classes) 
    self.predictions = tf.nn.softmax(logits_2) 
with tf.variable_scope('costs'): 
    xent = tf.nn.softmax_cross_entropy_with_logits(
    logits=logits_2, labels=self.labels) 
    self.cost = tf.reduce_mean(xent, name='xent') 
    self.cost += self._decay() 

Dann versuche ich, die latested API tf.train.MonitoredTrainingSession zu verwenden, um den Kontrollpunkt in der ersten Ausbildung erhalten wiederherzustellen. Ich habe mehrere Methoden ausprobiert, aber keiner von ihnen funktioniert.

Versuchen 1: Wenn ich Gerüst nicht in MonitoredTrainingSession verwenden:

with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.log_root, 
    #scaffold=scaffold, 
    hooks=[logging_hook, _LearningRateSetterHook()], 
    chief_only_hooks=[summary_hook], 
    save_checkpoint_secs = 600, 
    # Since we provide a SummarySaverHook, we need to disable default 
    # SummarySaverHook. To do that we set save_summaries_steps to 0. 
    save_summaries_steps=None, 
    save_summaries_secs=None, 
    config=tf.ConfigProto(allow_soft_placement=True), 
    stop_grace_period_secs=120, 
    log_step_count_steps=100) as mon_sess: 
while not mon_sess.should_stop(): 
    mon_sess.run(_train_op) 

Die Fehlermeldungen sind:

2017-12-29 10:33:30.699061: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key logit_modified/fc_1/biases/Momentum not found in checkpoint ...

Während Es scheint, dass die Sitzung nach dem modifizierten wiederherzustellen Trys Graphen, aber nicht die Variablen, die sowohl im neuen Diagramm als auch in der vorherigen Prüfpunktdatei existieren (dh alle Ebenen schließen die letzten 2 aus).

Versuchen 2: Inspiriert durch den Transfer Lerncode tf.train.Supervisor hier: https://github.com/kwotsin/transfer_learning_tutorial/blob/master/train_flowers.py, von der Linie 251.

Zuerst habe ich den Code in resnet_model.py geändert, fügen Sie diese Zeile:

self.variables_to_restore = tf.contrib.framework.get_variables_to_restore(
exclude=["logit_modified", "logit_2"]) 

Dann wird das Gerüst in MonitoredTrainingSession geändert:

saver = tf.train.Saver(variables_to_restore) 
def restore_fn(sess): 
    return saver.restore(sess, FLAGS.log_root) 
scaffold = tf.train.Scaffold(saver=saver, init_fn = restore_fn) 

Leider ist die followiing Fehlermeldung wurde angezeigt:

RuntimeError: Init operations did not make model ready for local_init. Init op: group_deps, init fn: at 0x7f0ec26f4320>, error: Variables not initialized: logit_modified/fc_1/DW, ...

scheint, wie die letzten 2 Schichten werden nicht richtig gestellt, so dass die übrigen Schichten werden nicht gestellt.

Versuchen Sie 3: Ich habe auch versucht, Methoden, die hier auflisten: How to use tf.train.MonitoredTrainingSession to restore only certain variables, aber keiner von ihnen funktioniert.

Ich weiß, es gibt andere Methoden zur Wiederherstellung wie der Code in https://github.com/tensorflow/models/blob/6fb14a790c283a922119b19632e3f7b8e5c0a729/research/inception/inception/inception_model.py, aber sie sind verschachtelt und nicht allgemein genug, um leicht auf andere Modelle angewendet werden. Aus diesem Grund möchte ich "MonitoredTrainingSession" verwenden.

So wie "MonitoredTrainingSession" verwenden, um nur einen Teil des Prüfpunkts in Tensorflow wiederherzustellen?

+0

traurig, dass die Fehlermeldung hat immer ein Einrückungsproblem auf StackOverflow. Die Grundidee der Fehlermeldung wurde stattdessen in der Frage erläutert. – GhostPotato

Antwort

0

OK, endlich finde ich es heraus.

Nach dem Lesen der überwachten_session.py hier: https://github.com/tensorflow/tensorflow/blob/4806cb0646bd21f713722bd97c0d0262c575f7e0/tensorflow/python/training/monitored_session.py, Ich habe festgestellt, der Schlüssel (und sehr knifflig) Punkt ist, zu einem neuen leeren Checkpoint-Verzeichnis zu ändern, so dass die MonitoredTrainingSession init_op oder init_fn nicht ignorieren. Dann können Sie den folgenden Code verwenden, um Ihre init_fn zu bauen (um Kontrollpunkt wiederherzustellen) als Gerüst auch:

variables_to_restore = tf.contrib.framework.get_variables_to_restore(
    exclude=['XXX'])  
init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
    ckpt.model_checkpoint_path, variables_to_restore) 
def InitAssignFn(scaffold,sess): 
    sess.run(init_assign_op, init_feed_dict) 

scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn) 

Erinnern Sie sich die ckpt.model_checkpoint_path oben ist Ihre alten Checkpoint Pfad mit vortrainierte Dateien darin. Der neue leere Kontrollpunkt Verzeichnis, die ich oben erwähnt habe, bedeutet den Parameter „checkpoint_dir“ von MonitoredTrainingSession hier:

with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.log_root_2,...) as mon_sess: 
while not mon_sess.should_stop(): 
    mon_sess.run(_train_op) 

Der erste Absatz der Code von mir geändert leitet sich von learning.py in tf.slim, von der Linie 134 : https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/learning.py

plus-: Dank dieser Q & A für Inspiration, obwohl die Lösung ein wenig anders: What's the recommend way of restoring only parts model in distributed tensorflow

Verwandte Themen