2014-12-15 40 views
5

Ich sehe nicht, was mit meinem Code für regulierte lineare Regression falsch ist. Unreglementierten ich einfach diese, die ich ziemlich sicher bin, ist richtig:Numpy lineare Regression mit Regularisierung

import numpy as np 

def get_model(features, labels): 
    return np.linalg.pinv(features).dot(labels) 

Hier für eine gesetzlich geregelte Lösung mein Code ist, wo ich sehen, ich nicht, was es falsch ist:

def get_model(features, labels, lamb=0.0): 
    n_cols = features.shape[1] 
    return linalg.inv(features.transpose().dot(features) + lamb * np.identity(n_cols))\ 
      .dot(features.transpose()).dot(labels) 

Mit der Standardwert von 0.0 für lamm, meine Absicht ist, dass es das gleiche Ergebnis wie die (korrekte) nicht regulierte Version geben sollte, aber der Unterschied ist tatsächlich ziemlich groß.

Kann jemand sehen, was das Problem ist?

+0

tun Regularisierung Ich fange an, und würde eine lineare Regressionslinie eine Kurve erzeugen Regularisierung? – duldi

+1

Nein. Sie erhalten immer noch lineare Koeffizienten. Die Regularisierung ändert nur die Steigung. –

Antwort

6

Das Problem ist:

features.transpose().dot(features) nicht umkehrbar sein kann. Und numpy.linalg.inv funktioniert nur für die vollständige Rangliste gemäß den Dokumenten. Ein Regularisierungsausdruck (ungleich Null) macht die Gleichung jedoch immer nonsingulär.

Übrigens, Sie haben Recht mit der Implementierung. Aber es ist nicht effizient. Ein effizienter Weg, diese Gleichung zu lösen, ist die Methode der kleinsten Quadrate.

np.linalg.lstsq(features, labels) kann die Arbeit für np.linalg.pinv(features).dot(labels) tun.

In einer allgemeinen Weise können Sie diese

def get_model(A, y, lamb=0): 
    n_col = A.shape[1] 
    return np.linalg.lstsq(A.T.dot(A) + lamb * np.identity(n_col), A.T.dot(y)) 
+0

Wenn Sie np.linalg.lstsq() verwenden, wie passen Sie in den Regularisierungsbegriff 'Lamm'? –

+0

bearbeite meine Antwort. – nullas

+0

Das funktioniert gut! Vielen Dank. Ich endete mit 'np.linalg.lstsq (...) [0]' weil ich sonst ein Tupel zurückbekam. Wissen Sie auch, warum 'lstsq()' leistungsfähiger ist? –

Verwandte Themen