3

ich ein RNN-Modell für die ptb example mit tensorflow des ptb_word trainiert haben. Bellow Ich habe einen Code, wo ich versuche, ein paar Beispiele drucken trainiert das Modell zu testen. Ich erhalte einen Fehler TypeError: 'Tensor' object is not callable, wenn ich diesen Code auf der Linie führe, die ich mache probs, state = sess.run([mtest.output_probs(), mtest._final_state], feed_dict=feed_dict)Tensorflow Session.Run() Tensor Objekt ist nicht aufrufbar

Was genau verursacht diesen Fehler?

hier ist der Code:

import numpy as np 
import os 
import tensorflow as tf 
from ptb_word_lm import * 
from tensorflow.models.rnn.ptb import reader 
from tensorflow.python.platform import gfile 

data_path = "/home/usr/simple-examples/data/" 
raw_data = reader.ptb_raw_data(data_path) 
train_data, valid_data, test_data, vocabulary = raw_data 

test_path = os.path.join(data_path, "ptb.test.txt") 
word_to_id = reader._build_vocab(test_path) 


eval_config = get_config() 
eval_config.batch_size = 1 
eval_config.num_steps = 1 

sess = tf.Session() 

initializer = tf.random_uniform_initializer(-eval_config.init_scale, 
              eval_config.init_scale) 
test_input = PTBInput(config=eval_config, data=test_data, name="TestInput") 
with tf.variable_scope("model", reuse=None, initializer=initializer): 
    mtest = PTBModel(is_training=False, config=eval_config, input_=test_input) 

sess.run(tf.initialize_all_variables()) 

saver = tf.train.import_meta_graph('/home/usr/models/medium/model.ckpt-50979.meta') 

ckpt = tf.train.get_checkpoint_state('/home/usr/models/medium/') 
if ckpt and gfile.Exists(ckpt.model_checkpoint_path): 
    msg = 'Reading model parameters from %s' % ckpt.model_checkpoint_path 
    print(msg) 
    saver.restore(sess, ckpt.model_checkpoint_path) 

def pick_from_weight(weight, pows=1.0): 
    weight = weight**pows 
    t = np.cumsum(weight) 
    s = np.sum(weight) 
    return int(np.searchsorted(t, np.random.rand(1) * s)) 

while True: 
    number_of_sentences = 10 
    sentence_cnt = 0 
    text = '\n' 
    end_of_sentence_char = word_to_id['<eos>'] 
    input_char = np.array([[end_of_sentence_char]]) 
    state = sess.run(mtest.initial_state) 
    for attr in mtest.__dict__: 
     print attr 
    print 'all attributes above' 
    while sentence_cnt < number_of_sentences: 
     feed_dict = {mtest._input: input_char, 
        mtest.initial_state: state} 

     probs, state = sess.run([mtest.output_probs(), mtest._final_state], feed_dict=feed_dict) 

     print 'after state' 
     sampled_char = pick_from_weight(probs[0]) 
     print sampled_char 
     if sampled_char == end_of_sentence_char: 
      text += '.\n' 
      sentence_cnt += 1 
     else: 
      text += ' ' + id_to_word[sampled_char] 
     input_char = np.array([[sampled_char]]) 
    print(text) 
    raw_input('press any key to continue ...') 

Antwort

1

am referenzierten Code auf GitHub Blick ich nicht output_props, so vielleicht die Versionen unterscheiden sich finden können. Da jedoch mtest.initial_state ein @property ist, nehme ich an, dass mtest.output_props man als gut ist. Das heißt, versuchen Sie statt dessen, d. H. Ohne die Klammern zu verwenden.

Auch mtest._final_state ist eine interne Variable und sollte nicht direkt verwendet werden. Sie wollen wahrscheinlich mtest.final_state stattdessen verwenden.

+0

Wenn ich beide Änderungen vorgenommen habe, bekam ich einen weiteren Fehler: 'TypeError: Kann den feed_dict Schlüssel nicht als Tensor interpretieren: Kann einen PTBInput nicht in einen Tensor konvertieren ' – smith

+1

Das scheint zu sein, weil' mtest._input' [tatsächlich] ist (https : //github.com/tensorflow/models/blob/520b557e095b008bfb023da1c749b3d0eabc521c/tutorials/rnn/ptb/ptb_word_lm.py#L101) Ihre 'test_input' Referenz. Ich denke, "mtest.input.input_data" könnte funktionieren. – sunside

+0

Vielen Dank. Ich bekomme immer noch Fehler im Code, aber das war die Linie, mit der ich feststeckte. Vielen Dank für Ihre Zeit. – smith

Verwandte Themen