Ich habe einen Tensor logits
mit den Dimensionen [batch_size, num_rows, num_coordinates]
(d.h. jeder Logit in der Charge ist eine Matrix). In meinem Fall ist die Stapelgröße 2, es gibt 4 Zeilen und 4 Koordinaten.Wie wählt man Zeilen aus einem 3D-Tensor in TensorFlow?
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0],
[12.0, 10.0, 10.0, 20.0],
[13.0, 10.0, 10.0, 20.0]],
[[14.0, 11.0, 21.0, 31.0],
[15.0, 11.0, 11.0, 21.0],
[16.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
Ich mag die erste und die zweite Zeile der ersten Partie und die zweiten und vierten Zeile der zweiten Charge auszuwählen.
indices = tf.constant([[0, 1], [1, 3]])
So würde die gewünschte Ausgabe
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0]],
[[15.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
sein Wie kann ich tun dies mit TensorFlow? Ich habe versucht, tf.gather(logits, indices)
zu verwenden, aber es gab nicht zurück, was ich erwartete. Vielen Dank!
Während Ihre Antwort ist großartig, ich denke, heute kann es mit 'tf.gather_nd' ersetzt werden, die zum Zeitpunkt Ihres Schreibens wahrscheinlich noch nicht verfügbar war (siehe meine Antwort) – kaufmanu