2017-10-24 2 views
2

Ich habe die folgenden Dateien:Tensorflow - Speichern der Checkpoint-Dateien als .pb, aber ohne Ausgabeknotennamen

model.ckpt-2400.data-00000-of-00001 
model.ckpt-2400.index 
model.ckpt-2400.meta 

Und ich möchte sie in der Form eines .pb mit der folgenden Funktion speichern:

def freeze_graph(model_dir, output_node_names): 
    """Extract the sub graph defined by the output nodes and convert all its variables into constant 
    Args: 
    model_dir: the root folder containing the checkpoint state file 
    output_node_names: a string, containing all the output node's names, 
         comma separated 
    """ 
    if not tf.gfile.Exists(model_dir): 
     raise AssertionError(
      "Export directory doesn't exists. Please specify an export " 
      "directory: %s" % model_dir) 

    if not output_node_names: 
     print("You need to supply the name of a node to --output_node_names.") 
     return -1 

    # We retrieve our checkpoint fullpath 
    checkpoint = tf.train.get_checkpoint_state(model_dir) 
    input_checkpoint = checkpoint.model_checkpoint_path 

    # We precise the file fullname of our freezed graph 
    absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1]) 
    output_graph = absolute_model_dir + "/frozen_model.pb" 

    # We clear devices to allow TensorFlow to control on which device it will load operations 
    clear_devices = True 

    # We start a session using a temporary fresh Graph 
    with tf.Session(graph=tf.Graph()) as sess: 
     # We import the meta graph in the current default Graph 
     saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices) 

     # We restore the weights 
     saver.restore(sess, input_checkpoint) 

     # We use a built-in TF helper to export variables to constants 
     output_graph_def = tf.graph_util.convert_variables_to_constants(
      sess, # The session is used to retrieve the weights 
      tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 
      output_node_names.split(",") # The output node names are used to select the usefull nodes 
     ) 

     # Finally we serialize and dump the output graph to the filesystem 
     with tf.gfile.GFile(output_graph, "wb") as f: 
      f.write(output_graph_def.SerializeToString()) 
     print("%d ops in the final graph." % len(output_graph_def.node)) 

    return output_graph_def 

Das Problem ist, dass, wenn ich tf.get_default_graph().as_graph_def().node verwenden, ist es [] zurückgibt. Ein leeres Array. Es gibt keine Ausgabeknotennamen, die ich dafür verwenden kann.

Also, wie sonst kann ich sie als .pb speichern? Soll ich mich nur auf die tf.python.tools.freeze_graph.freeze_graph() Funktion beziehen?

+0

Ihr Code sieht gut aus. Können Sie überprüfen, ob Ihre Prüfpunktdateien gut sind? Oder Sie können diese Dateien einfach irgendwo teilen, damit andere Leute einen Blick darauf werfen können. – Mingxing

+0

Ich habe gerade festgestellt, dass ich einen 'graph.pbtxt' in meinem Modellordner habe. Im Moment habe ich versucht, '.pbtxt' in' .pb' umzuwandeln, aber ich habe es noch nicht getestet. – Gensoukyou1337

Antwort

0

Es stellt sich heraus, dass ich nur den Namen des Ausgabeknotens angeben musste ... den ich in einem anderen Teil meines Codes als Knoten für die Protokollierung festlegen musste, um die Ergebnisse zu überprüfen.

predictions = { 
     # Generate predictions (for PREDICT and EVAL mode) 
     "classes": tf.argmax(input=logits, axis=1), 
     # Add `softmax_tensor` to the graph. It is used for PREDICT and by the 
     # `logging_hook`. 
     "probabilities": tf.nn.softmax(logits, name="softmax_tensor") #This one 
    } 

In meinem Fall ist es softmax_tensor.

Verwandte Themen