2016-04-28 15 views
2

Ich habe folgende Situation:TensorFlow: Zugriff auf Inhalt einer Variablen nach Name

Ich habe bereits gebaut, trainiert und gespeichert mein Netz. Jetzt versuche ich das Netz wiederherzustellen und die Gewichtsmatrizen zu visualisieren.

Ich kenne alle Namen für die Variablen, aber ich habe keine Python-Markierung zugewiesen, um die Variable zur Auswertung an die Sitzung zu übergeben. Wie kann ich die Daten in der Variablen abrufen?

Hier ist mein Code Situation:

dataset_params = nn_params.mnist_dataset_params 
design = nn_designs.mnist_net_A_design 
## Build Housing Object 
mnist_nn = nn_class.CNN(**dataset_params) 
mnist_nn.build_net(design['design']) 
mnist_nn.__setattr__('saved_path',saved_model) 
mnist_nn_epoch_file = saved_model+'_epochs_completed.txt' 
mnist_nn.__setattr__('epoch_file',mnist_nn_epoch_file) 


# evaluate weight variables 
session = tf.Session() 
saver = tf.train.Saver() 
session.run(tf.initialize_all_variables()) 
saver.restore(session,saved_model) 




session.close() 

Was soll ich zu Sitzung passieren, um die Gewichte zu ziehen? (Ein Beispiel für einen Gewichtungsnamen lautet: 'conv_w_1')?

Antwort

6

Sie können dies unter Verwendung der tf.get_collection() Lookup Methode, um die gewünschte Variable zu erhalten.

weight_var = tf.get_collection(tf.GraphKeys.VARIABLES, "conv_w_1")[0] 

weight_var_value = session.run(weight_var) 
0

Oder Sie können das Ergebnis erhalten, indem Funktion tf.get_default_graph mit() get_tensor_by_name:

valua_of_conv_w_1 = session.run(tf.get_default_graph().get_tensor_by_name("conv_w_1:0")) 
Verwandte Themen