2017-08-18 4 views
1

Ich versuche, mit dieser RNN tutorial on medium folgen, Refactoring es als ich entlang gehen. Wenn ich meinen Code ausführe, scheint es zu funktionieren, aber als ich versuchte, die current state Variable auszudrucken, um zu sehen, was innerhalb des neuronalen Netzes geschieht, bekam ich alle 1 s. Ist das erwartetes Verhalten? Wird der Status aus irgendeinem Grund nicht aktualisiert? Von dem, was ich verstehe, sollte die current state die neuesten Werte in der versteckten Ebene für alle Chargen enthalten, so dass es nicht alle 1 s sein sollte. Jede Hilfe würde sehr geschätzt werden.Tensorflow versteckten Zustand scheint nicht zu ändern

def __train_minibatch__(self, batch_num, sess, current_state): 
    """ 
    Trains one minibatch. 

    :type batch_num: int 
    :param batch_num: the current batch number. 

    :type sess: tensorflow Session 
    :param sess: the session during which training occurs. 

    :type current_state: numpy matrix (array of arrays) 
    :param current_state: the current hidden state 

    :type return: (float, numpy matrix) 
    :param return: (the calculated loss for this minibatch, the updated hidden state) 
    """ 
    start_index = batch_num * self.settings.truncate 
    end_index = start_index + self.settings.truncate 

    batch_x = self.x_train_batches[:, start_index:end_index] 
    batch_y = self.y_train_batches[:, start_index:end_index] 
    total_loss, train_step, current_state, predictions_series = sess.run(
     [self.total_loss_fun, self.train_step_fun, self.current_state, self.predictions_series], 
     feed_dict={ 
      self.batch_x_placeholder:batch_x, 
      self.batch_y_placeholder:batch_y, 
      self.hidden_state:current_state 
     }) 
    return total_loss, current_state, predictions_series 
# End of __train_minibatch__() 

def __train_epoch__(self, epoch_num, sess, current_state, loss_list): 
    """ 
    Trains one full epoch. 

    :type epoch_num: int 
    :param epoch_num: the number of the current epoch. 

    :type sess: tensorflow Session 
    :param sess: the session during training occurs. 

    :type current_state: numpy matrix 
    :param current_state: the current hidden state. 

    :type loss_list: list of floats 
    :param loss_list: holds the losses incurred during training. 

    :type return: (float, numpy matrix) 
    :param return: (the latest incurred lost, the latest hidden state) 
    """ 
    self.logger.info("Starting epoch: %d" % (epoch_num)) 

    for batch_num in range(self.num_batches): 
     # Debug log outside of function to reduce number of arguments. 
     self.logger.debug("Training minibatch : ", batch_num, " | ", "epoch : ", epoch_num + 1) 
     total_loss, current_state, predictions_series = self.__train_minibatch__(batch_num, sess, current_state) 
     loss_list.append(total_loss) 
    # End of batch training 

    self.logger.info("Finished epoch: %d | loss: %f" % (epoch_num, total_loss)) 
    return total_loss, current_state, predictions_series 
# End of __train_epoch__() 

def train(self): 
    """ 
    Trains the given model on the given dataset, and saves the losses incurred 
    at the end of each epoch to a plot image. 
    """ 
    self.logger.info("Started training the model.") 
    self.__unstack_variables__() 
    self.__create_functions__() 
    with tf.Session() as sess: 
     sess.run(tf.global_variables_initializer()) 
     loss_list = [] 

     current_state = np.zeros((self.settings.batch_size, self.settings.hidden_size), dtype=float) 
     for epoch_idx in range(1, self.settings.epochs + 1): 
      total_loss, current_state, predictions_series = self.__train_epoch__(epoch_idx, sess, current_state, loss_list) 
      print("Shape: ", current_state.shape, " | Current output: ", current_state) 
      # End of epoch training 

    self.logger.info("Finished training the model. Final loss: %f" % total_loss) 
    self.__plot__(loss_list) 
    self.generate_output() 
# End of train() 

aktualisiert

Nach den second part of the tutorial Abschluss und mit dem eingebauten in RNN api, ist das Problem weg, was bedeutet, dass es entweder etwas falsch mit der Art, wie ich zu meinen current_state Variable oder Änderungen verwenden Die Tensorflow-API hat etwas Verrücktes verursacht (ich bin mir ziemlich sicher, dass es das Erste ist). Lassen Sie die Frage offen, falls jemand eine definitive Antwort hat.

Antwort

0

Zuerst sollten Sie sicherstellen, dass "es scheint zu funktionieren" ist wahr und Ihr Testfehler wird wirklich niedriger.

Eine Hypothese, die ich habe, ist, dass der allerletzte Stapel am Ende mit Nullen verfälscht wird, weil die Länge der Daten total_series_length/batch_size kein Vielfaches von truncated_backprop_length ist. (Ich habe nicht überprüft, dass es mit Nullen gefüllt ist. Der Code im Tutorial ist zu alt, um auf meiner tf-Version ausgeführt zu werden, und wir haben Ihren Code nicht.) Dieser letzte Minibatch mit nur Nullen am Ende könnte die endgültige current_state zu allen Einsen konvergieren führen. Auf jedem anderen Mini-Batch current_state wären nicht alle Einsen.

Sie könnten versuchen, die current_state jedes Mal, wenn Sie sess.run, in __train_minibatch__ ausführen. Oder drucken Sie es einfach alle 1000 Mini-Chargen.

+0

Der Testfehler wird niedriger - von ~ 6,7 auf ~ 3,4. Ich habe deinen Vorschlag, 'current_state' auf der Minibatch-Ebene zu betrachten, versucht und ich bekomme alle 1s in jedem Minibatch. Mein Code ist auf [einem GitHub Repo] (https://github.com/ffrankies/tf-terry), wenn es hilft (Ich versuche, das Tutorial auf einem textbasierten Dataset ausführen zu lassen, also habe ich musste den Tutorial-Code ein wenig anpassen). Außerdem habe ich Daten an den Enden meiner Matrixzeilen, aber das Problem bleibt bestehen, auch wenn ich die Daten in ein langes Array umwandle und das in Minibatches umgruppiere. – frankie

Verwandte Themen