2016-12-28 2 views
7

Ich füttere die Daten mit Eingabepipelinemethoden zum Diagramm und tf.train.shuffle_batch wird implementiert, um Batchdaten zu generieren. Mit fortschreitendem Training wird der Tensorflow jedoch für spätere Iterationen immer langsamer. Ich bin verwirrt darüber, was der wesentliche Grund dafür ist? Vielen Dank! Mein Code-Schnipsel ist:Das Tensorflow-Training wird langsamer und langsamer, wenn die Iteration mehr als 10.000 beträgt. Warum?

def main(argv=None): 

# define network parameters 
# weights 
# bias 

# define graph 
# graph network 

# define loss and optimization method 
# data = inputpipeline('*') 
# loss 
# optimizer 

# Initializaing the variables 
init = tf.initialize_all_variables() 

# 'Saver' op to save and restore all the variables 
saver = tf.train.Saver() 

# Running session 
print "Starting session... " 
with tf.Session() as sess: 

    # initialize the variables 
    sess.run(init) 

    # initialize the queue threads to start to shovel data 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    print "from the train set:" 
    for i in range(train_set_size * epoch): 
     _, d, pre = sess.run([optimizer, depth_loss, prediction]) 

    print "Training Finished!" 

    # Save the variables to disk. 
    save_path = saver.save(sess, model_path) 
    print("Model saved in file: %s" % save_path) 

    # stop our queue threads and properly close the session 
    coord.request_stop() 
    coord.join(threads) 
    sess.close() 
+0

Es ist schwer das Programm zu sagen, ohne zu sehen, aber ich vermute, dass etwas in Ihrem Trainingsschleife Knoten des Graphen hinzufügt. Wenn dies der Fall ist, könnten Sie auch an einem Speicherleck leiden, so [diese Dokumentation] (http://stackoverflow.com/documentation/tensorflow/3883/how-to- debug-a-memory-leak-in- tensorflow/13426/use-graph-finalize-to-catch-nodes-hinzugefügte-to-the-graph # t = 201612280201558374055) hat eine potentielle Debugging-Technik. – mrry

+0

Klingt wie ein Shlemiel The Painter-Algorithmus. Sind Sie vielleicht in der Lage, andere Metadaten zu verfolgen, indem Sie diese an eine Datenstruktur mit O (n) einfügen oder verketten? Die –

+0

Ich habe mein Code-Snippet, vielen Dank! – Lei

Antwort

1

Beim Training sollten Sie sess.run nur einmal tun. Empfehlen Sie so etwas wie dieses versuchen, hoffe, es hilft:

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    for i in range(train_set_size * epoch): 
    train_step.run([optimizer, depth_loss, prediction]) 
Verwandte Themen