2013-07-08 7 views
5

Ich versuche herauszufinden, wie man eine Python-Funktion beschleunigt, die numpy verwendet. Die Ausgabe, die ich von lineprofiler erhalten habe, ist unten, und dies zeigt, dass die meiste Zeit auf der Linie ind_y, ind_x = np.where(seg_image == i) verbracht wird.Beschleunigen Sie numpy.where zum Extrahieren von Integer-Segmenten?

seg_image ist ein Integer-Array, das das Ergebnis der Segmentierung eines Bildes ist, und so die Pixel findet, in denen seg_image == i ein bestimmtes segmentiertes Objekt extrahiert. Ich schlängele mich durch viele dieser Objekte (im Code unten durchlaufe ich nur 5 Tests, aber ich werde tatsächlich mehr als 20.000 durchlaufen), und es dauert sehr lange!

Gibt es eine Möglichkeit, wie der np.where Anruf beschleunigt werden kann? Oder alternativ, dass die vorletzte Zeile (die auch einen großen Teil der Zeit benötigt) beschleunigt werden kann?

Die ideale Lösung wäre, den Code auf dem gesamten Array auf einmal auszuführen, anstatt Schleifen, aber ich glaube nicht, dass dies möglich ist, da einige der Funktionen, die ich ausführen muss (z Beispielsweise kann das Aufteilen eines segmentierten Objekts dazu führen, dass es mit der nächsten Region kollidiert und somit später falsche Ergebnisse liefert.

Hat jemand irgendwelche Ideen?

Line #  Hits   Time Per Hit % Time Line Contents 
============================================================== 
    5           def correct_hot(hot_image, seg_image): 
    6   1  239810 239810.0  2.3  new_hot = hot_image.copy() 
    7   1  572966 572966.0  5.5  sign = np.zeros_like(hot_image) + 1 
    8   1  67565 67565.0  0.6  sign[:,:] = 1 
    9   1  1257867 1257867.0  12.1  sign[hot_image > 0] = -1 
    10           
    11   1   150 150.0  0.0  s_elem = np.ones((3, 3)) 
    12           
    13            #for i in xrange(1,seg_image.max()+1): 
    14   6   57  9.5  0.0  for i in range(1,6): 
    15   5  6092775 1218555.0  58.5   ind_y, ind_x = np.where(seg_image == i) 
    16           
    17             # Get the average HOT value of the object (really simple!) 
    18   5   2408 481.6  0.0   obj_avg = hot_image[ind_y, ind_x].mean() 
    19           
    20   5   333  66.6  0.0   miny = np.min(ind_y) 
    21             
    22   5   162  32.4  0.0   minx = np.min(ind_x) 
    23             
    24           
    25   5   369  73.8  0.0   new_ind_x = ind_x - minx + 3 
    26   5   113  22.6  0.0   new_ind_y = ind_y - miny + 3 
    27           
    28   5   211  42.2  0.0   maxy = np.max(new_ind_y) 
    29   5   143  28.6  0.0   maxx = np.max(new_ind_x) 
    30           
    31             # 7 is + 1 to deal with the zero-based indexing, + 2 * 3 to deal with the 3 cell padding above 
    32   5   217  43.4  0.0   obj = np.zeros((maxy+7, maxx+7)) 
    33           
    34   5   158  31.6  0.0   obj[new_ind_y, new_ind_x] = 1 
    35           
    36   5   2482 496.4  0.0   dilated = ndimage.binary_dilation(obj, s_elem) 
    37   5   1370 274.0  0.0   border = mahotas.borders(dilated) 
    38           
    39   5   122  24.4  0.0   border = np.logical_and(border, dilated) 
    40           
    41   5   355  71.0  0.0   border_ind_y, border_ind_x = np.where(border == 1) 
    42   5   136  27.2  0.0   border_ind_y = border_ind_y + miny - 3 
    43   5   123  24.6  0.0   border_ind_x = border_ind_x + minx - 3 
    44           
    45   5   645 129.0  0.0   border_avg = hot_image[border_ind_y, border_ind_x].mean() 
    46           
    47   5  2167729 433545.8  20.8   new_hot[seg_image == i] = (new_hot[ind_y, ind_x] + (sign[ind_y, ind_x] * np.abs(obj_avg - border_avg))) 
    48   5  10179 2035.8  0.1   print obj_avg, border_avg 
    49           
    50   1   4  4.0  0.0  return new_hot 

Antwort

4

EDIT Ich habe meine ursprüngliche Antwort am unteren Rand der Ordnung halber links, aber ich habe tatsächlich in Ihren Code im Detail über das Mittagessen sah, und ich denke, dass np.where mit ein großer Fehler:

In [63]: a = np.random.randint(100, size=(1000, 1000)) 

In [64]: %timeit a == 42 
1000 loops, best of 3: 950 us per loop 

In [65]: %timeit np.where(a == 42) 
100 loops, best of 3: 7.55 ms per loop 

Sie könnten ein boolesches Array (das Sie für die Indizierung verwenden können) in 1/8 der Zeit erhalten, die Sie benötigen, um die tatsächlichen Koordinaten der Punkte zu erhalten !!!

Es gibt natürlich das Zuschneiden der Features, die Sie tun, aber ndimage hat eine find_objects Funktion, die umschließende Scheiben zurückgibt und sehr schnell zu sein scheint:

In [66]: %timeit ndimage.find_objects(a) 
100 loops, best of 3: 11.5 ms per loop 

Diese eine Liste von Tupeln von Scheiben zurück Einschließen alle Ihrer Objekte, in 50% mehr Zeit, die es braucht, um die Indizes ein einzelnes Objekt zu finden.

Es kann nicht aus der Box arbeitet, wie ich es jetzt nicht testen können, aber ich würde Ihren Code in etwa wie folgt neu strukturieren:

def correct_hot_bis(hot_image, seg_image): 
    # Need this to not index out of bounds when computing border_avg 
    hot_image_padded = np.pad(hot_image, 3, mode='constant', 
           constant_values=0) 
    new_hot = hot_image.copy() 
    sign = np.ones_like(hot_image, dtype=np.int8) 
    sign[hot_image > 0] = -1 
    s_elem = np.ones((3, 3)) 

    for j, slice_ in enumerate(ndimage.find_objects(seg_image)): 
     hot_image_view = hot_image[slice_] 
     seg_image_view = seg_image[slice_] 
     new_shape = tuple(dim+6 for dim in hot_image_view.shape) 
     new_slice = tuple(slice(dim.start, 
           dim.stop+6, 
           None) for dim in slice_) 
     indices = seg_image_view == j+1 

     obj_avg = hot_image_view[indices].mean() 

     obj = np.zeros(new_shape) 
     obj[3:-3, 3:-3][indices] = True 

     dilated = ndimage.binary_dilation(obj, s_elem) 
     border = mahotas.borders(dilated) 
     border &= dilated 

     border_avg = hot_image_padded[new_slice][border == 1].mean() 

     new_hot[slice_][indices] += (sign[slice_][indices] * 
            np.abs(obj_avg - border_avg)) 

    return new_hot 

Sie würden nach wie vor die Notwendigkeit, um herauszufinden, Kollisionen, aber man konnte durch die Berechnung alle Indizes über ein 2-facher Geschwindigkeit-up erhalten gleichzeitig einen np.unique basierten Ansatz:

a = np.random.randint(100, size=(1000, 1000)) 

def get_pos(arr): 
    pos = [] 
    for j in xrange(100): 
     pos.append(np.where(arr == j)) 
    return pos 

def get_pos_bis(arr): 
    unq, flat_idx = np.unique(arr, return_inverse=True) 
    pos = np.argsort(flat_idx) 
    counts = np.bincount(flat_idx) 
    cum_counts = np.cumsum(counts) 
    multi_dim_idx = np.unravel_index(pos, arr.shape) 
    return zip(*(np.split(coords, cum_counts) for coords in multi_dim_idx)) 

In [33]: %timeit get_pos(a) 
1 loops, best of 3: 766 ms per loop 

In [34]: %timeit get_pos_bis(a) 
1 loops, best of 3: 388 ms per loop 

Beachten Sie, dass die Pixel für jedes o bject werden in einer anderen Reihenfolge zurückgegeben, sodass Sie nicht einfach die Rückgabewerte beider Funktionen vergleichen können, um die Gleichheit zu beurteilen. Aber sie sollten beide dasselbe zurückgeben.

+0

Das ist wunderbar, fantastisch und erstaunlich - danke! Das erste Mal, als ich es ausführte, stellte ich fest, dass es tatsächlich langsamer war als mein ursprünglicher Code, aber dann habe ich etwas von deinem Code so geändert, dass er die ganze Arbeit (Dilation, Borders usw.) in einem kleinen Array anstatt dem riesigen Array erledigt hat. durch Modifizieren, wie die neue_Form berechnet wurde. Ich habe jetzt eine enorme Geschwindigkeitssteigerung erlebt. Bei einem der Bilder, mit denen ich arbeite, dauerte die alte Version zweieinhalb Stunden, die neue dauerte 11 Sekunden! – robintw

+0

Hoppla! Ja, es sieht so aus, als müsste der Generatorausdruck 'new_shape = tuple (dim + 6 für dim in hot_image_view.shape)' sein und nicht 'new_shape = tuple (dim + 6 für dim in hot_image.shape)' '. Hast du dich verändert? Bitte, fühlen Sie sich frei, meine Antwort zu bearbeiten, um den Arbeitscode widerzuspiegeln. – Jaime

2

Eine Sache, die Sie gleichen ein wenig Zeit tun könnte, ist das Ergebnis der seg_image == i zu speichern, damit Sie es nicht zweimal berechnen müssen. Sie berechnen es in Zeilen 15 & 47, Sie könnten seg_mask = seg_image == i hinzufügen und dann dieses Ergebnis wiederverwenden (Es könnte auch gut sein, dieses Stück für Profiling-Zwecke zu trennen).

Während Sie noch ein paar andere Dinge tun können, um ein wenig Leistung zu erzielen, liegt das Hauptproblem darin, dass Sie einen O (M * N) Algorithmus verwenden, wobei M die Anzahl der Segmente und N ist ist die Größe Ihres Bildes. Es ist für mich aus Ihrem Code nicht offensichtlich, ob es einen schnelleren Algorithmus gibt, um dasselbe zu erreichen, aber das ist der erste Ort, an dem ich versuchen würde, nach einer Beschleunigung zu suchen.