2017-02-24 5 views
0

Ich bin sehr, sehr neu in Tensorflow und muss ein Skript schreiben, das ein einzelnes Beispiel auf einem Modell testet, das aus einer Prüfpunktdatei wiederhergestellt wurde.Allgemeine Methode zum Testen des wiederhergestellten Tensorflussmodells

Ich frage mich, ob es eine allgemeine Möglichkeit gab, eine Testfunktion für ein wiederhergestelltes Modell zu erstellen, ohne alle winzigen Details des Modells zu kennen.

Weiter, sieht das im letzten Abschnitt des Codes unten aus, wie ich in die richtige Richtung gehe? Wenn ja, wie baut man "y", ohne Details des Modells auswendig zu kennen?

import tensorflow as tf 
from tensorflow.python import pywrap_tensorflow 
import numpy as np 
from fuel.datasets.hdf5 import H5PYDataset 

ckpt_path='ckt/mnist/mnist_2017_02_23_17_22_50/mnist_2017_02_23_17_22_50_5000.ckpt' 

############################## 
#### Initialize Variables #### 
############################## 

reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path) 
var_to_shape_map = reader.get_variable_to_shape_map() 
var=[0]*len(var_to_shape_map) 
i=0 
for key in var_to_shape_map: 
    var[i] = tf.Variable(reader.get_tensor(key), name=key) 
    #print("tensor_name: ", key) 
    #print(reader.get_tensor(key)) 
    i=i+1 
initialize=tf.global_variables_initializer() 

############################### 
####### Restore Model ######### 
############################### 

saver = tf.train.Saver() 
sess = tf.Session() 
saver.restore(sess, ckpt_path) 

############################### 
##### Get Example to Test ##### 
############################### 

test_set = H5PYDataset('../CNN3D/data/bmnist.hdf5', which_sets=('test',)) 
handle = test_set.open() 
for i in range(0,100): 
    test_data = test_set.get_data(handle, slice(i, i+1)) 
    if test_data[1][0][0]==8: 
     model_idx=i 
test_data = test_set.get_data(handle, slice(model_idx,model_idx+1)) 
data = tf.Variable(np.asarray(test_data[0][0][0]), name='data') 

############################### 
######## Test Example ######### 
############################### 

x = tf.placeholder(tf.float32,shape=[28,28]) 
y = ??? 
sess.run(initialize) 
result=sess.run(y, feed_dict={x: data}) 
print result 

Antwort

0

Die Estimator Klasse verfügt über eine komfortable Reihe von Dienstprogrammen, so dass, wenn Ihr Modell um einen Schätzer, Laden gewickelt ist und daraus die Vorhersage einfach.

Insgesamt, ohne irgendeine Art von Koordination, wird dies schwer sein.

Verwandte Themen