2016-05-04 5 views
0

Ich bin ein k-Mittel auf der Kugel, ausgehend von @ dga gist. Die Unit-Norm-Constraint bedeutet im Wesentlichen die Verwendung von inneren Produkten anstelle von paarweisen Distanzen unter Verwendung von argmax anstelle von argmin und Summe + Normalisierung anstelle der Mittelwertbildung, um die Schwerpunkte zu aktualisieren.Ersetzen Sie einige Zeilen einer Variablen in Tensorflow mit boolescher Maske und Indexliste

Jetzt versuche ich tote Zentroide durch die am wenigsten gut vertretenen Datenpunkte zu ersetzen. unsorted_segment_sum wird eine Summe von 0 für tot Centroide zurück:

total = tf.unsorted_segment_sum(points, best_centroids, K) 

Daraus ich eine boolean Maske von toten Centroide erhalten:

deads = tf.equal(total, 0) 

... eine Zählung der Toten Centroide:

dead_count = tf.reduce_sum(tf.as_type(deads, 'int64')) 

... und schließlich eine Liste mit den Indizes der Datenpunkte, die vom aktuellen Modell am schlechtesten dargestellt werden:

_, dead_replacement_idx = tf.nn.top_k(-assignment_qualities, 
             k=dead_count, sorted=False) 

Wie kann ich die toten Zentroide ersetzen? In numpy würde dies nun kommen dazu etwa:

means[deads] = points[dead_replacement_idx] 

Wie kann ich etwas ähnliches in Tensorflow tun?

+0

https://github.com/tensorflow/tensorflow/issues/206 ... Ich glaube, jemand beginnt damit zu arbeiten –

Antwort

0

Wenn Sie Ihre Mittel speichern in Variables können Sie verwenden scatter_update

tf.reset_default_graph() 

means = tf.Variable(np.array([[1,1],[2,2],[3,3]]), dtype=np.float32) 
indices = tf.constant([0, 2]) 
new_mean = tf.constant([-1, -1], dtype=np.float32) 

new_mean_matrix = tf.reshape(new_mean, [1, -1]) 
tile_shape = tf.pack([tf.size(indices), 1]) 
new_mean_matrix_tiled = tf.tile(new_mean_matrix, tile_shape) 
update_op = tf.scatter_update(means, indices, new_mean_matrix_tiled) 

sess = tf.InteractiveSession() 
sess.run(tf.initialize_all_variables()) 
print "Before update" 
print sess.run(means) 
print "Updating rows", indices.eval(), "to", new_mean.eval() 
sess.run(update_op) 
print "After update" 
print sess.run(means) 

Ergebnis

Before update 
[[ 1. 1.] 
[ 2. 2.] 
[ 3. 3.]] 
Updating rows [0 2] to [-1. -1.] 
After update 
[[-1. -1.] 
[ 2. 2.] 
[-1. -1.]] 
+0

In meinem Fall ist 'means' ein Tensor * abgeleitet von * der' centroids' Variable. Wenn sich die Zuweisungen ändern, werden 'centroids' und' cluster_assignments' in '' control_dependencies(): ... tf.group() '' assign() 'ed. Ich nehme an, dass ich dann das 'scatter_update()' danach machen könnte, aber ich würde ein anderes 'control_dependencies()' brauchen, um sicherzustellen, dass die toten Zentroide * nach * dem 'assign' ersetzt werden? –

+0

Der einfachste Weg, um sicherzustellen, dass ein Satz von Berechnungen nach den anderen geschieht, ist die Verwendung von zwei verschiedenen Laufaufrufen. IE, 'sess.run (update_variables); sess.run (do_calculation) '. Wenn Sie 'with_dependencies' verwenden, vergewissern Sie sich, dass alle relevanten Ops im Block' with_dependencies' erstellt wurden. –

Verwandte Themen