2015-12-28 26 views
8

Ich versuche, einen Vorschlag von Antworten zu implementieren: Tensorflow: how to save/restore a model?tensorflow: Speichern und Wiederherstellen Sitzung

Ich habe ein Objekt, das ein tensorflow Modell in einem sklearn Stil wickelt.

import tensorflow as tf 
class tflasso(): 
    saver = tf.train.Saver() 
    def __init__(self, 
       learning_rate = 2e-2, 
       training_epochs = 5000, 
        display_step = 50, 
        BATCH_SIZE = 100, 
        ALPHA = 1e-5, 
        checkpoint_dir = "./", 
      ): 
     ... 

    def _create_network(self): 
     ... 


    def _load_(self, sess, checkpoint_dir = None): 
     if checkpoint_dir: 
      self.checkpoint_dir = checkpoint_dir 

     print("loading a session") 
     ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) 
     if ckpt and ckpt.model_checkpoint_path: 
      self.saver.restore(sess, ckpt.model_checkpoint_path) 
     else: 
      raise Exception("no checkpoint found") 
     return 

    def fit(self, train_X, train_Y , load = True): 
     self.X = train_X 
     self.xlen = train_X.shape[1] 
     # n_samples = y.shape[0] 

     self._create_network() 
     tot_loss = self._create_loss() 
     optimizer = tf.train.AdagradOptimizer(self.learning_rate).minimize(tot_loss) 

     # Initializing the variables 
     init = tf.initialize_all_variables() 
     " training per se" 
     getb = batchgen(self.BATCH_SIZE) 

     yvar = train_Y.var() 
     print(yvar) 
     # Launch the graph 
     NUM_CORES = 3 # Choose how many cores to use. 
     sess_config = tf.ConfigProto(inter_op_parallelism_threads=NUM_CORES, 
                  intra_op_parallelism_threads=NUM_CORES) 
     with tf.Session(config= sess_config) as sess: 
      sess.run(init) 
      if load: 
       self._load_(sess) 
      # Fit all training data 
      for epoch in range(self.training_epochs): 
       for (_x_, _y_) in getb(train_X, train_Y): 
        _y_ = np.reshape(_y_, [-1, 1]) 
        sess.run(optimizer, feed_dict={ self.vars.xx: _x_, self.vars.yy: _y_}) 
       # Display logs per epoch step 
       if (1+epoch) % self.display_step == 0: 
        cost = sess.run(tot_loss, 
          feed_dict={ self.vars.xx: train_X, 
            self.vars.yy: np.reshape(train_Y, [-1, 1])}) 
        rsq = 1 - cost/yvar 
        logstr = "Epoch: {:4d}\tcost = {:.4f}\tR^2 = {:.4f}".format((epoch+1), cost, rsq) 
        print(logstr) 
        self.saver.save(sess, self.checkpoint_dir + 'model.ckpt', 
         global_step= 1+ epoch) 

      print("Optimization Finished!") 
     return self 

Wenn ich laufen:

tfl = tflasso() 
tfl.fit(train_X, train_Y , load = False) 

ich Ausgang:

Epoch: 50 cost = 38.4705 R^2 = -1.2036 
    b1: 0.118122 
Epoch: 100 cost = 26.4506 R^2 = -0.5151 
    b1: 0.133597 
Epoch: 150 cost = 22.4330 R^2 = -0.2850 
    b1: 0.142261 
Epoch: 200 cost = 20.0361 R^2 = -0.1477 
    b1: 0.147998 

aber wenn ich versuche, um die Parameter zu erholen (auch ohne das Objekt zu töten): tfl.fit(train_X, train_Y , load = True)

Ich bekomme seltsame Ergebnisse. Zuallererst entspricht der geladene Wert nicht dem gespeicherten Wert.

Was ist der richtige Weg zu laden, und wahrscheinlich zuerst die gespeicherten Variablen zu überprüfen?

