2017-07-28 1 views
1

In Tensorflow finde ich die API tf.add_to_collcetion, um der Sammlung einen Wert wie Code unten hinzuzufügen.Tensorflow Reset oder Clear Collection

def accuracy_rate(logits, labels): 
    correct = tf.nn.in_top_k(logits, labels, 1) 
    # Return the accuracy of true entries. 
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) 
    return accuracy 
with tf.Session() as sess: 
    logits, labels = ... 
    accuracy = accuracy_rate(logits, labels) 
    tf.add_to_collection('total_accuracy', sess.run(accuracy)) 

Was ich nicht in der API finden kann, ist, dass, wie kann ich alle Werte klar, dass ich bereits in einer Sammlung gespeichert haben?

+0

Ich weiß, dass Sie eine alternative Lösung gefunden haben, aber Sie können auch 'tf.get_default_graph(). Clear_collection ('total_accuracy')' für diesen Zweck betrachten. Da eine Sammlung auch 'Variablen',' trainable_variables' und 'train_op' enthält, empfiehlt es sich, die Werte nach Schlüssel zu löschen. –

Antwort

3

können Sie tf.get_collection_ref verwenden, um einen veränderbaren Verweis auf die Sammlung zu erhalten, die Sie freigeben können (es ist nur eine Python-Liste).

+0

Meinst du, dass ich die Werte in der Sammlung löschen kann? – user6932206

+1

Ja, es ist eine Python-Liste, die Sie löschen können –

0

Finden Sie eine alternative Lösung, die verschiedene tf.Graph() verwendet.

0

Ich denke, das könnte das sein, wonach Sie suchen?

In [2]: import tensorflow as tf 
In [3]: w = tf.Variable([[1,2,3], [4,5,6], [7,8,9], [3,1,5], [4,1,7]], collections=[tf.GraphKeys.WEIGHTS, tf.GraphKeys.GLOBAL_VARIABLES], dtype=tf.float32) 
In [4]: params = tf.get_collection_ref(tf.GraphKeys.WEIGHTS) 
In [5]: del params[:] 
In [6]: tf.get_collection_ref(tf.GraphKeys.WEIGHTS)                                         
Out[6]: [] 
In [10]: params = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) 
In [11]: params 
Out[11]: [<tf.Variable 'Variable:0' shape=(5, 3) dtype=float32_ref>]