2017-02-11 6 views
0

Nehmen Sie values und Tensor T beide haben Form (N,K). Nun, wenn wir sie in Bezug auf Matrizen denken, möchte ich für jede Zeile von T das Zeilenelement erhalten, das dem Index entspricht, wobei values sein Maximum hat. Ich kann leicht diese Indizes mitSammeln Sie Elemente entlang der zweiten Dimension des Tensors

max_indicies = tf.argmax(T, 1) 

finden, die (N) einen Tensor Form zurückkehrt. Nun, wie kann ich diese Indizes von T so sammeln, dass ich etwas von Form N bekomme? Ich habe versucht,

result = tf.gather(T,max_indices) 

aber nicht das Richtige tun - es gibt etwas von Form (N,K) was bedeutet, dass es nicht alles aufzusammeln hat.

Antwort

2

Sie können tf.gather_nd verwenden.

Zum Beispiel

import tensorflow as tf 

sess = tf.InteractiveSession() 

values = tf.constant([[0, 0, 0, 1], 
         [0, 1, 0, 0], 
         [0, 0, 1, 0]]) 

T = tf.constant([[0, 1, 2 , 3], 
       [4, 5, 6 , 7], 
       [8, 9, 10, 11]]) 

max_indices = tf.argmax(values, axis=1) 
# If T.get_shape()[0] is None, you can replace it with tf.shape(T)[0]. 
result = tf.gather_nd(T, tf.stack((tf.range(T.get_shape()[0], 
              dtype=max_indices.dtype), 
            max_indices), 
            axis=1)) 

print(result.eval()) 

jedoch, wenn die Reihen der values und T höher sind, die Verwendung von tf.gather_nd wird ein wenig umständlich sein. Ich habe meine aktuelle Lösung unter this question veröffentlicht. Im Fall von hochdimensionalen values und T könnte es eine bessere Lösung geben.

+0

Vielen Dank, mein Herr. – Pueggel

Verwandte Themen