2017-08-02 5 views
5

Ich habe das Einführungsmodell mit einem neuen Datensatz verfeinert und als ".h5" -Modell in Keras gespeichert. Jetzt ist mein Ziel, mein Modell auf Android Tensorflow, die nur ". Pb" Erweiterung akzeptiert. Frage ist das, gibt es irgendeine Bibliothek in Keras oder Tensorflow, um diese Umwandlung zu machen? Ich habe diesen Beitrag bisher gesehen: https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html aber kann noch nicht herausfinden.Wie Keras .h5 zu Tensorflow. Pb exportieren?

Antwort

9

Keras enthält keine Mittel, um ein TensorFlow-Diagramm als Protokollpufferdatei zu exportieren, aber Sie können dies mit normalen TensorFlow-Dienstprogrammen tun. Here ist ein Blogbeitrag, der erklärt, wie man es unter Verwendung des in TensorFlow enthaltenen Hilfsskriptes freeze_graph.py macht, was die "typische" Art ist, wie es gemacht wird.

Allerdings finde ich persönlich ein Ärgernis einen Kontrollpunkt zu müssen und dann ein externes Skript ausführen, ein Modell zu erhalten, und stattdessen lieber von meinem eigenen Python-Code zu tun, so dass ich eine Funktion wie folgt aus:

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): 
    """ 
    Freezes the state of a session into a pruned computation graph. 

    Creates a new computation graph where variable nodes are replaced by 
    constants taking their current value in the session. The new graph will be 
    pruned so subgraphs that are not necessary to compute the requested 
    outputs are removed. 
    @param session The TensorFlow session to be frozen. 
    @param keep_var_names A list of variable names that should not be frozen, 
          or None to freeze all the variables in the graph. 
    @param output_names Names of the relevant graph outputs. 
    @param clear_devices Remove the device directives from the graph for better portability. 
    @return The frozen graph definition. 
    """ 
    from tensorflow.python.framework.graph_util import convert_variables_to_constants 
    graph = session.graph 
    with graph.as_default(): 
     freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) 
     output_names = output_names or [] 
     output_names += [v.op.name for v in tf.global_variables()] 
     input_graph_def = graph.as_graph_def() 
     if clear_devices: 
      for node in input_graph_def.node: 
       node.device = "" 
     frozen_graph = convert_variables_to_constants(session, input_graph_def, 
                 output_names, freeze_var_names) 
     return frozen_graph 

Das ist in der Implementierung von freeze_graph.py inspiriert. Die Parameter sind dem Skript ähnlich. session ist das TensorFlow-Sitzungsobjekt. keep_var_names wird nur benötigt, wenn Sie einige Variable nicht eingefroren (z. B. für Stateful-Modelle) behalten möchten, also in der Regel nicht. output_names ist eine Liste mit den Namen der Operationen, die die gewünschten Ausgaben erzeugen. clear_devices entfernt nur alle Geräterichtlinien, um den Graphen portabler zu machen. Also, für eine typische Keras model mit einem Ausgang, würden Sie so etwas wie:

from keras import backend as K 

# Create, compile and train model... 

frozen_graph = freeze_session(K.get_session(), output_names=[model.output.op.name]) 

Dann Sie die Grafik in eine Datei wie gewohnt mit tf.train.write_graph schreiben:

tf.train.write_graph(frozen_graph, "some_directory", "my_model.pb", as_text=False) 
1

Die freeze_session Methode funktioniert gut . Aber verglichen mit dem Speichern in einer Prüfpunktdatei erscheint mir die Verwendung des freeze_graph-Tools, das mit TensorFlow geliefert wird, einfacher, da es einfacher zu warten ist. Alles, was Sie sind die folgenden zwei Schritte tun müssen:

Fügen Sie zunächst nach Ihrem Keras Code model.fit(...) und trainieren Sie Ihr Modell:

from keras import backend as K 
import tensorflow as tf 
print(model.output.op.name) 
saver = tf.train.Saver() 
saver.save(K.get_session(), '/tmp/keras_model.ckpt') 

Dann cd zu Ihrem TensorFlow Root-Verzeichnis zu starten, führen

python tensorflow/python/tools/freeze_graph.py \ 
--input_meta_graph=/tmp/keras_model.ckpt.meta \ 
--input_checkpoint=/tmp/keras_model.ckpt \ 
--output_graph=/tmp/keras_frozen.pb \ 
--output_node_names="<output_node_name_printed_in_step_1>" \ 
--input_binary=true 
Verwandte Themen