Beachten Sie, dass diese Berechnung müssen mindestens ~ n × N = 173 Milliarden Rechenoperationen (ohne Berücksichtigung von Symmetrie), dies zu tun, es wird langsam sein, es sei denn numpy Zugriff auf GPU oder etwas hat. Auf einem modernen Computer mit einer ~ 3 GHz-CPU wird erwartet, dass die gesamte Berechnung etwa 60 Sekunden dauert, wobei keine SIMD/Parallel-Beschleunigung angenommen wird.
Für die Prüfung, lassen Sie uns mit N start = 1000. Wir diese verwenden Korrektheit und Leistung zu überprüfen:
#!/usr/bin/env python3
import numpy
import time
numpy.random.seed(0)
n = 120
N = 1000
X = numpy.random.random((N, n))
start_time = time.time()
M3 = numpy.einsum('ij,ik,il->jkl', X, X, X)
end_time = time.time()
print('check:', M3[2,4,6], '= 125.401852515?')
print('check:', M3[4,2,6], '= 125.401852515?')
print('check:', M3[6,4,2], '= 125.401852515?')
print('check:', numpy.sum(M3), '= 218028826.631?')
print('total time =', end_time - start_time)
Dieser Vorgang dauert ca. 8 Sekunden. Dies ist die Grundlinie.
Beginnen sie mit der 3 verschachtelten Schleife als Alternative zu starten:
M3 = numpy.zeros((n, n, n))
for j in range(n):
for k in range(n):
for l in range(n):
M3[j,k,l] = numpy.sum(X[:,j] * X[:,k] * X[:,l])
# ~27 seconds
Dies dauert etwa eine halbe Minute, nicht gut! Ein Grund dafür ist, dass es sich tatsächlich um vier verschachtelte Schleifen handelt: numpy.sum
kann auch als Schleife betrachtet werden.
Wir nehmen zur Kenntnis, dass die Summe in ein Punktprodukt umgewandelt werden kann diese vierte Schleife zu entfernen:
M3 = numpy.zeros((n, n, n))
for j in range(n):
for k in range(n):
for l in range(n):
M3[j,k,l] = X[:,j] * X[:,k] @ X[:,l]
# 14 seconds
Viel besser jetzt aber immer noch langsam. Aber wir beachten Sie, dass das das Punktprodukt in eine Matrixmultiplikation geändert werden kann eine Schleife zu entfernen:
M3 = numpy.zeros((n, n, n))
for j in range(n):
for k in range(n):
M3[j,k] = X[:,j] * X[:,k] @ X
# ~0.5 seconds
Huh? Jetzt ist das sogar viel effizienter als einsum
! Wir könnten auch überprüfen, ob die Antwort in der Tat richtig ist.
Können wir weiter gehen? Ja! Wir könnten die k
Schleife beseitigen, indem sie:
M3 = numpy.zeros((n, n, n))
for j in range(n):
Y = numpy.repeat(X[:,j], n).reshape((N, n))
M3[j] = (Y * X).T @ X
# ~0.3 seconds
wir auch Rundfunk (dh a * [b,c] == [a*b, a*c]
für jede Zeile von X) verwenden, könnte die numpy.repeat
(dank @Divakar) zu tun zu vermeiden:
M3 = numpy.zeros((n, n, n))
for j in range(n):
Y = X[:,j].reshape((N, 1))
## or, equivalently:
# Y = X[:, numpy.newaxis, j]
M3[j] = (Y * X).T @ X
# ~0.16 seconds
Wenn wir maßstabs dies zu N = 100000 das Programm wird voraussichtlich 16 Sekunden dauern, was innerhalb der theoretischen Grenze ist, so dass die Beseitigung der kann nicht zu viel helfen (aber das kann den Code wirklich schwer zu verstehen). Wir könnten dies als endgültige Lösung akzeptieren.
Hinweis: Wenn Sie Python 2 verwenden, ist a @ b
-a.dot(b)
gleichwertig.
große antwort, danke! –
Tolle Idee wirklich. Wenn ich hier ein bisschen Broadcast hinzufügen könnte, könnten wir es vermeiden, 'Y' zu erzeugen und direkt die iterative Ausgabe zu erhalten:' (X [:, None, j] * X) .T @ X'. Das sollte uns einen weiteren Leistungsschub geben. – Divakar
@Divakar: Danke! Aktualisiert. – kennytm