2017-08-16 4 views
4

Ich habe 2 Möglichkeiten gefunden, ein Modell in Tensorflow zu speichern: tf.train.Saver() und SavedModelBuilder. Jedoch, Ich kann keine Dokumentation zur Verwendung des Modells finden, nachdem es auf die zweite Art geladen wurde.Wie wird ein gespeichertes Modell auf Tensorflow geladen und verwendet?

Hinweis: Ich möchte SavedModelBuilder Weg verwenden, weil ich das Modell in Python trainieren und es in einer anderen Sprache verwenden (Go), und es scheint, dass SavedModelBuilder der einzige Weg in diesem Fall ist.

Dies funktioniert gut mit tf.train.Saver() (ersten Weg):

model = tf.add(W * x, b, name="finalnode") 

# save 
saver = tf.train.Saver() 
saver.save(sess, "/tmp/model") 

# load 
saver.restore(sess, "/tmp/model") 

# IMPORTANT PART: REALLY USING THE MODEL AFTER LOADING IT 
# I CAN'T FIND AN EQUIVALENT OF THIS PART IN THE OTHER WAY. 

model = graph.get_tensor_by_name("finalnode:0") 
sess.run(model, {x: [5, 6, 7]}) 

tf.saved_model.builder.SavedModelBuilder() im Readme definiert, aber nach dem Modell mit tf.saved_model.loader.load(sess, [], export_dir) Laden) kann ich Dokumentation nicht an dem Knoten auf immer wieder finden (siehe "finalnode" im Code oben)

Antwort

4

Was fehlte der signature war

# Saving 
builder = tf.saved_model.builder.SavedModelBuilder(export_dir) 
builder.add_meta_graph_and_variables(sess, ["tag"], signature_def_map= { 
     "model": tf.saved_model.signature_def_utils.predict_signature_def(
      inputs= {"x": x}, 
      outputs= {"finalnode": model}) 
     }) 
builder.save() 

# loading 
with tf.Session(graph=tf.Graph()) as sess: 
    tf.saved_model.loader.load(sess, ["tag"], export_dir) 
    graph = tf.get_default_graph() 
    x = graph.get_tensor_by_name("x:0") 
    model = graph.get_tensor_by_name("finalnode:0") 
    print(sess.run(model, {x: [5, 6, 7, 8]})) 
0

Tensorflow ‚s bevorzugten Art des Gebäudes und ein Modell in verschiedenen Sprachen ist tensorflow serving

In diesem Fall verwenden Sie saver.save, um das Modell zu speichern. Auf diese Weise speichert es eine meta Datei, ckpt Datei und einige andere Dateien, um die Gewichtungen und Netzwerkinformationen zu speichern, Schritte trainiert usw. Dies ist die bevorzugte Art zu sparen, während Sie trainieren.

Wenn Sie jetzt mit dem Training fertig sind, sollten Sie das Diagramm mit SavedModelBuilder aus den Dateien, die Sie speichern, durch saver.save einfrieren. Dieser eingefrorene Graph enthält eine pb Datei und enthält das gesamte Netzwerk und die Gewichte.

Dieses gefrorene Modell sollte verwendet werden, um von tensorflow serving dienen und dann andere Sprachen können das Modell mit gRPC Protokoll verwenden.

Das gesamte Verfahren ist in this ausgezeichnete Tutorial beschrieben.

+0

Dank für die Antwort und den Link, aber das beantwortet nicht so sehr meine Frage ... – Thomas

+0

Der Link * nicht * hat die Antwort irgendwo nach „The lesen Schritt - Speichern Sie das Modell, aber das ist einfach zu finden, nur wenn Sie bereits wissen, wo Sie suchen müssen ... es könnte definitiv prägnanter sein, aber auch danke für den Link und die Einsichten –

Verwandte Themen