2016-07-15 4 views
6

Als ich dritter Ordnung Momente einer Matrix X mit N Zeilen und n Spalten berechnen, ich einsum in der Regel verwenden:Alternativen numpy einsum

M3 = sp.einsum('ij,ik,il->jkl',X,X,X) /N 

Dies funktioniert in der Regel gut, aber jetzt mit größeren Werten arbeite ich, nämlich n = 120 und N = 100000 und einsum liefert folgende Fehler:

ValueError: iterator is too large

die alternative von 3 verschachtelte Schleifen tun, ist undurchführbar, so dass ich Ich frage mich, ob es irgendeine Alternative gibt.

Antwort

4

Beachten Sie, dass diese Berechnung müssen mindestens ~ n × N = 173 Milliarden Rechenoperationen (ohne Berücksichtigung von Symmetrie), dies zu tun, es wird langsam sein, es sei denn numpy Zugriff auf GPU oder etwas hat. Auf einem modernen Computer mit einer ~ 3 GHz-CPU wird erwartet, dass die gesamte Berechnung etwa 60 Sekunden dauert, wobei keine SIMD/Parallel-Beschleunigung angenommen wird.


Für die Prüfung, lassen Sie uns mit N start = 1000. Wir diese verwenden Korrektheit und Leistung zu überprüfen:

#!/usr/bin/env python3 

import numpy 
import time 

numpy.random.seed(0) 

n = 120 
N = 1000 
X = numpy.random.random((N, n)) 

start_time = time.time() 

M3 = numpy.einsum('ij,ik,il->jkl', X, X, X) 

end_time = time.time() 

print('check:', M3[2,4,6], '= 125.401852515?') 
print('check:', M3[4,2,6], '= 125.401852515?') 
print('check:', M3[6,4,2], '= 125.401852515?') 
print('check:', numpy.sum(M3), '= 218028826.631?') 
print('total time =', end_time - start_time) 

Dieser Vorgang dauert ca. 8 Sekunden. Dies ist die Grundlinie.

Beginnen sie mit der 3 verschachtelten Schleife als Alternative zu starten:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    for k in range(n): 
     for l in range(n): 
      M3[j,k,l] = numpy.sum(X[:,j] * X[:,k] * X[:,l]) 
# ~27 seconds 

Dies dauert etwa eine halbe Minute, nicht gut! Ein Grund dafür ist, dass es sich tatsächlich um vier verschachtelte Schleifen handelt: numpy.sum kann auch als Schleife betrachtet werden.

Wir nehmen zur Kenntnis, dass die Summe in ein Punktprodukt umgewandelt werden kann diese vierte Schleife zu entfernen:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    for k in range(n): 
     for l in range(n): 
      M3[j,k,l] = X[:,j] * X[:,k] @ X[:,l] 
# 14 seconds 

Viel besser jetzt aber immer noch langsam. Aber wir beachten Sie, dass das das Punktprodukt in eine Matrixmultiplikation geändert werden kann eine Schleife zu entfernen:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    for k in range(n): 
     M3[j,k] = X[:,j] * X[:,k] @ X 
# ~0.5 seconds 

Huh? Jetzt ist das sogar viel effizienter als einsum! Wir könnten auch überprüfen, ob die Antwort in der Tat richtig ist.

Können wir weiter gehen? Ja! Wir könnten die k Schleife beseitigen, indem sie:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    Y = numpy.repeat(X[:,j], n).reshape((N, n)) 
    M3[j] = (Y * X).T @ X 
# ~0.3 seconds 

wir auch Rundfunk (dh a * [b,c] == [a*b, a*c] für jede Zeile von X) verwenden, könnte die numpy.repeat (dank @Divakar) zu tun zu vermeiden:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    Y = X[:,j].reshape((N, 1)) 
    ## or, equivalently: 
    # Y = X[:, numpy.newaxis, j] 
    M3[j] = (Y * X).T @ X 
# ~0.16 seconds 

Wenn wir maßstabs dies zu N = 100000 das Programm wird voraussichtlich 16 Sekunden dauern, was innerhalb der theoretischen Grenze ist, so dass die Beseitigung der kann nicht zu viel helfen (aber das kann den Code wirklich schwer zu verstehen). Wir könnten dies als endgültige Lösung akzeptieren.


Hinweis: Wenn Sie Python 2 verwenden, ist a @ b-a.dot(b) gleichwertig.

+0

große antwort, danke! –

+0

Tolle Idee wirklich. Wenn ich hier ein bisschen Broadcast hinzufügen könnte, könnten wir es vermeiden, 'Y' zu erzeugen und direkt die iterative Ausgabe zu erhalten:' (X [:, None, j] * X) .T @ X'. Das sollte uns einen weiteren Leistungsschub geben. – Divakar

+0

@Divakar: Danke! Aktualisiert. – kennytm