2016-03-18 7 views
6

Ich habe einen Tensor logits mit den Dimensionen [batch_size, num_rows, num_coordinates] (d.h. jeder Logit in der Charge ist eine Matrix). In meinem Fall ist die Stapelgröße 2, es gibt 4 Zeilen und 4 Koordinaten.Wie wählt man Zeilen aus einem 3D-Tensor in TensorFlow?

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0], 
         [11.0, 10.0, 10.0, 30.0], 
         [12.0, 10.0, 10.0, 20.0], 
         [13.0, 10.0, 10.0, 20.0]], 
        [[14.0, 11.0, 21.0, 31.0], 
         [15.0, 11.0, 11.0, 21.0], 
         [16.0, 11.0, 11.0, 21.0], 
         [17.0, 11.0, 11.0, 21.0]]]) 

Ich mag die erste und die zweite Zeile der ersten Partie und die zweiten und vierten Zeile der zweiten Charge auszuwählen.

indices = tf.constant([[0, 1], [1, 3]]) 

So würde die gewünschte Ausgabe

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0], 
         [11.0, 10.0, 10.0, 30.0]], 
        [[15.0, 11.0, 11.0, 21.0], 
         [17.0, 11.0, 11.0, 21.0]]]) 

sein Wie kann ich tun dies mit TensorFlow? Ich habe versucht, tf.gather(logits, indices) zu verwenden, aber es gab nicht zurück, was ich erwartete. Vielen Dank!

Antwort

7

Dies ist in TensorFlow möglich, aber etwas unbequem, da tf.gather() derzeit nur mit eindimensionalen Indizes funktioniert und nur Schichten aus der 0ten Dimension eines Tensors auswählt. Allerdings ist es immer noch möglich ist, Ihr Problem effizient zu lösen, indem Sie die Argumente verwandeln, so dass sie zu tf.gather() weitergegeben werden können:

logits = ... # [2 x 4 x 4] tensor 
indices = tf.constant([[0, 1], [1, 3]]) 

# Use tf.shape() to make this work with dynamic shapes. 
batch_size = tf.shape(logits)[0] 
rows_per_batch = tf.shape(logits)[1] 
indices_per_batch = tf.shape(indices)[1] 

# Offset to add to each row in indices. We use `tf.expand_dims()` to make 
# this broadcast appropriately. 
offset = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1) 

# Convert indices and logits into appropriate form for `tf.gather()`. 
flattened_indices = tf.reshape(indices + offset, [-1]) 
flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]])) 

selected_rows = tf.gather(flattened_logits, flattened_indices) 

result = tf.reshape(selected_rows, 
        tf.concat(0, [tf.pack([batch_size, indices_per_batch]), 
            tf.shape(logits)[2:]])) 

Bitte beachte, dass, da diese verwenden tf.reshape() und nicht tf.transpose(), es braucht nicht zu ändern die (potenziell großen) Daten in der logits Tensor, so sollte es ziemlich effizient sein.

+0

Während Ihre Antwort ist großartig, ich denke, heute kann es mit 'tf.gather_nd' ersetzt werden, die zum Zeitpunkt Ihres Schreibens wahrscheinlich noch nicht verfügbar war (siehe meine Antwort) – kaufmanu

4

mrry Antwort ist groß, aber ich denke, mit der Funktion tf.gather_nd das Problem mit viel weniger Codezeilen gelöst werden kann (wahrscheinlich diese Funktion zum Zeitpunkt der mrry des Schreibens noch nicht verfügbar war):

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0], 
         [11.0, 10.0, 10.0, 30.0], 
         [12.0, 10.0, 10.0, 20.0], 
         [13.0, 10.0, 10.0, 20.0]], 
        [[14.0, 11.0, 21.0, 31.0], 
         [15.0, 11.0, 11.0, 21.0], 
         [16.0, 11.0, 11.0, 21.0], 
         [17.0, 11.0, 11.0, 21.0]]]) 

indices = tf.constant([[[0, 0], [0, 1]], [[1, 1], [1, 3]]]) 

result = tf.gather_nd(logits, indices) 
with tf.Session() as sess: 
    print(sess.run(result)) 

Dies druckt

[[[ 10. 10. 20. 20.] 
    [ 11. 10. 10. 30.]] 

[[ 15. 11. 11. 21.] 
    [ 17. 11. 11. 21.]]] 

tf.gather_nd sollte ab v0.10 verfügbar sein. Weitere Informationen hierzu finden Sie unter this github issue.

+0

Wie hast du Indizes von 2D in 3d geändert (wie in Frage gestellt)? – Tulsi

+0

@Tulsi Ich verstehe deine Frage nicht. 3D-Indizes werden in der Frage nicht erwähnt, oder? – kaufmanu

Verwandte Themen