Ich bin sehr, sehr neu in Tensorflow und muss ein Skript schreiben, das ein einzelnes Beispiel auf einem Modell testet, das aus einer Prüfpunktdatei wiederhergestellt wurde.Allgemeine Methode zum Testen des wiederhergestellten Tensorflussmodells
Ich frage mich, ob es eine allgemeine Möglichkeit gab, eine Testfunktion für ein wiederhergestelltes Modell zu erstellen, ohne alle winzigen Details des Modells zu kennen.
Weiter, sieht das im letzten Abschnitt des Codes unten aus, wie ich in die richtige Richtung gehe? Wenn ja, wie baut man "y", ohne Details des Modells auswendig zu kennen?
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
import numpy as np
from fuel.datasets.hdf5 import H5PYDataset
ckpt_path='ckt/mnist/mnist_2017_02_23_17_22_50/mnist_2017_02_23_17_22_50_5000.ckpt'
##############################
#### Initialize Variables ####
##############################
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
var_to_shape_map = reader.get_variable_to_shape_map()
var=[0]*len(var_to_shape_map)
i=0
for key in var_to_shape_map:
var[i] = tf.Variable(reader.get_tensor(key), name=key)
#print("tensor_name: ", key)
#print(reader.get_tensor(key))
i=i+1
initialize=tf.global_variables_initializer()
###############################
####### Restore Model #########
###############################
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, ckpt_path)
###############################
##### Get Example to Test #####
###############################
test_set = H5PYDataset('../CNN3D/data/bmnist.hdf5', which_sets=('test',))
handle = test_set.open()
for i in range(0,100):
test_data = test_set.get_data(handle, slice(i, i+1))
if test_data[1][0][0]==8:
model_idx=i
test_data = test_set.get_data(handle, slice(model_idx,model_idx+1))
data = tf.Variable(np.asarray(test_data[0][0][0]), name='data')
###############################
######## Test Example #########
###############################
x = tf.placeholder(tf.float32,shape=[28,28])
y = ???
sess.run(initialize)
result=sess.run(y, feed_dict={x: data})
print result