2017-08-30 3 views
2

Ich muss ein mehrdimensionales Array nach den Werten im ersten Sub-Array so schnell wie möglich sortieren (die Zeile wird millionenfach angewendet).Mehrdimensionales Array schneller Sortierung

Unten ist meine ursprüngliche Linie, und mein Versuch, seine Leistung zu verbessern, die nicht funktioniert. Soweit ich das sehen kann, sortiert mein numpy Ansatz nur das erste Sub-Array und keines der verbleibenden korrekt.

Was mache ich falsch und wie kann ich die Sortierleistung verbessern?

import numpy as np 

# Generate some random data. 
# I receive the actual data as a list, hence the .tolist() 
aa = np.random.rand(10, 2000).tolist() 

# This is the original line I need to process faster. 
b1 = zip(*sorted(zip(*aa), key=lambda x: x[0])) 

# This is my attempt at improving the above line's performance 
b2 = np.sort(np.asarray(aa).T, axis=0).T 

# Check if all sub-arrays are equal 
for a, b in zip(*[b1, b2]): 
    print(np.array_equal(a, b)) 
+1

Auf Anhieb können Sie versuchen, 'lambda x: x [0]' durch 'operator.itemgetter (0)' zu ersetzen. – chepner

+0

Danke, ich werde das jetzt versuchen. Aber warum funktioniert der 'numpy' Ansatz nicht? Was mache ich falsch? – Gabriel

Antwort

4

noch ein Neuling, wenn es um lambdas kommt, sondern von dem, was wenig ich aus dem Code zu verstehen - Sie in Ihrem lambda Methode scheint, verwenden Sie x[0] die Sortierschlüssel zu bekommen und dann diejenigen mit Werten ziehen aus jedes Element in aa. In NumPy-Begriffen bedeutet dies, dass die Sortierindizes für die erste Zeile in der Array-Version abgerufen und dann in jeder Zeile indexiert werden (da jedes Element von aa zu jeder Zeile des Arrays a wird). Das ist im Grunde eine Spalten-Indizierung. Außerdem scheint es sorted Ordnung für identische Elemente aufrechtzuerhalten. Also müssen wir argsort(kind='mergesort') verwenden.

So können wir einfach tun -

a[:, a[0].argsort(kind='mergesort')] # a = np.array(aa) 

In Ihrem NumPy Code, tun Sie nichts dieser Art, also nicht die richtigen Ergebnisse.

+0

Danke! Das ist ~ 20x schneller als mein ursprünglicher Ansatz. Kannst du erklären, was ich falsch mache, wenn ich 'numpy' benutze und transponiere? So kann ich aus meinen Fehlern lernen :) – Gabriel

+1

@Gabriel Sehen Sie, ob die Änderungen sinnvoll sind. – Divakar

+0

Vielen Dank für die Erklärung Divakar! – Gabriel