2016-08-13 9 views
0

Ich versuche, einen Gradienten-Abstieg-Algorithmus zu implementieren, der zuvor in Matlab in Python mit numpy geschrieben wurde, aber ich bekomme eine Reihe von ähnlichen, aber unterschiedlichen Ergebnissen.Verschiedene Ergebnisse in numpy vs Matlab

Hier ist der MATLAB-Code

function [theta] = gradientDescentMulti(X, y, theta, alpha, num_iters) 

m = length(y); 
num_features = size(X,2); 
for iter = 1:num_iters; 
    temp_theta = theta; 
    for i = 1:num_features 
     temp_theta(i) = theta(i)-((alpha/m)*(X * theta - y)'*X(:,i)); 
    end 
    theta = temp_theta; 
end 


end 

und meine Python-Version

def gradient_descent(X,y, alpha, trials): 

    m = X.shape[0] 
    n = X.shape[1] 
    theta = np.zeros((n, 1)) 

    for i in range(trials): 

     temp_theta = theta 
     for p in range(n): 
      thetaX = np.dot(X, theta) 
      tMinY = thetaX-y 
      temp_theta[p] = temp_theta[p]-(alpha/m)*np.dot(tMinY.T, X[:,p:p+1]) 

     theta = temp_theta 

    return theta 

Testfall und die Ergebnisse in Matlab

X = [1 2 1 3; 1 7 1 9; 1 1 8 1; 1 3 7 4] 
y = [2 ; 5 ; 5 ; 6]; 
[theta] = gradientDescentMulti(X, y, zeros(4,1), 0.01, 1); 

theta = 

    0.0450 
    0.1550 
    0.2225 
    0.2000 

Testfall und führen in Python

test_X = np.array([[1,2,1,3],[1,7,1,9],[1,1,8,1],[1,3,7,4]]) 
test_y = np.array([[2], [5], [5], [6]]) 
theta, cost = gradient_descent(test_X, test_y, 0.01, 1) 
print theta 
>>[[ 0.045  ] 
    [ 0.1535375 ] 
    [ 0.20600144] 
    [ 0.14189214]] 
+0

@Kartik "MATLAB Ergebnisse möglicherweise falsch" ist wirklich kein hilfreicher Vorschlag, ohne einen detaillierten Grund. – dbliss

+0

Mein Kommentar wurde missverstanden. Ich habe meine Erfahrungen mit Ihnen geteilt, und ich habe vorgeschlagen, dass Sie eine andere Software verwenden könnten, um dies zu beheben, wenn Sie können. (Ich bin mir bewusst, dass Sie möglicherweise keinen Zugang zu einer anderen Software haben.) Warum MATLAB-Ergebnisse falsch waren, als ich ein einfaches Histogramm ausprobierte, hatte ich zu diesem Zeitpunkt noch nicht herausgefunden. Ich machte es für die verschlossene Quelle von MATLAB verantwortlich und vermutete, dass etwas durch ihre Tests gerutscht war, und fuhr mit Python fort, das ich viel mehr "zu Hause" benutzte. – Kartik

Antwort

7

Diese Zeile in Python:

temp_theta = theta 

nicht tut, was Sie denken, es tut. Es erstellt keine Kopie von theta und "weist" es der "Variablen" temp_theta zu - es sagt nur "temp_theta ist jetzt ein neuer Name für das Objekt, das derzeit von theta benannt wird".

Also, wenn Sie temp_theta hier ändern:

 temp_theta[p] = temp_theta[p]-(alpha/m)*np.dot(tMinY.T, X[:,p:p+1]) 

Sie tatsächlich theta modifizieren - denn es gibt nur die eine Reihe, jetzt mit zwei Namen.

Wenn Sie stattdessen

temp_theta = theta.copy() 

schreiben werden Sie so etwas wie

(3.5) [email protected]:~/coding$ python peter.py 
[[ 0.045 ] 
[ 0.155 ] 
[ 0.2225] 
[ 0.2 ]] 

erhalten, die Ihre Matlab Ergebnisse übereinstimmt.

Verwandte Themen