2017-11-20 1 views
1

Mein Code:Erhalten Indizes N höchsten Werte in numpy Array

import numpy as np 
N = 2 
a = np.array([[0.5, 0.3, 0.2], 
       [0.2, 0.6, 0.2], 
       [0.3, 0.2, 0.7], 
       [np.nan, 0.2, 0.8],      
       [np.nan, np.nan, 0.8]      
       ]) 

ind = np.argsort(np.where(np.isnan(a), -1, a), axis=1)[:, -N:] 


a 
Out[2]: 
array([[ 0.5, 0.3, 0.2], 
     [ 0.2, 0.6, 0.2], 
     [ 0.3, 0.2, 0.7], 
     [ nan, 0.2, 0.8], 
     [ nan, nan, 0.8]]) 

ind 
Out[3]: 
array([[1, 0], 
     [2, 1], 
     [0, 2], 
     [1, 2], 
     [1, 2]], dtype=int64) 

ind [:, 1] die höchste und ind [: 0] zweithöchste

Welche außer in dem Fall in Ordnung mit 2 Nans in der letzten Reihe. Wie kann der zweithöchste Wert ignoriert werden, wenn er nicht groß ist? gewünschte Ausgabe wäre:

array([[1, 0], 
     [2, 1], 
     [0, 2], 
     [1, 2], 
     [nan, 2]], dtype=int64) 

Bonus Frage: Wie zufällig eine Krawatte im Falle eines [1 ,:] brechen?

Antwort

1

Advanced-index und überprüfen NaNs uns eine Maske zu geben, die dann mit np.where verwendet werden, um die Wahl zu tun, wie so -

In [244]: a_ind = a[np.arange(ind.shape[0])[:,None],ind] 

In [245]: mask = np.isnan(a_ind) 

In [246]: np.where(mask, np.nan, ind) 
Out[246]: 
array([[ 1., 0.], 
     [ 2., 1.], 
     [ 0., 2.], 
     [ 1., 2.], 
     [ nan, 2.]]) 

Beachten Sie, dass ein Array NaN zu float dtype umgewandelt würden haben daher wäre die endgültige Ausgabe auch von float dtype.

Verwandte Themen