2016-10-08 12 views
1

Als Teil eines Projekts, an dem ich gerade arbeite, muss ich den mittleren quadratischen Fehler zwischen 2m Vektoren berechnen.Speicherfehler bei der Matrixmultiplikation

Im Grunde habe ich zwei Matrizen x und xhat sind beide Größe m von n und die Vektoren mich interessiert sind die Reihen dieser Vektoren.

ich die MSE mit diesem Code

def cost(x, xhat): #mean squared error between x the data and xhat the output of the machine 
    return (1.0/(2 * m)) * np.trace(np.dot(x-xhat,(x-xhat).T)) 

Es funktioniert korrekt zu berechnen, ist diese Formel korrekt. Das Problem ist, dass in meinem speziellen Fall meine m und n sehr groß sind. speziell, m = 60000 und n = 785. Wenn ich also meinen Code starte und diese Funktion eingibt, erhalte ich einen Speicherfehler.

Gibt es eine bessere Möglichkeit, den MSE zu berechnen? Ich würde lieber Schleifen vermeiden und mich stark auf die Matrixmultiplikation konzentrieren, aber Matrixmultiplikation scheint hier extrem verschwenderisch. Vielleicht etwas in der Anzahl, die mir nicht bewusst ist?

+2

können Sie versuchen, die Möglichkeiten, hier vorgeschlagen: http: // Stackoverflow. com/questions/16774849/mean-squared-Fehler-in-numpy – jerry

+3

dotting zwei Arrays und dann nur mit der Spur ist unmöglich verschwenderisch – Julius

Antwort

5

Der Ausdruck np.dot(x-xhat,(x-xhat).T) erstellt ein Array mit Form (m, m). Sie sagen, m ist 60000, so dass das Array fast 29 Gigabyte ist.

Sie nehmen die Spur des Arrays, die nur die Summe der diagonalen Elemente ist, so dass der größte Teil dieses riesigen Arrays nicht verwendet wird. Wenn Sie sich genau die np.trace(np.dot(x-xhat,(x-xhat).T)) ansehen, werden Sie sehen, dass es nur die Summe der Quadrate aller Elemente von x - xhat ist. So eine einfachere Möglichkeit, np.trace(np.dot(x-xhat,(x-xhat).T)) zu berechnen, die das riesige Zwischenarray nicht benötigt, ist ((x - xhat)**2).sum(). Zum Beispiel

In [44]: x 
Out[44]: 
array([[ 0.87167186, 0.96838389, 0.72545457], 
     [ 0.05803253, 0.57355625, 0.12732163], 
     [ 0.00874702, 0.01555692, 0.76742386], 
     [ 0.4130838 , 0.89307633, 0.49532327], 
     [ 0.15929044, 0.27025289, 0.75999848]]) 

In [45]: xhat 
Out[45]: 
array([[ 0.20825392, 0.63991699, 0.28896932], 
     [ 0.67658621, 0.64919721, 0.31624655], 
     [ 0.39460861, 0.33057769, 0.24542263], 
     [ 0.10694332, 0.28030777, 0.53177585], 
     [ 0.21066692, 0.53096774, 0.65551612]]) 

In [46]: np.trace(np.dot(x-xhat,(x-xhat).T)) 
Out[46]: 2.2352330441581061 

In [47]: ((x - xhat)**2).sum() 
Out[47]: 2.2352330441581061 

Weitere Vorstellungen über die MSE Berechnung finden Sie in die link provided by user1984065 in einem Kommentar.

1

Wenn Sie Leistung gehen auch, ein alternativer Ansatz Summe der quadrierten Differenzen zu berechnen, könnte np.einsum verwenden, wie so -

subs = x-xhat 
out = np.einsum('ij,ij',subs,subs) 
Verwandte Themen