2017-06-27 2 views
0

Ich versuche ein CNN-Modell mit einigen Bildern zu trainieren, die ich beschriftet habe. Ich bin neu in TensorFlow. und hier ist das, was ich getan habe:Wie erstellt man `input_fn`, um es als Argument in` .fit() `zu verwenden?

def read_labeled_image_list(image_list_file): 
    f = open(image_list_file, 'r') 
    filenames = [] 
    labels = [] 
    for line in f: 
     filename, label = line[:-1].split(' ') 
     filenames.append(filename) 
     index0 = 1 if int(label) == 0 else 0 
     index1 = 1 if int(label) == 1 else 0 
     labels.append([index0, index1]) 
    return filenames, labels 

def read_images_from_disk(input_queue): 
    label = input_queue[1] 
    file_contents = tf.read_file(input_queue[0]) 
    example = tf.image.decode_jpeg(file_contents, channels=1) 
    return example, label 

die Verwendung von "read_images_from_disk", wie mein input_fn:

image_list, label_list = 
      read_labeled_image_list("./images_training/training_list.txt") 

images = tf.constant(image_list, dtype=tf.string) 
labels = tf.constant(label_list, dtype=tf.int32) 

# Makes an input queue 
input_queue = tf.train.slice_input_producer([images, labels], 
              num_epochs=30, 
               shuffle=True) 

image, label = read_images_from_disk(input_queue) 

# Train the model 
graph_classifier.fit(
    input_fn=read_images_from_disk(input_queue), 
    steps=20000, 
    monitors=[logging_hook]) 

Ich erhalte die folgende Fehlermeldung:

features, labels = input_fn() 
TypeError: 'tuple' object is not callable 

Antwort

0

Der Grund für die Fehler ist, dass das input_fn Argument in der fit Methode eine aufrufbar sein soll. Sie könnten dann versuchen:

def read_images_from_disk(input_queue): 
    label = input_queue[1] 
    file_contents = tf.read_file(input_queue[0]) 
    example = tf.image.decode_jpeg(file_contents, channels=1) 
    return example, label 

def my_input_func(): 
return read_images_from_disk(input_queue) 

# Train the model 
graph_classifier.fit(
    input_fn=my_input_func, 
    steps=20000, 
    monitors=[logging_hook]) 

Ich würde auch sorgfältig the official doc auf input_func zu lesen empfehlen.

Verwandte Themen