Wo finde ich die Backpropagation (durch Zeit) Code in Tensorflow (Python API)? Oder werden andere Algorithmen verwendet?Backpropagation (durch Zeit) Code in Tensorflow
Zum Beispiel, wenn ich ein LSTM-Netz erstellen.
Wo finde ich die Backpropagation (durch Zeit) Code in Tensorflow (Python API)? Oder werden andere Algorithmen verwendet?Backpropagation (durch Zeit) Code in Tensorflow
Zum Beispiel, wenn ich ein LSTM-Netz erstellen.
Die gesamte Backpropagation in TensorFlow wird implementiert, indem die Operationen im Vorwärtsdurchlauf des Netzwerks automatisch differenziert und explizite Operationen zur Berechnung des Gradienten an jedem Punkt im Netzwerk hinzugefügt werden. Die allgemeine Umsetzung kann in tf.gradients()
gefunden werden, aber die bestimmte Version verwendet wird, hängt ab, wie Ihre LSTM implementiert:
tf.gradients()
verwendet, um eine abgerollte Backpropagation Schleife in die entgegengesetzte Richtung aufzubauen.tf.while_loop()
implementiert ist, wird zusätzliche Unterstützung für das Unterscheiden von Schleifen in control_flow_grad.py
verwendet.Ich bin darüber nicht sicher, aber dies funktionieren könnte:
Da RNNs kann wie Feed-Forward-Netze trainiert werden, der Code sehr ähnlich ist. Dies ist, wie Sie ein Feed-Forward-Netz der Bahn: (X ist der Eingang)
train = tf.train.GradientDescentOptimizer(learning_rate).minimize(error)
# Session
sess = tf.Session()
sess.run(tf.initialize_all_variables())
for i in range(epochs):
sess.run(train, feed_dict={X: [[0, 0, 1], [1, 1, 1], [1, 0, 1], [0, 1, 1]], labels: [[0], [1], [1], [0]]})
Der einzige Unterschied in Backpropagation durch die Zeit ist, dass jede Epoche nun eine verschachtelte Zeitschleife hat.
Dies ist der Code eine einfache rnn zu trainieren:
train = tf.train.GradientDescentOptimizer(learning_rate).minimize(error)
time_series = [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
for i in range(number_of_epochs):
for j in range(len(time_series) - 1):
curr_X = time_series[j+1]
curr_prev = time_series[j]
lbs = curr_prev
sess.run(train, feed_dict={X: [[curr_X]], prev_val: [[curr_prev]], labels: [[lbs]]})
In diesem Code die rnn eine Zeitreihe mit alternativen 1 und 0 lernen.