Ich möchte ein GAN mit Tensorflow entwickeln, wobei der Generator ein Autoencoder und der Diskriminator ein Convolutional Neural Net mit binärer Ausgabe ist. Es gibt kein Problem, einen Autoencoder und das CNN zu entwickeln, aber meine Idee ist es, 1 Epoche für jede der Komponenten (Diskriminator und Generator) zu trainieren und diesen Zyklus für 1000 Epochen zu wiederholen, wobei die Ergebnisse (Gewichte) der vorherigen Trainingsepoche beibehalten werden für den nächsten. Wie kann ich das operationalisieren?Wie wird der GAN-Generator und -Diskriminator asynchron im Tensorflow aktualisiert?
Antwort
Wenn Sie zwei ops genannt train_step_generator
und train_step_discriminator
(von denen jeder zum Beispiel in der Form tf.train.AdamOptimizer().minimize(loss)
mit einem entsprechenden Verlust für jeden), dann Ihre Trainingsschleife etwas ähnlich der folgenden Struktur sein sollte:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(1000):
if epoch%2 == 0: # train discriminator on even epochs
for i in range(training_set_size/batch_size):
z_ = np.random.normal(0,1,batch_size) # this is the input to the generator
batch = get_next_batch(batch_size)
sess.run(train_step_discriminator,feed_dict={z:z_, x:batch})
else: # train generator on odd epochs
for i in range(training_set_size/batch_size):
z_ = np.random.normal(0,1,batch_size) # this is the input to the generator
sess.run(train_step_generator,feed_dict={z:z_})
Die Gewichte bleiben zwischen den Iterationen erhalten.
Ich löste das Problem. Tatsächlich möchte ich, dass die Ausgabe des Autoencoders die Eingabe des CNN ist, das GAN verbindet und Gewichte im Verhältnis 1: 1 aktualisiert. Ich bemerkte, dass ich besonders vorsichtig sein musste, um die Verluste des Generators und des Diskriminators zu unterscheiden, sonst wird zu Beginn der zweiten Schleife der Tensorverlust des Generators durch einen Float ersetzt, der letzte von Discriminator erzeugte Verlust.
Heres der Code:
with tf.Session() as sess:
sess.run(init)
for i in range(1, num_steps+1):
hier der Generator Ausbildung
batch_x, batch_y=next_batch(batch_size, x_train_noisy, x_train)
_, l = sess.run([optimizer, loss], feed_dict={X: batch_x.reshape(n,784),
Y:batch_y})
if i % display_step == 0 or i == 1:
print('Epoch %i: Denoising Loss: %f' % (i, l))
hier der Ausgang des Generators wird als Eingang für die
Diskriminator verwendet werdenoutput=sess.run([decoder_op],feed_dict={X: x_train})
x_train2=np.array(output).reshape(n,784).astype(np.float64)
hier das Diskriminatortraining
batch_x2, batch_y2 = next_batch(batch_size, x_train2, y_train)
sess.run(train_op, feed_dict={X2: batch_x2.reshape(n,784), Y2: batch_y2, keep_prob: 0.8})
if i % display_step == 0 or i == 1:
loss3, acc = sess.run([loss_op2, accuracy], feed_dict={X2: batch_x2,
Y2: batch_y2,
keep_prob: 1.0})
print("Epoch " + str(i) + ", CNN Loss= " + \
"{:.4f}".format(loss3) + ", Training Accuracy= " + "{:.3f}".format(acc))
diese Weise die asynchrone Aktualisierung kann im Verhältnis 1 operationalisierbar: 1, 1: 5, 5: 1:
(Discriminator Generator) oder irgendeine andere Art und Weise- 1. Wie wird der Verlust im Tensorflow berechnet?
- 2. aktualisiert Tabellenzelle Bild asynchron
- 3. Erstellen Sie Inhalt, der asynchron aktualisiert PHP
- 4. Wie funktioniert der Diskriminator auf DCGAN?
- 5. ListView der Objekte, die asynchron aktualisiert werden
- 6. Roslyn: Wie man mehrere Projekte asynchron aktualisiert?
- 7. Wie wird ein Gleitfenster im Tensorflow implementiert?
- 8. Wie wird die Sitzung im Tensorflow wiederhergestellt?
- 9. Wie wird der Primärschlüssel aktualisiert?
- 10. Wie wird eine Variable im Tensorflow inkrementiert?
- 11. Wie wird SaveFileDialog asynchron verwendet?
- 12. Wie Diskriminator Spalte in JPA
- 13. Wie aktualisiert das Tensorflow-word2vec-Tutorial Einbettungen?
- 14. Wie wird der Streudiagramm aktualisiert?
- 15. Wie wird der Status aktualisiert?
- 16. Was ist der "Diskriminator" in addr2line?
- 17. Wie wird der Status der Eltern aktualisiert?
- 18. Wie wird die Dosierung im Tensorflow-Serving durchgeführt?
- 19. Wird Cassandra asynchron ausgelöst?
- 20. JPA Diskriminator Spalte Problem
- 21. Wie TensorArray und while_loop im Tensorflow zusammenarbeiten?
- 22. Korrektur der Dekonvolutionsebene im Tensorflow
- 23. Style wird nicht aktualisiert, wenn der Status aktualisiert wird
- 24. Wie wird protokolliert, was im Eingabefeld select2 eingegeben wird und eventuell ein anderes Eingabefeld aktualisiert wird?
- 25. Wie wird DrawingVisuals asynchron mit Async und Await geladen?
- 26. Sortieren nach Diskriminator - EF
- 27. FFmpeg creating mp4 wird asynchron
- 28. Wie wird ein PHP-Skript asynchron ausgeführt?
- 29. Erläuterung der GRU-Zelle im Tensorflow?
- 30. Doctrine2/Abfrage Diskriminator Wert