Ich habe die folgenden Dateien:Tensorflow - Speichern der Checkpoint-Dateien als .pb, aber ohne Ausgabeknotennamen
model.ckpt-2400.data-00000-of-00001
model.ckpt-2400.index
model.ckpt-2400.meta
Und ich möchte sie in der Form eines .pb
mit der folgenden Funktion speichern:
def freeze_graph(model_dir, output_node_names):
"""Extract the sub graph defined by the output nodes and convert all its variables into constant
Args:
model_dir: the root folder containing the checkpoint state file
output_node_names: a string, containing all the output node's names,
comma separated
"""
if not tf.gfile.Exists(model_dir):
raise AssertionError(
"Export directory doesn't exists. Please specify an export "
"directory: %s" % model_dir)
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
# We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(model_dir)
input_checkpoint = checkpoint.model_checkpoint_path
# We precise the file fullname of our freezed graph
absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_dir + "/frozen_model.pb"
# We clear devices to allow TensorFlow to control on which device it will load operations
clear_devices = True
# We start a session using a temporary fresh Graph
with tf.Session(graph=tf.Graph()) as sess:
# We import the meta graph in the current default Graph
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
# We restore the weights
saver.restore(sess, input_checkpoint)
# We use a built-in TF helper to export variables to constants
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes
output_node_names.split(",") # The output node names are used to select the usefull nodes
)
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
return output_graph_def
Das Problem ist, dass, wenn ich tf.get_default_graph().as_graph_def().node
verwenden, ist es []
zurückgibt. Ein leeres Array. Es gibt keine Ausgabeknotennamen, die ich dafür verwenden kann.
Also, wie sonst kann ich sie als .pb speichern? Soll ich mich nur auf die tf.python.tools.freeze_graph.freeze_graph()
Funktion beziehen?
Ihr Code sieht gut aus. Können Sie überprüfen, ob Ihre Prüfpunktdateien gut sind? Oder Sie können diese Dateien einfach irgendwo teilen, damit andere Leute einen Blick darauf werfen können. – Mingxing
Ich habe gerade festgestellt, dass ich einen 'graph.pbtxt' in meinem Modellordner habe. Im Moment habe ich versucht, '.pbtxt' in' .pb' umzuwandeln, aber ich habe es noch nicht getestet. – Gensoukyou1337