2017-12-26 4 views
1

Ich habe ein Modell in Tensorflow gespeichert und möchte es für die weitere Verwendung wiederherstellen, aber ich habe einen Fehler erhalten. Der Code ist wie folgt:Wie stelle ich einen Platzhalter im Tensorflow wieder her?

import tensorflow as tf 
def input_func(dim): 
    input_ = tf.placeholder(tf.float32,[1,dim]) 
    return input_ 
def fully_connect(input_,out_dimension): 
    out=tf.layers.dense(input_, out_dimension,\ 
     kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=False)) 
    return tf.reduce_sum(out) 
def train(real_input, input_dim, out_dimension): 
    input_ = input_func(input_dim) 
    output = fully_connect(input_, out_dimension) 
    with tf.Session() as sess: 
     sess.run(tf.global_variables_initializer()) 
     for epoch in range(10): 
      sess.run(output, {input_:real_input}) 

     tf.add_to_collection('input_',input_) 
     tf.add_to_collection('output',output) 
     tf.train.Saver().save(sess,'./save/expression') 
dim=3 
out_dimension=2 
real_input=[[1,2,3]] 
with tf.Graph().as_default(): 
    train(real_input, dim, out_dimension) 

Jetzt ist das Modell gebaut und gespeichert.

später das Modell wiederherzustellen ich den folgenden Code verwendet:

with tf.Session() as sess: 
    loader = tf.train.import_meta_graph('./save/expression.ckpt.meta') 
    loader.restore(sess, './save/expression.ckpt') 
    input_=tf.get_collection('input_') 
    print(input_) 
    output=tf.get_collection('output') 
    print(sess.run(output, {input_:[[4,5,6]]})) 

Aber ich ist ein Fehler aufgetreten:

INFO:tensorflow:Restoring parameters from ./save/expression.ckpt 
[] 

--------------------------------------------------------------------------- 
TypeError         Traceback (most recent call last) 
<ipython-input-98-6cfbdc96438e> in <module>() 
     5  print(input_) 
     6  output=tf.get_collection('output') 
----> 7  print(sess.run(output, {input_:[[4,5,6]]})) 

TypeError: unhashable type: 'list' 

Es scheint, dass der Platzhalter input_ nicht gespeichert!

Kann mir jemand dabei helfen?

+0

Platzhalter können nicht gespeichert werden. –

Antwort

1

Sie müssen den Platzhalter wiederherstellen und ihm den entsprechenden Wert geben. Idealerweise sollten Sie Ihren Platzhalter beim Erstellen benannt haben. Da Sie es nicht benannt haben, müssen Sie den Namen aus Ihrem Diagramm finden. Nachdem das Modell wiederhergestellt wurde, drucken Sie den Namen der Knoten in Ihrem Diagramm aus. Der Platzhalter wird zuerst gedruckt. Sie können dies tun, mit

with tf.Session() as sess: 
    loader = tf.train.import_meta_graph('./save/expression.ckpt.meta') 
    loader.restore(sess, './save/expression.ckpt') 
    graph = tf.get_default_graph() 
    for op in graph.get_operations(): 
     print(op.name) 

ich, dass der Eingang Platzhalter erraten der Standardnamen „Platzhalter“ gegeben werden. Nachdem Sie seinen Namen gefunden haben, müssen Sie diesen Tensor wiederherstellen und ihm einen Wert geben. Wenn der Name Placeholder ist, können Sie es mit

graph.get_tensor_by_name('Placeholder:0')

wiederherstellen Sie sollten den Namen Ihres Ausgangsknoten in der gleichen Weise lokalisieren. Es sollte etwas wie fully_connected_1/matmul... sein, nehmen wir an, dass der Name outputNodeName ist. Dann können Sie Ihr Diagramm als

with tf.Session() as sess: 
    loader = tf.train.import_meta_graph('./save/expression.ckpt.meta') 
    loader.restore(sess, './save/expression.ckpt') 
    graph = tf.get_default_graph() 
    input_= graph.get_tensor_by_name('Placeholder:0') 
    output=tf.get_collection('outputNodeName:0') 
    print(sess.run(output, {input_:[[4,5,6]]})) 
Verwandte Themen