0

Ich habe ein Modell mit Tensorflow trainiert und während des Trainings die Batch-Normalisierung verwendet. Die Batch-Normalisierung erfordert, dass der Benutzer einen booleschen Wert is_training übergibt, um festzulegen, ob sich das Modell in der Trainings- oder Testphase befindet.Das trainierte Tensorflow-Modell wiederherstellen, den mit einem Knoten verknüpften Wert bearbeiten und speichern

Wenn das Modell trainiert wurde, is_training als Konstante gesetzt wurde, wie unten

is_training = tf.constant(True, dtype=tf.bool, name='is_training') 
ich das trainierte Modell gespeichert haben

gezeigt, enthalten die Dateien Kontrollpunkt, .meta Datei, .index-Datei und eine .data . Ich möchte das Modell wiederherstellen und daraus Inferenz ziehen. Das Modell kann nicht umgeschult werden. Daher möchte ich das vorhandene Modell wiederherstellen, den Wert is_training auf False setzen und dann das Modell speichern. Wie kann ich den booleschen Wert bearbeiten, der diesem Knoten zugeordnet ist, und das Modell erneut speichern?

+0

es einfacher gewesen wäre, wenn Sie 'is_training verwendet = tf.Variable..' eher als Konstante –

+0

Gibt es einen Grund, warum' is_training' eine tensorflow konstant sein muss? Kann es nicht ein Python Bool sein? Beachten Sie, dass das Ändern von "is_training" in ein Python-Bool beim Wiederherstellen des Modells keine Fehler verursachen sollte. – GeertH

+0

@GeertH Es kann sein, die Frage ist, wie setze ich "is_training" auf "False" nach dem Laden des Modells, dann speichern Sie es zurück. Wenn der Knoten also wiederhergestellt wird, hat er den Wert 'False'. – dpk

Antwort

1

Sie können das Argument input_map von tf.train.import_meta_graph verwenden, um den Graphtensor auf einen aktualisierten Wert neu zuzuordnen.

config = tf.ConfigProto(allow_soft_placement=True) 
with tf.Session(config=config) as sess: 
    # define the new is_training tensor 
    is_training = tf.constant(False, dtype=tf.bool, name='is_training') 

    # now import the graph using the .meta file of the checkpoint 
    saver = tf.train.import_meta_graph(
    '/path/to/model.meta', input_map={'is_training:0':is_training}) 

    # restore all weights using the model checkpoint 
    saver.restore(sess, '/path/to/model') 

    # save updated graph and variables values 
    saver.save(sess, '/path/to/new-model-name') 
+0

Der obige Code löst einen Fehler aus: "ValueError: tf.import_graph_def() benötigt einen nicht leeren Namen, wenn input_map verwendet wird" – dpk

+0

Ich habe diesen Code mit 'tensorflow == 1.2.0' getestet, hoffe es hilft; ALSO ist es nicht "tf.import_graph_def". Siehe meinen Code. –

+0

Ich habe Ihren Code so probiert, wie er ist, der Fehler wird von dieser Zeile ausgelöst, 'saver = tf.train.import_meta_graph (r'D: \ code \ iprings \ k-falten-modell \ VanillaCNN_24.0000.meta ', input_map = {'is_training': is_training}) ' – dpk

Verwandte Themen