2017-06-25 4 views
1

Ich versuche heart.csv Dateidaten in Batches zu lesen. Im Anschluss an die Dokumentation von tensorflow Website, habe ich den folgenden Code ein Arbeits Zeile für Zeile lesenLesen von CSV-Dateien in Tensorflow 1.2.0

import tensorflow as tf 
filename_queue = tf.train.string_input_producer(["heart.csv"]) 
reader = tf.TextLineReader(skip_header_lines=1) 
_, csv_row = reader.read(filename_queue) 

record_defaults = [[0], [0.0], [0.0], [0.0], [""], [0], [0.0], [0.0], [0], [0]] 
sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age, chd = tf.decode_csv(csv_row, record_defaults=record_defaults) 
features = [sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age] 

nof_examples = 10 
with tf.Session() as sess: 
    tf.global_variables_initializer().run() 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 
    while nof_examples > 0: 
     nof_examples -= 1 
     try: 
      data_features, data_chd = sess.run([features, chd]) 
#    data_features[4] = 1 if data_features[4] == 'Present' else 0 
      print(data_features, data_chd) 
     except tf.errors.OutOfRangeError: 
      coord.request_stop() 
      coord.join(threads) 
      break 
    coord.request_stop() 
    coord.join(threads) 

Ausgang:

([160, 12.0, 5.73, 23.110001, 'Present', 49, 25.299999, 97.199997, 52], 1) 
([144, 0.0099999998, 4.4099998, 28.610001, 'Absent', 55, 28.870001, 2.0599999, 63], 1) 
([118, 0.079999998, 3.48, 32.279999, 'Present', 52, 29.139999, 3.8099999, 46], 0) 
([170, 7.5, 6.4099998, 38.029999, 'Present', 51, 31.99, 24.26, 58], 1) 
([134, 13.6, 3.5, 27.780001, 'Present', 60, 25.99, 57.34, 49], 1) 
([132, 6.1999998, 6.4699998, 36.209999, 'Present', 62, 30.77, 14.14, 45], 0) 
([142, 4.0500002, 3.3800001, 16.200001, 'Absent', 59, 20.809999, 2.6199999, 38], 0) 
([114, 4.0799999, 4.5900002, 14.6, 'Present', 62, 23.110001, 6.7199998, 58], 1) 
([114, 0.0, 3.8299999, 19.4, 'Present', 49, 24.860001, 2.49, 29], 0) 
([132, 0.0, 5.8000002, 30.959999, 'Present', 69, 30.110001, 0.0, 53], 1) 

Aber wenn ich versuche, in den Reihen zu lesen, wie in der tensorflow Dokumentation zeigte, i erhalten

TypeError: Cannot convert a list containing a tensor of dtype <dtype: 
float32'> to <dtype: 'int32'> (Tensor is: <tf.Tensor 'DecodeCSV_6:1' 
shape=() dtype=float32>) 

Batch-Verarbeitungscode

import tensorflow as tf 
batch_size = 1 
def read_my_file_format(filename_queue): 
    reader = tf.TextLineReader(skip_header_lines=1) 
    _, csv_row = reader.read(filename_queue) 
    record_defaults = [[0], [0.0], [0.0], [0.0], [""], [0], [0.0], [0.0], [0], [0]] 
    sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age, chd = tf.decode_csv(csv_row, record_defaults=record_defaults) 
    feature = [sbp, tobacco, ldl, adiposity, famhist, typea, obesity, alcohol, age] 
    label = [chd] 
    return feature, label 

def input_pipeline(filenames, batch_size, num_epochs=None): 
    filename_queue = tf.train.string_input_producer(filenames, 
                num_epochs=num_epochs, 
                shuffle=True) 
    feature, label = read_my_file_format(filename_queue) 
    min_after_dequeue = 10000 
    capacity = min_after_dequeue + 3 * batch_size 
    feature_batch, label_batch = tf.train.shuffle_batch([feature, label], 
                 batch_size=batch_size, 
                 capacity=capacity, 
                 min_after_dequeue=min_after_dequeue) 
    return feature_batch, label_batch 

features, labels = input_pipeline(['heart.csv'], batch_size) 

with tf.Session() as sess: 
    tf.global_variables_initializer().run() 

    # start populating filename queue 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    try: 
     while not coord.should_stop(): 
      feature_batch, label_batch = sess.run([features, labels]) 
      print(feature_batch) 
    except tf.errors.OutOfRangeError: 
     print('Done training, epoch reached') 
    finally: 
     coord.request_stop() 
    coord.join(threads) 

Das Lesen von CSV-Dateien mit Tensorflow scheint etwas umständlich, aber ich bin mir sicher, dass es wichtig ist, dass die Bibliothek ein verteiltes System ist. Ich fand es verwirrend und brauchte mehr als 60 Minuten zu lesen und zu verstehen, wie die Lese-Feed-Pipeline für CSV-Dateien funktioniert. Vielleicht sollte Dokumentation besser sein und mehr Visuals benötigt werden.

+0

Ist der Code, den Sie gab [MCVE]? – boardrider

+0

@boardrider Ja, der Code ist vollständig – bicepjai

+0

Aber, ist es Minimal? – boardrider

Antwort

1

hatte ich einen Blick auf den Code, und es scheint, dass eine der internen Funktion in tf.train.shuffle_batch, dass die gleichen dtype alle Tensoren in der Reihe erfordert haben (aus dem ersten Elemente zu entnehmen, in Sie ein tf.int32 Fall). Sie könnten sie in einer Zeichenfolge dekodieren und später im richtigen Typ konvertieren. Nicht sehr praktisch.

Aber was ich empfehlen würde, wenn Sie TensorFlow 1.2.0 verwenden, ist die Verwendung der neuen DataSet-API, die die neue Methode zum Umgang mit Daten ist (siehe zum Beispiel this answer).

auf die zitierte Antwort Basierend, hier ist ein Beispiel für die neue API, die funktionieren sollte:

def read_row(csv_row): 
    record_defaults = [[0], [0.0], [0.0], [0.0], [""], [0], [0.0], [0.0], [0], [0]] 
    row = tf.decode_csv(csv_row, record_defaults=record_defaults) 
    return row[:-1], row[-1] 

def input_pipeline(filenames, batch_size): 
    # Define a `tf.contrib.data.Dataset` for iterating over one epoch of the data. 
    dataset = (tf.contrib.data.TextLineDataset(filenames) 
       .skip(1) 
       .map(lambda line: read_row(line)) 
       .shuffle(buffer_size=10) # Equivalent to min_after_dequeue=10. 
       .batch(batch_size)) 

    # Return an *initializable* iterator over the dataset, which will allow us to 
    # re-initialize it at the beginning of each epoch. 
    return dataset.make_initializable_iterator() 

iterator = input_pipeline(['heart.csv'], batch_size) 
features, labels = iterator.get_next() 


nof_examples = 10 
with tf.Session() as sess: 
    tf.global_variables_initializer().run() 
    sess.run(iterator.initializer) 
    while nof_examples > 0: 
     nof_examples -= 1 
     try: 
      data_features, data_labels = sess.run([features, labels]) 
      print(data_features) 
     except tf.errors.OutOfRangeError: 
      pass 
+0

Nach dem Hinzufügen der Codeänderungen funktionierte es gut. Vielen Dank – bicepjai

Verwandte Themen