2016-11-01 9 views
1

Ich versuche, ein Programm mit Tensorflow zu erstellen, etwas mit einer Sequenz zu tun. Eine Option in den Funktionen tf.nn.seq2seq ist der Parameter feed_previous. Dies kann ich zur Trainingszeit, aber nicht zur Auswertung/Laufzeit verwenden. Zur Zeit habe ich versucht, dies:Tensorflow Wiederverwendung rnn_seq2seq Modell

(outputs_dict,state_dict) = tf.nn.seq2seq.one2many_rnn_seq2seq(enc_inp,decoder_inputs_dictionary,cell, vocab_size, decoder_symbols_dictionary,embedding_size=embedding_dim) 
(evaluation_dict,evaluation_state_dict) = tf.nn.seq2seq.one2many_rnn_seq2seq(enc_inp,decoder_inputs_dictionary,cell, vocab_size, decoder_symbols_dictionary,embedding_size=embedding_dim,feed_previous=True) 

Aber ich bekomme die Fehlermeldung: "Valueerror:.? Variable one2many_rnn_seq2seq/RNN/EmbeddingWrapper/Einbettung bereits vorhanden ist, nicht erlaubt Haben Sie Wiederverwendung = True in VarScope gesetzt bedeuten"

Weiß jemand, wie man

  • mein Modell wieder verwenden, so kann ich während des Trainings feed_previous verwenden, aber nicht bei der Auswertung?
+0

Haben Sie eine [variable_scope] betrachtet (https://www.tensorflow.org/versions/r0.11/how_tos/ variable_scope/index.html)? Sie würden beide in Variablenbereiche mit demselben Namen umbrechen und für den zweiten Eintrag die Option "Wiederverwendung = True" festlegen. –

+0

Welp, das war's! Es war nicht klar, wie man den Variablenbereich in diesem Fall verwendet, schaue auf meine Antwort für den Code, den ich verwendet habe, um es zu beheben. – rmeertens

Antwort

1

Allen wies mich in die richtige Richtung!

ist der Code, den ich verwendet, um meine Probleme (glaube ich) zu lösen:

with tf.variable_scope("decoder1") as scope: 
    (outputs_dict, state_dict) = tf.nn.seq2seq.one2many_rnn_seq2seq(enc_inp, train_decoder_inputs_dictionary, 
                    cell, vocab_size, decoder_symbols_dictionary, 
                    embedding_size=embedding_dim, feed_previous=True) 
with tf.variable_scope("decoder1",reuse=True) as scope: 
    (runtime_outputs_dict, runtime_state_dict) = tf.nn.seq2seq.one2many_rnn_seq2seq(enc_inp, runtime_decoder_inputs_dictionary, 
                    cell, vocab_size, decoder_symbols_dictionary, 
                    embedding_size=embedding_dim, feed_previous=True) 
Verwandte Themen