bezogen auf: https://www.tensorflow.org/how_tos/reading_data/Eingabesequenzen unter Verwendung von CSV-Reader und Warteschlangen in Tensorflow
ZIEL:
Ich mag ein LSTM Verwendung vorhergehende Sequenzen in der folgenden Form trainieren: [t0 t1 t2] , [t1 t2 t3], [t2 t3 t4] ...
Außerdem sollten diese Sequenzen auch gemischt werden. z.B. [t2 t3 t4], [t0 t1 t2], [t1 t2 t3] ...
Meine Daten werden in einer csv-Datei gespeichert, wobei jede Zeile einen Zeitschritt darstellt. Die Spalten enthalten die verschiedenen Merkmale & Zielwert.
Frage:
Gibt es eine Weise gemischt kohärente Sequenzen zu füttern Verwendung CSV-Leser und Warteschlangen (nicht Platzhalter und feed_dict) in Tensorflow? Ich kann nicht eine Möglichkeit finden, das mit zu implementieren: tf.TextLineReader() und tf.train.shuffle_batch().
Meine Abhilfe tut, was es zu sein, angenommen hat, ist aber schrecklich langsam:
train_filename_queue = tf.train.string_input_producer([path])
rand_ind_q = tf.train.range_input_producer(data_len-seq_len, shuffle=True)
def read_csv(filename_queue, ncols, header_lines):
''' returns a list of tensors with content of csv-file
'''
# content <- [(data_len,) ... ncols ... (data_len,)]
whole_reader = tf.WholeFileReader()
_, content = whole_reader.read(filename_queue)
content = tf.string_split([content], delimiter='\n').values[header_lines:]
record_defaults = ncols*[[0.]]
content = tf.decode_csv(content, record_defaults, field_delim=',')
return content
def slice_seq(q, content, seq_len):
''' returns a list of tensors with sequences
'''
# seq <- [(1,seq_len,) ... ncols ... (1,seq_len,)]
start_ind = q.dequeue()
seq = list(map(lambda tensor: tf.slice(tensor, [start_ind], [seq_len]), content))
seq = list(map(lambda tensor: tf.reshape(tensor, (1,seq_len,)), seq))
return seq
Danke, definitiv schneller als die WholeFilerReader-Version! Müssen Sie noch etwas herumspielen, um zu überprüfen, ob das schneller ist als mit Platzhalter + Feed_dict. – phtephanx
Die Sequenzlänge sieht hier wie 'block_size-seq_len + 1' aus, was 'seq_len' hätte sein sollen, oder? Zweitens, was ist block_size? Gesamtzahl der in der CSV-Datei gespeicherten Samples? – iliTheFallen