import tensorflow as tf
x = tf.constant([[1, 0, 2, 3, 4], [2, 3, 4, 4, 4], [2, 3, 4, 5, 4]])
cond = tf.cast(tf.equal(x, 4), tf.int8)
idx4_ = tf.reshape(tf.argmax(cond, axis=1, output_type=tf.int32), (-1,1))
anzuwenden
Optional, wenn alle Zeilen mindestens einen Wert gleich 4:
idx4 = tf.where(
tf.equal(tf.reduce_max(cond, axis=1, keep_dims=True), 1),
idx4_,
tf.constant(-1, shape=idx4_.shape)
)
erstellen Sie die Maske durch den Index des ersten 4 mit einem 1D-Entfernungs-Index zu vergleichen:
mask = idx4 >= tf.range(x.shape[1])
with tf.Session() as sess:
print(sess.run(mask))
#[[ True True True True True]
# [ True True True False False]
# [ True True True False False]]
Nutzen sequence_mask
:
import tensorflow as tf
x = tf.constant([[1, 0, 2, 3, 4], [2, 3, 4, 4, 4], [2, 3, 4, 5, 4]])
cond = tf.cast(tf.equal(x, 4), tf.int8)
idx4_ = tf.argmax(cond, axis=1, output_type=tf.int32)
idx4 = tf.where(
tf.equal(tf.reduce_max(cond, axis=1), 1),
idx4_,
tf.constant(-1, shape=idx4_.shape)
)
with tf.Session() as sess:
print(sess.run(tf.sequence_mask(idx4+1, x.shape[1])))
#[[ True True True True True]
# [ True True True False False]
# [ True True True False False]]
Wenn x ist ein Platzhalter mit unbekannter Form vor der Hand:
import tensorflow as tf
x = tf.placeholder(tf.int32, shape=[None,None])
cond = tf.cast(tf.equal(x, 4), tf.int8)
idx4_ = tf.argmax(cond, axis=1, output_type=tf.int32)
idx4 = tf.where(
tf.equal(tf.reduce_max(cond, axis=1), 1),
idx4_,
tf.fill(tf.shape(idx4_), -1)
)
mask = tf.sequence_mask(idx4+1, tf.shape(x)[-1])
with tf.Session() as sess:
print(sess.run(mask, {x: [[1, 0, 2, 3, 4], [2, 3, 4, 4, 4], [2, 3, 4, 5, 4]]}))
#[[ True True True True True]
# [ True True True False False]
# [ True True True False False]]
So eine Batch-Form ist [BATCH_SIZE, 3, 5] al Wege? und was meinst du mit den ersten 4, um wahr zu sein? Ich sehe das nicht in der resultierenden Grafik? –
Die Batch-Größe ist wie im Beispiel angegeben festgelegt. Alles vor den ersten 4 (einschließlich der ersten 4), die gesehen werden, sollte als True maskiert werden. 4 ist der Elementwert in der Charge, da meine Charge Zahlen enthält – user3669481