2016-03-29 4 views
0

Wie wird die Kreuzvalidierung von Tensorflow korrekt durchgeführt? Im Folgenden sind mein Code-Schnipsel:
Gespeicherte Tensorflow-Modelle werden größer, wenn Kreuzvalidierung ausgeführt wird

class TextCNN: 
    ... 
    def train(self): 
    saver = tf.train.Saver(tf.all_variables()) 
    with tf.Session() as sess: 
     ... 
     # training loop 
     ... 
     # training finished 
     path = saver.save(sess, "{:s}/model.{:d}".format(self.checkpoint_dir, self.test_fold)) 

if __name__ == "__main__": 
    for i in range(CV_SIZE): 
    cnn = TextCNN(i) 
    cnn.train() 

Die gespeicherte Modellgröße für Falte 0 ist um 2M. Aber für Falte 1 um 4M, Falte 2 um 6M, und so weiter.

Antwort

3

Meine Vermutung ist, dass der TextCNN Konstruktor und train() Methode Knoten auf den Standard Graph hinzufügen (tf.get_default_graph()) und dem gespeicherten Modell alle bisherigen Graphen enthält, so ist es „versehentlich quadratische“ und mit jeder Iteration der __main__ Schleife wächst.

Die Lösung ist zum Glück, einfach. Schreiben Sie einfach Ihre Hauptschleife wie folgt:

if __name__ == "__main__": 
    for i in range(CV_SIZE): 
    with tf.Graph().as_default(): # Performs training in a new, empty graph. 
     cnn = TextCNN(i) 
     cnn.train() 

Dies wird ein neues, leeres Diagramm für jede Iteration der Schleife erstellen. Daher enthält das gespeicherte Modell keine Knoten (und Variablen) aus der vorherigen Iteration, und die Modellgröße sollte konstant bleiben.

Beachten Sie, dass Sie, wenn möglich, versuchen sollten, das gleiche Diagramm für alle Iterationen wiederzuverwenden. Ich weiß jedoch, dass dies möglicherweise nicht möglich ist, wenn sich die Struktur des Graphen von einer Iteration zur nächsten ändert.

Verwandte Themen