Ich versuche, mit dieser RNN tutorial on medium folgen, Refactoring es als ich entlang gehen. Wenn ich meinen Code ausführe, scheint es zu funktionieren, aber als ich versuchte, die current state
Variable auszudrucken, um zu sehen, was innerhalb des neuronalen Netzes geschieht, bekam ich alle 1
s. Ist das erwartetes Verhalten? Wird der Status aus irgendeinem Grund nicht aktualisiert? Von dem, was ich verstehe, sollte die current state
die neuesten Werte in der versteckten Ebene für alle Chargen enthalten, so dass es nicht alle 1
s sein sollte. Jede Hilfe würde sehr geschätzt werden.Tensorflow versteckten Zustand scheint nicht zu ändern
def __train_minibatch__(self, batch_num, sess, current_state):
"""
Trains one minibatch.
:type batch_num: int
:param batch_num: the current batch number.
:type sess: tensorflow Session
:param sess: the session during which training occurs.
:type current_state: numpy matrix (array of arrays)
:param current_state: the current hidden state
:type return: (float, numpy matrix)
:param return: (the calculated loss for this minibatch, the updated hidden state)
"""
start_index = batch_num * self.settings.truncate
end_index = start_index + self.settings.truncate
batch_x = self.x_train_batches[:, start_index:end_index]
batch_y = self.y_train_batches[:, start_index:end_index]
total_loss, train_step, current_state, predictions_series = sess.run(
[self.total_loss_fun, self.train_step_fun, self.current_state, self.predictions_series],
feed_dict={
self.batch_x_placeholder:batch_x,
self.batch_y_placeholder:batch_y,
self.hidden_state:current_state
})
return total_loss, current_state, predictions_series
# End of __train_minibatch__()
def __train_epoch__(self, epoch_num, sess, current_state, loss_list):
"""
Trains one full epoch.
:type epoch_num: int
:param epoch_num: the number of the current epoch.
:type sess: tensorflow Session
:param sess: the session during training occurs.
:type current_state: numpy matrix
:param current_state: the current hidden state.
:type loss_list: list of floats
:param loss_list: holds the losses incurred during training.
:type return: (float, numpy matrix)
:param return: (the latest incurred lost, the latest hidden state)
"""
self.logger.info("Starting epoch: %d" % (epoch_num))
for batch_num in range(self.num_batches):
# Debug log outside of function to reduce number of arguments.
self.logger.debug("Training minibatch : ", batch_num, " | ", "epoch : ", epoch_num + 1)
total_loss, current_state, predictions_series = self.__train_minibatch__(batch_num, sess, current_state)
loss_list.append(total_loss)
# End of batch training
self.logger.info("Finished epoch: %d | loss: %f" % (epoch_num, total_loss))
return total_loss, current_state, predictions_series
# End of __train_epoch__()
def train(self):
"""
Trains the given model on the given dataset, and saves the losses incurred
at the end of each epoch to a plot image.
"""
self.logger.info("Started training the model.")
self.__unstack_variables__()
self.__create_functions__()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
loss_list = []
current_state = np.zeros((self.settings.batch_size, self.settings.hidden_size), dtype=float)
for epoch_idx in range(1, self.settings.epochs + 1):
total_loss, current_state, predictions_series = self.__train_epoch__(epoch_idx, sess, current_state, loss_list)
print("Shape: ", current_state.shape, " | Current output: ", current_state)
# End of epoch training
self.logger.info("Finished training the model. Final loss: %f" % total_loss)
self.__plot__(loss_list)
self.generate_output()
# End of train()
aktualisiert
Nach den second part of the tutorial Abschluss und mit dem eingebauten in RNN api, ist das Problem weg, was bedeutet, dass es entweder etwas falsch mit der Art, wie ich zu meinen current_state
Variable oder Änderungen verwenden Die Tensorflow-API hat etwas Verrücktes verursacht (ich bin mir ziemlich sicher, dass es das Erste ist). Lassen Sie die Frage offen, falls jemand eine definitive Antwort hat.
Der Testfehler wird niedriger - von ~ 6,7 auf ~ 3,4. Ich habe deinen Vorschlag, 'current_state' auf der Minibatch-Ebene zu betrachten, versucht und ich bekomme alle 1s in jedem Minibatch. Mein Code ist auf [einem GitHub Repo] (https://github.com/ffrankies/tf-terry), wenn es hilft (Ich versuche, das Tutorial auf einem textbasierten Dataset ausführen zu lassen, also habe ich musste den Tutorial-Code ein wenig anpassen). Außerdem habe ich Daten an den Enden meiner Matrixzeilen, aber das Problem bleibt bestehen, auch wenn ich die Daten in ein langes Array umwandle und das in Minibatches umgruppiere. – frankie