2

Eine Möglichkeit zur Verbesserung der Stabilität bei tiefen Q-Learning-Aufgaben besteht darin, eine Reihe von Zielgewichtungen für das Netzwerk beizubehalten, die langsam aktualisiert werden und zur Berechnung von Q-Wert-Zielen verwendet werden. Als Ergebnis werden zu verschiedenen Zeitpunkten in der Lernprozedur zwei unterschiedliche Sätze von Gewichtungen in dem Vorwärtsdurchlauf verwendet. Für normale DQN ist dies zu implementieren ist nicht schwer, da die Gewichte tensorflow Variablen sind, die in einem feed_dict dh eingestellt werden kann:Wie kann ich auf die Gewichte einer wiederkehrenden Zelle in Tensorflow zugreifen?

sess = tf.Session() 
input = tf.placeholder(tf.float32, shape=[None, 5]) 
weights = tf.Variable(tf.random_normal(shape=[5,4], stddev=0.1) 
bias = tf.Variable(tf.constant(0.1, shape=[4]) 
output = tf.matmul(input, weights) + bias 
target = tf.placeholder(tf.float32, [None, 4]) 
loss = ... 

... 

#Here we explicitly set weights to be the slowly updated target weights 
sess.run(output, feed_dict={input: states, weights: target_weights, bias: target_bias}) 

# Targets for the learning procedure are computed using this output. 

.... 

#Now we run the learning procedure, using the most up to date weights, 
#as well as the previously computed targets 
sess.run(loss, feed_dict={input: states, target: targets}) 

Ich möchte diese Zielnetzwerktechnik in einer wiederkehrenden Version von DQN verwenden, aber Ich weiß nicht, wie ich auf die Gewichte innerhalb einer wiederkehrenden Zelle zugreifen und sie setzen kann. Insbesondere verwende ich eine tf.nn.rnn_cell.BasicLSTMCell, aber ich würde gerne wissen, wie das für jede Art von wiederkehrenden Zelle zu tun.

Antwort

3

Das BasicLSTMCell stellt seine Variablen nicht als Teil seiner öffentlichen API zur Verfügung. Ich empfehle, dass Sie entweder nachschlagen, welche Namen diese Variablen in Ihrem Diagramm haben, und diese Namen füttern (diese Namen werden sich wahrscheinlich nicht ändern, da sie sich in den Prüfpunkten befinden und diese Namen würden die Prüfpunktkompatibilität beeinträchtigen).

Alternativ können Sie eine Kopie von BasicLSTMCell erstellen, die die Variablen verfügbar macht. Dies ist der sauberste Ansatz, denke ich.

+1

Das hat funktioniert, danke Alexandre. Für alle, die mehr Details wünschen, werden die Gewichtungs- und Bias-Variablen erzeugt, wenn Sie die wiederkehrende Zelle in 'tf.nn.dynamicrnn()' einspeisen. Nach dem Ausführen von 'tf.initialize_all_variables()' in der Sitzung, gibt es zwei neue trainierbare Tensoren, die Sie sehen können, wenn Sie 'tf.trainable_variables()' ausführen. In meinem Fall wurden sie 'RNN/BasicLSTMCell/Linear/Matrix: 0' und' RNN/BasicLSTMCell/Linear/Bias: 0' genannt. –

Verwandte Themen