+0

tensorflow Dokumentation ohne ziemlich grundlegende Beispiele ist, müssen Sie in den Beispielen Ordner graben und das Gefühl der es meist auf eigene Faust – diffeomorphism

Antwort

10

TL; DR: Sie sollten versuchen, diese Klasse zu überarbeiten, so dass self.create_network() (i) nur einmal aufgerufen wird, und (ii), bevor die tf.train.Saver() aufgebaut.

Hier gibt es zwei kleine Probleme, die auf die Codestruktur und das Standardverhalten des tf.train.Saver constructor zurückzuführen sind. Wenn Sie einen Sparer ohne Argumente konstruieren (wie in Ihrem Code), sammelt er den aktuellen Satz von Variablen in Ihrem Programm und fügt dem Graphen Operationen zum Speichern und Wiederherstellen hinzu. Wenn Sie in Ihrem Code tflasso() aufrufen, wird ein Sparer erstellt, und es wird keine Variablen geben (weil create_network() wurde noch nicht aufgerufen). Daher sollte der Prüfpunkt leer sein.

Das zweite Problem ist, dass — standardmäßig — das Format eines gespeicherten Prüfpunkts ist eine Karte von name property of a variable zu seinem aktuellen Wert. Wenn Sie zwei Variablen mit demselben Namen erstellen, werden sie automatisch „uniquified“ werden, indem TensorFlow:

v = tf.Variable(..., name="weights") 
assert v.name == "weights" 
w = tf.Variable(..., name="weights") 
assert v.name == "weights_1" # The "_1" is added by TensorFlow. 

Die Folge davon ist, dass, wenn Sie self.create_network() im zweiten Aufruf zu tfl.fit() aufrufen, werden alle Variablen verschiedene Namen von den Namen, die im Checkpoint — gespeichert sind oder wenn der Sparer nach dem Netzwerk erstellt worden wäre. (Sie können dieses Verhalten vermeiden, indem ein Namen- Variable Wörterbuch zu den Sparer Konstruktor übergeben, aber das ist in der Regel recht umständlich.)

Es gibt zwei Haupt Abhilfen:

  1. In jedem Aufruf tflasso.fit() erstellen das gesamte Modell neu, indem Sie ein neues tf.Graph definieren, dann in diesem Diagramm, das das Netz bildet und ein tf.train.Saver verursacht.

  2. EMPFOHLEN das Netzwerk erstellen, dann ist die tf.train.Saver im tflasso Konstruktor, und diese Grafik Wiederverwendung bei jedem Aufruf von tflasso.fit().Beachten Sie, dass Sie möglicherweise noch etwas mehr tun müssen, um Dinge zu reorganisieren (insbesondere bin ich mir nicht sicher, was Sie mit self.X und self.xlen machen), aber es sollte möglich sein, dies mit placeholders und Feeding zu erreichen.

+0

danke zu machen! Der 'xlen' wird in 'self._create_network()' verwendet, um die Eingabegröße von 'X' zu setzen (Platzhalter init:' self.vars.xx = tf.placeholder ("float", shape = [None, self.xlen ]) '). Von dem, was Sie sagen, ist der bevorzugte Weg, 'xlen' an den Initialisierer zu übergeben. –

+0

Gibt es eine Möglichkeit, unifilter zurücksetzen/alte Variablen bei Neuinitialisierung des Objekts löschen? –

+1

Dazu müssen Sie einen neuen 'tf.Graph' erstellen und ihn zum Standard machen, bevor Sie (i) das Netzwerk erstellen und (ii) einen' Saver' erstellen. Wenn Sie den Textkörper von 'tflasso.fit()' in '' tf.Graph(). As_default(): 'blockieren und die' Saver'-Konstruktion in diesen Block verschieben, sollten die Namen jedes Mal gleich sein rufe 'fit()' auf. – mrry

Verwandte Themen