2017-03-16 7 views
0

Kann jemand herausfinden, wie man eine Reihe von Einbettungen ergreift?Eine Liste dynamischer Indizes sammeln

Ich habe einige Code, der die Wahrscheinlichkeit eines jeden Index prognostiziert, und wählt dann die max:

# U is batch_size x max_sentence_length x embedding_size 
scores_per_index = find_start_preds(U ...) # batch_size x max_sentence_length x 1 
start_preds = tf.argmax(alpha, axis=1) # batch_size x 1 

Ich möchte, wenn möglich, das Wort Einbettungen mit jeder Vorhersage beginnen assoziiert wieder greifen. Ist das möglich? Das ist, was ich denke, aber es funktioniert nicht :(

u_s = U[:, start_preds, :] 

Antwort

1

Sie sollten tf.gather für verwenden können, was Sie wollen, aber es funktioniert nur auf den führenden Indizes, so dass Sie neu ordnen müssen:

U2 = tf.transpose(U, [1, 0, 2]) 
u_s = tf.gather(U2, start_preds) 
u_s = tf.transpose(u_s, [1, 0, 2]) 
Verwandte Themen