2016-07-06 3 views
11

Ich möchte die Variablen sehen, die in einem Tensorflow-Checkpoint zusammen mit ihren Werten gespeichert sind. Wie kann ich die Variablennamen finden, die in einem Tensorflow-Checkpoint gespeichert sind?Wie finden Sie die Variablennamen, die in einem Tensorflow-Checkpoint gespeichert sind?

EDIT:

I verwendet tf.train.NewCheckpointReader die here erklärt. Aber es ist nicht in der Dokumentation von Tensorflow angegeben. Gibt es einen anderen Weg?

`

import tensorflow as tf 
    v0 = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype=tf.float32, name="v0") 
    v1 = tf.Variable([[[1], [2]], [[3], [4]], [[5], [6]]], dtype=tf.float32, 
        name="v1") 
    init_all_op = tf.initialize_all_variables() 
    save = tf.train.Saver({"v0": v0, "v1": v1}) 
    checkpoint_path = os.path.join(model_dir, "model.ckpt")  

    with tf.Session() as sess: 
     sess.run(init_all_op) 
     # Saves a checkpoint.  
     save.save(sess, checkpoint_path) 

     # Creates a reader. 
     reader = tf.train.NewCheckpointReader(checkpoint_path) 
     print('reder:\n', reader) 

     # Verifies that the tensors exist. 
     print('is exist v0?', reader.has_tensor("v0")) 
     print('is exist v1?', reader.has_tensor("v1")) 

     # Verifies that debug string contains the right strings. 
     debug_string = reader.debug_string() 
     print('\n All Variables: \n', debug_string) 

     # Verifies get_variable_to_shape_map() returns the correct information. 
     var_map = reader.get_variable_to_shape_map() 
     print('\n All Variables information :\n', var_map) 

     # Verifies get_tensor() returns the tensor value. 
     v0_tensor = reader.get_tensor("v0") 
     v1_tensor = reader.get_tensor("v1") 
     print('\n returns the v0 tensor value:\n', v0_tensor) 
     print('\n returns the v1 tensor value:\n', v1_tensor) 

`

+0

Ich sah, dass Sie die Antwort akzeptiert haben. Also, was ist der Code, den Sie geschrieben haben, um die Funktion 'print_tensors_in_checkpoint_file 'auszuführen? Ich habe versucht, das zu verwenden, aber wann immer ich' tf.python.tools.inspect_checkpoint.print_tensors_in_checkpoint_file 'python sagt, dass das Modul' tensorflow.python' keine hat Attribut 'Werkzeuge'. Ich denke, es wäre immens hilfreich, wenn Sie ein kleines Beispielskript für die Ausführung dieser Funktion zur Verfügung stellen (da diese Datei auch kein Beispiel liefert), besonders, da Sie die Antwort akzeptiert haben, also nehme ich an, dass etwas für Sie funktioniert hat. – Pinocchio

Antwort

4

Sie das inspect_checkpoint.py Tool verwenden können.

+2

Ich habe versucht, das zu verwenden, aber wann immer ich 'tf.python.tools.inspect_checkpoint.print_tensors_in_checkpoint_file' python sagt, dass das Modul 'tensorflow.python' kein Attribut' tools' hat. Ich denke, dass ti immens hilfreich wäre, wenn Sie ein kleines Beispielskript zur Verfügung stellen, wie man diese Funktion ausführt (da diese Datei auch kein Beispiel liefert) – Pinocchio

19

Beispiel Nutzung:

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file 
checkpoint_path = os.path.join(model_dir, "model.ckpt") 

# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80] 
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='') 

# List contents of v0 tensor. 
# Example output: tensor_name: v0 [[[[ 9.27958265e-02 7.40226209e-02 4.52989563e-02 3.15700471e-02 
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0') 

# List contents of v1 tensor. 
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1') 

Update:all_tensors Argument wurde print_tensors_in_checkpoint_file hinzugefügt, da Tensorflow 0.12.0-rc0 so müssen Sie all_tensors=False oder all_tensors=True hinzuzufügen, falls erforderlich.

Alternative Methode:

from tensorflow.python import pywrap_tensorflow 
checkpoint_path = os.path.join(model_dir, "model.ckpt") 
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) 
var_to_shape_map = reader.get_variable_to_shape_map() 
for key in var_to_shape_map: 
    print("tensor_name: ", key) 
    print(reader.get_tensor(key)) # Remove this is you want to print only variable names 

Hoffe, es hilft.

+0

wirklich hilfreich, danke! – allen

1

oben Antwort hinzu:

Wenn Modell gespeichert wird mit V2-Format

model-10000.data-00000-of-00001 
model-10000.index 
model-10000.meta 

Ihr Kontrollpunkt Eingabe Name sollte nur das Präfix

print_tensors_in_checkpoint_file(file_name='/home/RNN/models/model_10000', tensor_name='',all_tensors=True) 

Quelle sein: von @LingjiaDeng bei https://github.com/tensorflow/tensorflow/issues/7696

Verwandte Themen