2016-04-04 11 views
1

Ich baue einen Bildklassifizierer in TensorFlow und es gibt eine Klassenungleichheit in meinen Trainingsdaten. Daher muss ich, wenn der Verlust berechnet wird, den Verlust für jede Klasse mit der inversen Häufigkeit dieser Klasse in den Trainingsdaten gewichten.TensorFlow: Definition von Variablen im globalen Gültigkeitsbereich

So, hier ist mein Code:

# Get the softmax from the final layer of the network 
softmax = tf.nn.softmax(final_layer) 
# Weight the softmax by the inverse frequency of the weights 
weighted_softmax = tf.mul(softmax, class_weights) 
# Compute the cross entropy 
cross_entropy = -tf.reduce_sum(y_ * tf.log(softmax)) 
# Define the optimisation 
train_step = tf.train.AdamOptimizer(1e-5).minimize(cross_entropy) 

# Run the training 
session.run(tf.initialize_all_variables()) 
for i in range(10000): 
    # Get the next batch 
    batch = datasets.train.next_batch(64) 
    # Run a training step 
    train_step.run(feed_dict = {x: batch[0], y_: batch[1]}) 

Meine Frage ist: Kann ich class_weights als nur ein tf.constant(...) in globalem Bereich gespeichert werden? Oder muss ich es als ein Parameter weitergeben, wenn cross_entropy berechnet wird?

Der Grund, warum ich mich wundere, ist, dass class_weights für jede Charge unterschiedlich ist. Daher bin ich besorgt, dass, wenn es nur im globalen Gültigkeitsbereich definiert wird, wenn das Tensor Flow-Diagramm konstruiert wird, es nur die Anfangswerte in class_weights nimmt und sie dann nie aktualisiert. Wenn ich die class_weights unter Verwendung der feed_dict beim Berechnen von weighted_softmax passiere, dann sage ich Tensor Flow ausdrücklich, die aktuellen, aktualisierten Werte in class_weights zu verwenden.

Jede Hilfe wäre willkommen. Vielen Dank!

Antwort

1

Ich denke, class_weights eine tf.constant ist in Ordnung. Die Klassengewichtung sollte für den gesamten Datensatz und nicht für jeden Minibatch erfolgen.

Auch ein anderer Ansatz, den Sie in Betracht ziehen sollten, ist das Sampling, so dass jeder Batch die gleiche Anzahl von jeder Klasse hat?

+0

Wenn ich Minibatches mit unausgeglichenen Klassen habe, sollte ich noch nach den globalen Klassenstatistiken gewichten? Ich hätte gedacht, dass es Sinn macht, nach den Klassenstatistiken in diesem Minibatch zu gewichten ...? Die berechneten Gradienten sind nur für diesen Minibatch und deshalb sollten sie sich nicht um die globalen Klassenstatistiken kümmern ... – Karnivaurus

+1

Sie machen einen guten Punkt. Eine Sache ist, dass die Mini-Chargen aus der breiteren Klassenverteilung entnommen werden. Würden Sie mit einer Minibuchse gewichten, würde das Gewicht über eine vollständige Iteration dem Gewicht der breiteren Verteilung entsprechen. Also denke ich, dass es effizienter ist, class_weights nur einmal zu berechnen. Ich wünschte, ich könnte eine Art Zitat dafür finden, aber bis jetzt kann ich nicht ... –

Verwandte Themen